diff --git a/.venv/lib/python3.13/site-packages/sympy/calculus/__init__.py b/.venv/lib/python3.13/site-packages/sympy/calculus/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..865c0556769e35ca7e8a09a4c208a1cfbd17369c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/calculus/__init__.py @@ -0,0 +1,25 @@ +"""Calculus-related methods.""" + +from .euler import euler_equations +from .singularities import (singularities, is_increasing, + is_strictly_increasing, is_decreasing, + is_strictly_decreasing, is_monotonic) +from .finite_diff import finite_diff_weights, apply_finite_diff, differentiate_finite +from .util import (periodicity, not_empty_in, is_convex, + stationary_points, minimum, maximum) +from .accumulationbounds import AccumBounds + +__all__ = [ +'euler_equations', + +'singularities', 'is_increasing', +'is_strictly_increasing', 'is_decreasing', +'is_strictly_decreasing', 'is_monotonic', + +'finite_diff_weights', 'apply_finite_diff', 'differentiate_finite', + +'periodicity', 'not_empty_in', 'is_convex', 'stationary_points', +'minimum', 'maximum', + +'AccumBounds' +] diff --git a/.venv/lib/python3.13/site-packages/sympy/calculus/accumulationbounds.py b/.venv/lib/python3.13/site-packages/sympy/calculus/accumulationbounds.py new file mode 100644 index 0000000000000000000000000000000000000000..8a2b0e634da1028c25399c8b03884865d981a56e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/calculus/accumulationbounds.py @@ -0,0 +1,804 @@ +from sympy.core import Add, Mul, Pow, S +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.numbers import _sympifyit, oo, zoo +from sympy.core.relational import is_le, is_lt, is_ge, is_gt +from sympy.core.sympify import _sympify +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.logic.boolalg import And +from sympy.multipledispatch import dispatch +from sympy.series.order import Order +from sympy.sets.sets import FiniteSet + + +class AccumulationBounds(Expr): + r"""An accumulation bounds. + + # Note AccumulationBounds has an alias: AccumBounds + + AccumulationBounds represent an interval `[a, b]`, which is always closed + at the ends. Here `a` and `b` can be any value from extended real numbers. + + The intended meaning of AccummulationBounds is to give an approximate + location of the accumulation points of a real function at a limit point. + + Let `a` and `b` be reals such that `a \le b`. + + `\left\langle a, b\right\rangle = \{x \in \mathbb{R} \mid a \le x \le b\}` + + `\left\langle -\infty, b\right\rangle = \{x \in \mathbb{R} \mid x \le b\} \cup \{-\infty, \infty\}` + + `\left\langle a, \infty \right\rangle = \{x \in \mathbb{R} \mid a \le x\} \cup \{-\infty, \infty\}` + + `\left\langle -\infty, \infty \right\rangle = \mathbb{R} \cup \{-\infty, \infty\}` + + ``oo`` and ``-oo`` are added to the second and third definition respectively, + since if either ``-oo`` or ``oo`` is an argument, then the other one should + be included (though not as an end point). This is forced, since we have, + for example, ``1/AccumBounds(0, 1) = AccumBounds(1, oo)``, and the limit at + `0` is not one-sided. As `x` tends to `0-`, then `1/x \rightarrow -\infty`, so `-\infty` + should be interpreted as belonging to ``AccumBounds(1, oo)`` though it need + not appear explicitly. + + In many cases it suffices to know that the limit set is bounded. + However, in some other cases more exact information could be useful. + For example, all accumulation values of `\cos(x) + 1` are non-negative. + (``AccumBounds(-1, 1) + 1 = AccumBounds(0, 2)``) + + A AccumulationBounds object is defined to be real AccumulationBounds, + if its end points are finite reals. + + Let `X`, `Y` be real AccumulationBounds, then their sum, difference, + product are defined to be the following sets: + + `X + Y = \{ x+y \mid x \in X \cap y \in Y\}` + + `X - Y = \{ x-y \mid x \in X \cap y \in Y\}` + + `X \times Y = \{ x \times y \mid x \in X \cap y \in Y\}` + + When an AccumBounds is raised to a negative power, if 0 is contained + between the bounds then an infinite range is returned, otherwise if an + endpoint is 0 then a semi-infinite range with consistent sign will be returned. + + AccumBounds in expressions behave a lot like Intervals but the + semantics are not necessarily the same. Division (or exponentiation + to a negative integer power) could be handled with *intervals* by + returning a union of the results obtained after splitting the + bounds between negatives and positives, but that is not done with + AccumBounds. In addition, bounds are assumed to be independent of + each other; if the same bound is used in more than one place in an + expression, the result may not be the supremum or infimum of the + expression (see below). Finally, when a boundary is ``1``, + exponentiation to the power of ``oo`` yields ``oo``, neither + ``1`` nor ``nan``. + + Examples + ======== + + >>> from sympy import AccumBounds, sin, exp, log, pi, E, S, oo + >>> from sympy.abc import x + + >>> AccumBounds(0, 1) + AccumBounds(1, 2) + AccumBounds(1, 3) + + >>> AccumBounds(0, 1) - AccumBounds(0, 2) + AccumBounds(-2, 1) + + >>> AccumBounds(-2, 3)*AccumBounds(-1, 1) + AccumBounds(-3, 3) + + >>> AccumBounds(1, 2)*AccumBounds(3, 5) + AccumBounds(3, 10) + + The exponentiation of AccumulationBounds is defined + as follows: + + If 0 does not belong to `X` or `n > 0` then + + `X^n = \{ x^n \mid x \in X\}` + + >>> AccumBounds(1, 4)**(S(1)/2) + AccumBounds(1, 2) + + otherwise, an infinite or semi-infinite result is obtained: + + >>> 1/AccumBounds(-1, 1) + AccumBounds(-oo, oo) + >>> 1/AccumBounds(0, 2) + AccumBounds(1/2, oo) + >>> 1/AccumBounds(-oo, 0) + AccumBounds(-oo, 0) + + A boundary of 1 will always generate all nonnegatives: + + >>> AccumBounds(1, 2)**oo + AccumBounds(0, oo) + >>> AccumBounds(0, 1)**oo + AccumBounds(0, oo) + + If the exponent is itself an AccumulationBounds or is not an + integer then unevaluated results will be returned unless the base + values are positive: + + >>> AccumBounds(2, 3)**AccumBounds(-1, 2) + AccumBounds(1/3, 9) + >>> AccumBounds(-2, 3)**AccumBounds(-1, 2) + AccumBounds(-2, 3)**AccumBounds(-1, 2) + + >>> AccumBounds(-2, -1)**(S(1)/2) + sqrt(AccumBounds(-2, -1)) + + Note: `\left\langle a, b\right\rangle^2` is not same as `\left\langle a, b\right\rangle \times \left\langle a, b\right\rangle` + + >>> AccumBounds(-1, 1)**2 + AccumBounds(0, 1) + + >>> AccumBounds(1, 3) < 4 + True + + >>> AccumBounds(1, 3) < -1 + False + + Some elementary functions can also take AccumulationBounds as input. + A function `f` evaluated for some real AccumulationBounds `\left\langle a, b \right\rangle` + is defined as `f(\left\langle a, b\right\rangle) = \{ f(x) \mid a \le x \le b \}` + + >>> sin(AccumBounds(pi/6, pi/3)) + AccumBounds(1/2, sqrt(3)/2) + + >>> exp(AccumBounds(0, 1)) + AccumBounds(1, E) + + >>> log(AccumBounds(1, E)) + AccumBounds(0, 1) + + Some symbol in an expression can be substituted for a AccumulationBounds + object. But it does not necessarily evaluate the AccumulationBounds for + that expression. + + The same expression can be evaluated to different values depending upon + the form it is used for substitution since each instance of an + AccumulationBounds is considered independent. For example: + + >>> (x**2 + 2*x + 1).subs(x, AccumBounds(-1, 1)) + AccumBounds(-1, 4) + + >>> ((x + 1)**2).subs(x, AccumBounds(-1, 1)) + AccumBounds(0, 4) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Interval_arithmetic + + .. [2] https://fab.cba.mit.edu/classes/S62.12/docs/Hickey_interval.pdf + + Notes + ===== + + Do not use ``AccumulationBounds`` for floating point interval arithmetic + calculations, use ``mpmath.iv`` instead. + """ + + is_extended_real = True + is_number = False + + def __new__(cls, min, max) -> Expr: # type: ignore + + min = _sympify(min) + max = _sympify(max) + + # Only allow real intervals (use symbols with 'is_extended_real=True'). + if not min.is_extended_real or not max.is_extended_real: + raise ValueError("Only real AccumulationBounds are supported") + + if max == min: + return max + + # Make sure that the created AccumBounds object will be valid. + if max.is_number and min.is_number: + bad = max.is_comparable and min.is_comparable and max < min + else: + bad = (max - min).is_extended_negative + if bad: + raise ValueError( + "Lower limit should be smaller than upper limit") + + return Basic.__new__(cls, min, max) + + # setting the operation priority + _op_priority = 11.0 + + def _eval_is_real(self): + if self.min.is_real and self.max.is_real: + return True + + @property + def min(self): + """ + Returns the minimum possible value attained by AccumulationBounds + object. + + Examples + ======== + + >>> from sympy import AccumBounds + >>> AccumBounds(1, 3).min + 1 + + """ + return self.args[0] + + @property + def max(self): + """ + Returns the maximum possible value attained by AccumulationBounds + object. + + Examples + ======== + + >>> from sympy import AccumBounds + >>> AccumBounds(1, 3).max + 3 + + """ + return self.args[1] + + @property + def delta(self): + """ + Returns the difference of maximum possible value attained by + AccumulationBounds object and minimum possible value attained + by AccumulationBounds object. + + Examples + ======== + + >>> from sympy import AccumBounds + >>> AccumBounds(1, 3).delta + 2 + + """ + return self.max - self.min + + @property + def mid(self): + """ + Returns the mean of maximum possible value attained by + AccumulationBounds object and minimum possible value + attained by AccumulationBounds object. + + Examples + ======== + + >>> from sympy import AccumBounds + >>> AccumBounds(1, 3).mid + 2 + + """ + return (self.min + self.max) / 2 + + @_sympifyit('other', NotImplemented) + def _eval_power(self, other): + return self.__pow__(other) + + @_sympifyit('other', NotImplemented) + def __add__(self, other): + if isinstance(other, Expr): + if isinstance(other, AccumBounds): + return AccumBounds( + Add(self.min, other.min), + Add(self.max, other.max)) + if other is S.Infinity and self.min is S.NegativeInfinity or \ + other is S.NegativeInfinity and self.max is S.Infinity: + return AccumBounds(-oo, oo) + elif other.is_extended_real: + if self.min is S.NegativeInfinity and self.max is S.Infinity: + return AccumBounds(-oo, oo) + elif self.min is S.NegativeInfinity: + return AccumBounds(-oo, self.max + other) + elif self.max is S.Infinity: + return AccumBounds(self.min + other, oo) + else: + return AccumBounds(Add(self.min, other), Add(self.max, other)) + return Add(self, other, evaluate=False) + return NotImplemented + + __radd__ = __add__ + + def __neg__(self): + return AccumBounds(-self.max, -self.min) + + @_sympifyit('other', NotImplemented) + def __sub__(self, other): + if isinstance(other, Expr): + if isinstance(other, AccumBounds): + return AccumBounds( + Add(self.min, -other.max), + Add(self.max, -other.min)) + if other is S.NegativeInfinity and self.min is S.NegativeInfinity or \ + other is S.Infinity and self.max is S.Infinity: + return AccumBounds(-oo, oo) + elif other.is_extended_real: + if self.min is S.NegativeInfinity and self.max is S.Infinity: + return AccumBounds(-oo, oo) + elif self.min is S.NegativeInfinity: + return AccumBounds(-oo, self.max - other) + elif self.max is S.Infinity: + return AccumBounds(self.min - other, oo) + else: + return AccumBounds( + Add(self.min, -other), + Add(self.max, -other)) + return Add(self, -other, evaluate=False) + return NotImplemented + + @_sympifyit('other', NotImplemented) + def __rsub__(self, other): + return self.__neg__() + other + + @_sympifyit('other', NotImplemented) + def __mul__(self, other): + if self.args == (-oo, oo): + return self + if isinstance(other, Expr): + if isinstance(other, AccumBounds): + if other.args == (-oo, oo): + return other + v = set() + for a in self.args: + vi = other*a + v.update(vi.args or (vi,)) + return AccumBounds(Min(*v), Max(*v)) + if other is S.Infinity: + if self.min.is_zero: + return AccumBounds(0, oo) + if self.max.is_zero: + return AccumBounds(-oo, 0) + if other is S.NegativeInfinity: + if self.min.is_zero: + return AccumBounds(-oo, 0) + if self.max.is_zero: + return AccumBounds(0, oo) + if other.is_extended_real: + if other.is_zero: + if self.max is S.Infinity: + return AccumBounds(0, oo) + if self.min is S.NegativeInfinity: + return AccumBounds(-oo, 0) + return S.Zero + if other.is_extended_positive: + return AccumBounds( + Mul(self.min, other), + Mul(self.max, other)) + elif other.is_extended_negative: + return AccumBounds( + Mul(self.max, other), + Mul(self.min, other)) + if isinstance(other, Order): + return other + return Mul(self, other, evaluate=False) + return NotImplemented + + __rmul__ = __mul__ + + @_sympifyit('other', NotImplemented) + def __truediv__(self, other): + if isinstance(other, Expr): + if isinstance(other, AccumBounds): + if other.min.is_positive or other.max.is_negative: + return self * AccumBounds(1/other.max, 1/other.min) + + if (self.min.is_extended_nonpositive and self.max.is_extended_nonnegative and + other.min.is_extended_nonpositive and other.max.is_extended_nonnegative): + if self.min.is_zero and other.min.is_zero: + return AccumBounds(0, oo) + if self.max.is_zero and other.min.is_zero: + return AccumBounds(-oo, 0) + return AccumBounds(-oo, oo) + + if self.max.is_extended_negative: + if other.min.is_extended_negative: + if other.max.is_zero: + return AccumBounds(self.max / other.min, oo) + if other.max.is_extended_positive: + # if we were dealing with intervals we would return + # Union(Interval(-oo, self.max/other.max), + # Interval(self.max/other.min, oo)) + return AccumBounds(-oo, oo) + + if other.min.is_zero and other.max.is_extended_positive: + return AccumBounds(-oo, self.max / other.max) + + if self.min.is_extended_positive: + if other.min.is_extended_negative: + if other.max.is_zero: + return AccumBounds(-oo, self.min / other.min) + if other.max.is_extended_positive: + # if we were dealing with intervals we would return + # Union(Interval(-oo, self.min/other.min), + # Interval(self.min/other.max, oo)) + return AccumBounds(-oo, oo) + + if other.min.is_zero and other.max.is_extended_positive: + return AccumBounds(self.min / other.max, oo) + + elif other.is_extended_real: + if other in (S.Infinity, S.NegativeInfinity): + if self == AccumBounds(-oo, oo): + return AccumBounds(-oo, oo) + if self.max is S.Infinity: + return AccumBounds(Min(0, other), Max(0, other)) + if self.min is S.NegativeInfinity: + return AccumBounds(Min(0, -other), Max(0, -other)) + if other.is_extended_positive: + return AccumBounds(self.min / other, self.max / other) + elif other.is_extended_negative: + return AccumBounds(self.max / other, self.min / other) + if (1 / other) is S.ComplexInfinity: + return Mul(self, 1 / other, evaluate=False) + else: + return Mul(self, 1 / other) + + return NotImplemented + + @_sympifyit('other', NotImplemented) + def __rtruediv__(self, other): + if isinstance(other, Expr): + if other.is_extended_real: + if other.is_zero: + return S.Zero + if (self.min.is_extended_nonpositive and self.max.is_extended_nonnegative): + if self.min.is_zero: + if other.is_extended_positive: + return AccumBounds(Mul(other, 1 / self.max), oo) + if other.is_extended_negative: + return AccumBounds(-oo, Mul(other, 1 / self.max)) + if self.max.is_zero: + if other.is_extended_positive: + return AccumBounds(-oo, Mul(other, 1 / self.min)) + if other.is_extended_negative: + return AccumBounds(Mul(other, 1 / self.min), oo) + return AccumBounds(-oo, oo) + else: + return AccumBounds(Min(other / self.min, other / self.max), + Max(other / self.min, other / self.max)) + return Mul(other, 1 / self, evaluate=False) + else: + return NotImplemented + + @_sympifyit('other', NotImplemented) + def __pow__(self, other): + if isinstance(other, Expr): + if other is S.Infinity: + if self.min.is_extended_nonnegative: + if self.max < 1: + return S.Zero + if self.min > 1: + return S.Infinity + return AccumBounds(0, oo) + elif self.max.is_extended_negative: + if self.min > -1: + return S.Zero + if self.max < -1: + return zoo + return S.NaN + else: + if self.min > -1: + if self.max < 1: + return S.Zero + return AccumBounds(0, oo) + return AccumBounds(-oo, oo) + + if other is S.NegativeInfinity: + return (1/self)**oo + + # generically true + if (self.max - self.min).is_nonnegative: + # well defined + if self.min.is_nonnegative: + # no 0 to worry about + if other.is_nonnegative: + # no infinity to worry about + return self.func(self.min**other, self.max**other) + + if other.is_zero: + return S.One # x**0 = 1 + + if other.is_Integer or other.is_integer: + if self.min.is_extended_positive: + return AccumBounds( + Min(self.min**other, self.max**other), + Max(self.min**other, self.max**other)) + elif self.max.is_extended_negative: + return AccumBounds( + Min(self.max**other, self.min**other), + Max(self.max**other, self.min**other)) + + if other % 2 == 0: + if other.is_extended_negative: + if self.min.is_zero: + return AccumBounds(self.max**other, oo) + if self.max.is_zero: + return AccumBounds(self.min**other, oo) + return (1/self)**(-other) + return AccumBounds( + S.Zero, Max(self.min**other, self.max**other)) + elif other % 2 == 1: + if other.is_extended_negative: + if self.min.is_zero: + return AccumBounds(self.max**other, oo) + if self.max.is_zero: + return AccumBounds(-oo, self.min**other) + return (1/self)**(-other) + return AccumBounds(self.min**other, self.max**other) + + # non-integer exponent + # 0**neg or neg**frac yields complex + if (other.is_number or other.is_rational) and ( + self.min.is_extended_nonnegative or ( + other.is_extended_nonnegative and + self.min.is_extended_nonnegative)): + num, den = other.as_numer_denom() + if num is S.One: + return AccumBounds(*[i**(1/den) for i in self.args]) + + elif den is not S.One: # e.g. if other is not Float + return (self**num)**(1/den) # ok for non-negative base + + if isinstance(other, AccumBounds): + if (self.min.is_extended_positive or + self.min.is_extended_nonnegative and + other.min.is_extended_nonnegative): + p = [self**i for i in other.args] + if not any(i.is_Pow for i in p): + a = [j for i in p for j in i.args or (i,)] + try: + return self.func(min(a), max(a)) + except TypeError: # can't sort + pass + + return Pow(self, other, evaluate=False) + + return NotImplemented + + @_sympifyit('other', NotImplemented) + def __rpow__(self, other): + if other.is_real and other.is_extended_nonnegative and ( + self.max - self.min).is_extended_positive: + if other is S.One: + return S.One + if other.is_extended_positive: + a, b = [other**i for i in self.args] + if min(a, b) != a: + a, b = b, a + return self.func(a, b) + if other.is_zero: + if self.min.is_zero: + return self.func(0, 1) + if self.min.is_extended_positive: + return S.Zero + + return Pow(other, self, evaluate=False) + + def __abs__(self): + if self.max.is_extended_negative: + return self.__neg__() + elif self.min.is_extended_negative: + return AccumBounds(S.Zero, Max(abs(self.min), self.max)) + else: + return self + + + def __contains__(self, other): + """ + Returns ``True`` if other is contained in self, where other + belongs to extended real numbers, ``False`` if not contained, + otherwise TypeError is raised. + + Examples + ======== + + >>> from sympy import AccumBounds, oo + >>> 1 in AccumBounds(-1, 3) + True + + -oo and oo go together as limits (in AccumulationBounds). + + >>> -oo in AccumBounds(1, oo) + True + + >>> oo in AccumBounds(-oo, 0) + True + + """ + other = _sympify(other) + + if other in (S.Infinity, S.NegativeInfinity): + if self.min is S.NegativeInfinity or self.max is S.Infinity: + return True + return False + + rv = And(self.min <= other, self.max >= other) + if rv not in (True, False): + raise TypeError("input failed to evaluate") + return rv + + def intersection(self, other): + """ + Returns the intersection of 'self' and 'other'. + Here other can be an instance of :py:class:`~.FiniteSet` or AccumulationBounds. + + Parameters + ========== + + other : AccumulationBounds + Another AccumulationBounds object with which the intersection + has to be computed. + + Returns + ======= + + AccumulationBounds + Intersection of ``self`` and ``other``. + + Examples + ======== + + >>> from sympy import AccumBounds, FiniteSet + >>> AccumBounds(1, 3).intersection(AccumBounds(2, 4)) + AccumBounds(2, 3) + + >>> AccumBounds(1, 3).intersection(AccumBounds(4, 6)) + EmptySet + + >>> AccumBounds(1, 4).intersection(FiniteSet(1, 2, 5)) + {1, 2} + + """ + if not isinstance(other, (AccumBounds, FiniteSet)): + raise TypeError( + "Input must be AccumulationBounds or FiniteSet object") + + if isinstance(other, FiniteSet): + fin_set = S.EmptySet + for i in other: + if i in self: + fin_set = fin_set + FiniteSet(i) + return fin_set + + if self.max < other.min or self.min > other.max: + return S.EmptySet + + if self.min <= other.min: + if self.max <= other.max: + return AccumBounds(other.min, self.max) + if self.max > other.max: + return other + + if other.min <= self.min: + if other.max < self.max: + return AccumBounds(self.min, other.max) + if other.max > self.max: + return self + + def union(self, other): + # TODO : Devise a better method for Union of AccumBounds + # this method is not actually correct and + # can be made better + if not isinstance(other, AccumBounds): + raise TypeError( + "Input must be AccumulationBounds or FiniteSet object") + + if self.min <= other.min and self.max >= other.min: + return AccumBounds(self.min, Max(self.max, other.max)) + + if other.min <= self.min and other.max >= self.min: + return AccumBounds(other.min, Max(self.max, other.max)) + + +@dispatch(AccumulationBounds, AccumulationBounds) # type: ignore # noqa:F811 +def _eval_is_le(lhs, rhs): # noqa:F811 + if is_le(lhs.max, rhs.min): + return True + if is_gt(lhs.min, rhs.max): + return False + + +@dispatch(AccumulationBounds, Basic) # type: ignore # noqa:F811 +def _eval_is_le(lhs, rhs): # noqa: F811 + + """ + Returns ``True `` if range of values attained by ``lhs`` AccumulationBounds + object is greater than the range of values attained by ``rhs``, + where ``rhs`` may be any value of type AccumulationBounds object or + extended real number value, ``False`` if ``rhs`` satisfies + the same property, else an unevaluated :py:class:`~.Relational`. + + Examples + ======== + + >>> from sympy import AccumBounds, oo + >>> AccumBounds(1, 3) > AccumBounds(4, oo) + False + >>> AccumBounds(1, 4) > AccumBounds(3, 4) + AccumBounds(1, 4) > AccumBounds(3, 4) + >>> AccumBounds(1, oo) > -1 + True + + """ + if not rhs.is_extended_real: + raise TypeError( + "Invalid comparison of %s %s" % + (type(rhs), rhs)) + elif rhs.is_comparable: + if is_le(lhs.max, rhs): + return True + if is_gt(lhs.min, rhs): + return False + + +@dispatch(AccumulationBounds, AccumulationBounds) +def _eval_is_ge(lhs, rhs): # noqa:F811 + if is_ge(lhs.min, rhs.max): + return True + if is_lt(lhs.max, rhs.min): + return False + + +@dispatch(AccumulationBounds, Expr) # type:ignore +def _eval_is_ge(lhs, rhs): # noqa: F811 + """ + Returns ``True`` if range of values attained by ``lhs`` AccumulationBounds + object is less that the range of values attained by ``rhs``, where + other may be any value of type AccumulationBounds object or extended + real number value, ``False`` if ``rhs`` satisfies the same + property, else an unevaluated :py:class:`~.Relational`. + + Examples + ======== + + >>> from sympy import AccumBounds, oo + >>> AccumBounds(1, 3) >= AccumBounds(4, oo) + False + >>> AccumBounds(1, 4) >= AccumBounds(3, 4) + AccumBounds(1, 4) >= AccumBounds(3, 4) + >>> AccumBounds(1, oo) >= 1 + True + """ + + if not rhs.is_extended_real: + raise TypeError( + "Invalid comparison of %s %s" % + (type(rhs), rhs)) + elif rhs.is_comparable: + if is_ge(lhs.min, rhs): + return True + if is_lt(lhs.max, rhs): + return False + + +@dispatch(Expr, AccumulationBounds) # type:ignore +def _eval_is_ge(lhs, rhs): # noqa:F811 + if not lhs.is_extended_real: + raise TypeError( + "Invalid comparison of %s %s" % + (type(lhs), lhs)) + elif lhs.is_comparable: + if is_le(rhs.max, lhs): + return True + if is_gt(rhs.min, lhs): + return False + + +@dispatch(AccumulationBounds, AccumulationBounds) # type:ignore +def _eval_is_ge(lhs, rhs): # noqa:F811 + if is_ge(lhs.min, rhs.max): + return True + if is_lt(lhs.max, rhs.min): + return False + +# setting an alias for AccumulationBounds +AccumBounds = AccumulationBounds diff --git a/.venv/lib/python3.13/site-packages/sympy/calculus/euler.py b/.venv/lib/python3.13/site-packages/sympy/calculus/euler.py new file mode 100644 index 0000000000000000000000000000000000000000..817acf76091dfba2dba40487ca7735e307c0fc15 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/calculus/euler.py @@ -0,0 +1,108 @@ +""" +This module implements a method to find +Euler-Lagrange Equations for given Lagrangian. +""" +from itertools import combinations_with_replacement +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.utilities.iterables import iterable + + +def euler_equations(L, funcs=(), vars=()): + r""" + Find the Euler-Lagrange equations [1]_ for a given Lagrangian. + + Parameters + ========== + + L : Expr + The Lagrangian that should be a function of the functions listed + in the second argument and their derivatives. + + For example, in the case of two functions $f(x,y)$, $g(x,y)$ and + two independent variables $x$, $y$ the Lagrangian has the form: + + .. math:: L\left(f(x,y),g(x,y),\frac{\partial f(x,y)}{\partial x}, + \frac{\partial f(x,y)}{\partial y}, + \frac{\partial g(x,y)}{\partial x}, + \frac{\partial g(x,y)}{\partial y},x,y\right) + + In many cases it is not necessary to provide anything, except the + Lagrangian, it will be auto-detected (and an error raised if this + cannot be done). + + funcs : Function or an iterable of Functions + The functions that the Lagrangian depends on. The Euler equations + are differential equations for each of these functions. + + vars : Symbol or an iterable of Symbols + The Symbols that are the independent variables of the functions. + + Returns + ======= + + eqns : list of Eq + The list of differential equations, one for each function. + + Examples + ======== + + >>> from sympy import euler_equations, Symbol, Function + >>> x = Function('x') + >>> t = Symbol('t') + >>> L = (x(t).diff(t))**2/2 - x(t)**2/2 + >>> euler_equations(L, x(t), t) + [Eq(-x(t) - Derivative(x(t), (t, 2)), 0)] + >>> u = Function('u') + >>> x = Symbol('x') + >>> L = (u(t, x).diff(t))**2/2 - (u(t, x).diff(x))**2/2 + >>> euler_equations(L, u(t, x), [t, x]) + [Eq(-Derivative(u(t, x), (t, 2)) + Derivative(u(t, x), (x, 2)), 0)] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Euler%E2%80%93Lagrange_equation + + """ + + funcs = tuple(funcs) if iterable(funcs) else (funcs,) + + if not funcs: + funcs = tuple(L.atoms(Function)) + else: + for f in funcs: + if not isinstance(f, Function): + raise TypeError('Function expected, got: %s' % f) + + vars = tuple(vars) if iterable(vars) else (vars,) + + if not vars: + vars = funcs[0].args + else: + vars = tuple(sympify(var) for var in vars) + + if not all(isinstance(v, Symbol) for v in vars): + raise TypeError('Variables are not symbols, got %s' % vars) + + for f in funcs: + if not vars == f.args: + raise ValueError("Variables %s do not match args: %s" % (vars, f)) + + order = max([len(d.variables) for d in L.atoms(Derivative) + if d.expr in funcs] + [0]) + + eqns = [] + for f in funcs: + eq = diff(L, f) + for i in range(1, order + 1): + for p in combinations_with_replacement(vars, i): + eq = eq + S.NegativeOne**i*diff(L, diff(f, *p), *p) + new_eq = Eq(eq, 0) + if isinstance(new_eq, Eq): + eqns.append(new_eq) + + return eqns diff --git a/.venv/lib/python3.13/site-packages/sympy/calculus/finite_diff.py b/.venv/lib/python3.13/site-packages/sympy/calculus/finite_diff.py new file mode 100644 index 0000000000000000000000000000000000000000..17eece149aadad236cefeb350e1ef4a383c84f01 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/calculus/finite_diff.py @@ -0,0 +1,476 @@ +""" +Finite difference weights +========================= + +This module implements an algorithm for efficient generation of finite +difference weights for ordinary differentials of functions for +derivatives from 0 (interpolation) up to arbitrary order. + +The core algorithm is provided in the finite difference weight generating +function (``finite_diff_weights``), and two convenience functions are provided +for: + +- estimating a derivative (or interpolate) directly from a series of points + is also provided (``apply_finite_diff``). +- differentiating by using finite difference approximations + (``differentiate_finite``). + +""" + +from sympy.core.function import Derivative +from sympy.core.singleton import S +from sympy.core.function import Subs +from sympy.core.traversal import preorder_traversal +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import iterable + + + +def finite_diff_weights(order, x_list, x0=S.One): + """ + Calculates the finite difference weights for an arbitrarily spaced + one-dimensional grid (``x_list``) for derivatives at ``x0`` of order + 0, 1, ..., up to ``order`` using a recursive formula. Order of accuracy + is at least ``len(x_list) - order``, if ``x_list`` is defined correctly. + + Parameters + ========== + + order: int + Up to what derivative order weights should be calculated. + 0 corresponds to interpolation. + x_list: sequence + Sequence of (unique) values for the independent variable. + It is useful (but not necessary) to order ``x_list`` from + nearest to furthest from ``x0``; see examples below. + x0: Number or Symbol + Root or value of the independent variable for which the finite + difference weights should be generated. Default is ``S.One``. + + Returns + ======= + + list + A list of sublists, each corresponding to coefficients for + increasing derivative order, and each containing lists of + coefficients for increasing subsets of x_list. + + Examples + ======== + + >>> from sympy import finite_diff_weights, S + >>> res = finite_diff_weights(1, [-S(1)/2, S(1)/2, S(3)/2, S(5)/2], 0) + >>> res + [[[1, 0, 0, 0], + [1/2, 1/2, 0, 0], + [3/8, 3/4, -1/8, 0], + [5/16, 15/16, -5/16, 1/16]], + [[0, 0, 0, 0], + [-1, 1, 0, 0], + [-1, 1, 0, 0], + [-23/24, 7/8, 1/8, -1/24]]] + >>> res[0][-1] # FD weights for 0th derivative, using full x_list + [5/16, 15/16, -5/16, 1/16] + >>> res[1][-1] # FD weights for 1st derivative + [-23/24, 7/8, 1/8, -1/24] + >>> res[1][-2] # FD weights for 1st derivative, using x_list[:-1] + [-1, 1, 0, 0] + >>> res[1][-1][0] # FD weight for 1st deriv. for x_list[0] + -23/24 + >>> res[1][-1][1] # FD weight for 1st deriv. for x_list[1], etc. + 7/8 + + Each sublist contains the most accurate formula at the end. + Note, that in the above example ``res[1][1]`` is the same as ``res[1][2]``. + Since res[1][2] has an order of accuracy of + ``len(x_list[:3]) - order = 3 - 1 = 2``, the same is true for ``res[1][1]``! + + >>> res = finite_diff_weights(1, [S(0), S(1), -S(1), S(2), -S(2)], 0)[1] + >>> res + [[0, 0, 0, 0, 0], + [-1, 1, 0, 0, 0], + [0, 1/2, -1/2, 0, 0], + [-1/2, 1, -1/3, -1/6, 0], + [0, 2/3, -2/3, -1/12, 1/12]] + >>> res[0] # no approximation possible, using x_list[0] only + [0, 0, 0, 0, 0] + >>> res[1] # classic forward step approximation + [-1, 1, 0, 0, 0] + >>> res[2] # classic centered approximation + [0, 1/2, -1/2, 0, 0] + >>> res[3:] # higher order approximations + [[-1/2, 1, -1/3, -1/6, 0], [0, 2/3, -2/3, -1/12, 1/12]] + + Let us compare this to a differently defined ``x_list``. Pay attention to + ``foo[i][k]`` corresponding to the gridpoint defined by ``x_list[k]``. + + >>> foo = finite_diff_weights(1, [-S(2), -S(1), S(0), S(1), S(2)], 0)[1] + >>> foo + [[0, 0, 0, 0, 0], + [-1, 1, 0, 0, 0], + [1/2, -2, 3/2, 0, 0], + [1/6, -1, 1/2, 1/3, 0], + [1/12, -2/3, 0, 2/3, -1/12]] + >>> foo[1] # not the same and of lower accuracy as res[1]! + [-1, 1, 0, 0, 0] + >>> foo[2] # classic double backward step approximation + [1/2, -2, 3/2, 0, 0] + >>> foo[4] # the same as res[4] + [1/12, -2/3, 0, 2/3, -1/12] + + Note that, unless you plan on using approximations based on subsets of + ``x_list``, the order of gridpoints does not matter. + + The capability to generate weights at arbitrary points can be + used e.g. to minimize Runge's phenomenon by using Chebyshev nodes: + + >>> from sympy import cos, symbols, pi, simplify + >>> N, (h, x) = 4, symbols('h x') + >>> x_list = [x+h*cos(i*pi/(N)) for i in range(N,-1,-1)] # chebyshev nodes + >>> print(x_list) + [-h + x, -sqrt(2)*h/2 + x, x, sqrt(2)*h/2 + x, h + x] + >>> mycoeffs = finite_diff_weights(1, x_list, 0)[1][4] + >>> [simplify(c) for c in mycoeffs] #doctest: +NORMALIZE_WHITESPACE + [(h**3/2 + h**2*x - 3*h*x**2 - 4*x**3)/h**4, + (-sqrt(2)*h**3 - 4*h**2*x + 3*sqrt(2)*h*x**2 + 8*x**3)/h**4, + (6*h**2*x - 8*x**3)/h**4, + (sqrt(2)*h**3 - 4*h**2*x - 3*sqrt(2)*h*x**2 + 8*x**3)/h**4, + (-h**3/2 + h**2*x + 3*h*x**2 - 4*x**3)/h**4] + + Notes + ===== + + If weights for a finite difference approximation of 3rd order + derivative is wanted, weights for 0th, 1st and 2nd order are + calculated "for free", so are formulae using subsets of ``x_list``. + This is something one can take advantage of to save computational cost. + Be aware that one should define ``x_list`` from nearest to furthest from + ``x0``. If not, subsets of ``x_list`` will yield poorer approximations, + which might not grand an order of accuracy of ``len(x_list) - order``. + + See also + ======== + + sympy.calculus.finite_diff.apply_finite_diff + + References + ========== + + .. [1] Generation of Finite Difference Formulas on Arbitrarily Spaced + Grids, Bengt Fornberg; Mathematics of computation; 51; 184; + (1988); 699-706; doi:10.1090/S0025-5718-1988-0935077-0 + + """ + # The notation below closely corresponds to the one used in the paper. + order = S(order) + if not order.is_number: + raise ValueError("Cannot handle symbolic order.") + if order < 0: + raise ValueError("Negative derivative order illegal.") + if int(order) != order: + raise ValueError("Non-integer order illegal") + M = order + N = len(x_list) - 1 + delta = [[[0 for nu in range(N+1)] for n in range(N+1)] for + m in range(M+1)] + delta[0][0][0] = S.One + c1 = S.One + for n in range(1, N+1): + c2 = S.One + for nu in range(n): + c3 = x_list[n] - x_list[nu] + c2 = c2 * c3 + if n <= M: + delta[n][n-1][nu] = 0 + for m in range(min(n, M)+1): + delta[m][n][nu] = (x_list[n]-x0)*delta[m][n-1][nu] -\ + m*delta[m-1][n-1][nu] + delta[m][n][nu] /= c3 + for m in range(min(n, M)+1): + delta[m][n][n] = c1/c2*(m*delta[m-1][n-1][n-1] - + (x_list[n-1]-x0)*delta[m][n-1][n-1]) + c1 = c2 + return delta + + +def apply_finite_diff(order, x_list, y_list, x0=S.Zero): + """ + Calculates the finite difference approximation of + the derivative of requested order at ``x0`` from points + provided in ``x_list`` and ``y_list``. + + Parameters + ========== + + order: int + order of derivative to approximate. 0 corresponds to interpolation. + x_list: sequence + Sequence of (unique) values for the independent variable. + y_list: sequence + The function value at corresponding values for the independent + variable in x_list. + x0: Number or Symbol + At what value of the independent variable the derivative should be + evaluated. Defaults to 0. + + Returns + ======= + + sympy.core.add.Add or sympy.core.numbers.Number + The finite difference expression approximating the requested + derivative order at ``x0``. + + Examples + ======== + + >>> from sympy import apply_finite_diff + >>> cube = lambda arg: (1.0*arg)**3 + >>> xlist = range(-3,3+1) + >>> apply_finite_diff(2, xlist, map(cube, xlist), 2) - 12 # doctest: +SKIP + -3.55271367880050e-15 + + we see that the example above only contain rounding errors. + apply_finite_diff can also be used on more abstract objects: + + >>> from sympy import IndexedBase, Idx + >>> x, y = map(IndexedBase, 'xy') + >>> i = Idx('i') + >>> x_list, y_list = zip(*[(x[i+j], y[i+j]) for j in range(-1,2)]) + >>> apply_finite_diff(1, x_list, y_list, x[i]) + ((x[i + 1] - x[i])/(-x[i - 1] + x[i]) - 1)*y[i]/(x[i + 1] - x[i]) - + (x[i + 1] - x[i])*y[i - 1]/((x[i + 1] - x[i - 1])*(-x[i - 1] + x[i])) + + (-x[i - 1] + x[i])*y[i + 1]/((x[i + 1] - x[i - 1])*(x[i + 1] - x[i])) + + Notes + ===== + + Order = 0 corresponds to interpolation. + Only supply so many points you think makes sense + to around x0 when extracting the derivative (the function + need to be well behaved within that region). Also beware + of Runge's phenomenon. + + See also + ======== + + sympy.calculus.finite_diff.finite_diff_weights + + References + ========== + + Fortran 90 implementation with Python interface for numerics: finitediff_ + + .. _finitediff: https://github.com/bjodah/finitediff + + """ + + # In the original paper the following holds for the notation: + # M = order + # N = len(x_list) - 1 + + N = len(x_list) - 1 + if len(x_list) != len(y_list): + raise ValueError("x_list and y_list not equal in length.") + + delta = finite_diff_weights(order, x_list, x0) + + derivative = 0 + for nu in range(len(x_list)): + derivative += delta[order][N][nu]*y_list[nu] + return derivative + + +def _as_finite_diff(derivative, points=1, x0=None, wrt=None): + """ + Returns an approximation of a derivative of a function in + the form of a finite difference formula. The expression is a + weighted sum of the function at a number of discrete values of + (one of) the independent variable(s). + + Parameters + ========== + + derivative: a Derivative instance + + points: sequence or coefficient, optional + If sequence: discrete values (length >= order+1) of the + independent variable used for generating the finite + difference weights. + If it is a coefficient, it will be used as the step-size + for generating an equidistant sequence of length order+1 + centered around ``x0``. default: 1 (step-size 1) + + x0: number or Symbol, optional + the value of the independent variable (``wrt``) at which the + derivative is to be approximated. Default: same as ``wrt``. + + wrt: Symbol, optional + "with respect to" the variable for which the (partial) + derivative is to be approximated for. If not provided it + is required that the Derivative is ordinary. Default: ``None``. + + Examples + ======== + + >>> from sympy import symbols, Function, exp, sqrt, Symbol + >>> from sympy.calculus.finite_diff import _as_finite_diff + >>> x, h = symbols('x h') + >>> f = Function('f') + >>> _as_finite_diff(f(x).diff(x)) + -f(x - 1/2) + f(x + 1/2) + + The default step size and number of points are 1 and ``order + 1`` + respectively. We can change the step size by passing a symbol + as a parameter: + + >>> _as_finite_diff(f(x).diff(x), h) + -f(-h/2 + x)/h + f(h/2 + x)/h + + We can also specify the discretized values to be used in a sequence: + + >>> _as_finite_diff(f(x).diff(x), [x, x+h, x+2*h]) + -3*f(x)/(2*h) + 2*f(h + x)/h - f(2*h + x)/(2*h) + + The algorithm is not restricted to use equidistant spacing, nor + do we need to make the approximation around ``x0``, but we can get + an expression estimating the derivative at an offset: + + >>> e, sq2 = exp(1), sqrt(2) + >>> xl = [x-h, x+h, x+e*h] + >>> _as_finite_diff(f(x).diff(x, 1), xl, x+h*sq2) + 2*h*((h + sqrt(2)*h)/(2*h) - (-sqrt(2)*h + h)/(2*h))*f(E*h + x)/((-h + E*h)*(h + E*h)) + + (-(-sqrt(2)*h + h)/(2*h) - (-sqrt(2)*h + E*h)/(2*h))*f(-h + x)/(h + E*h) + + (-(h + sqrt(2)*h)/(2*h) + (-sqrt(2)*h + E*h)/(2*h))*f(h + x)/(-h + E*h) + + Partial derivatives are also supported: + + >>> y = Symbol('y') + >>> d2fdxdy=f(x,y).diff(x,y) + >>> _as_finite_diff(d2fdxdy, wrt=x) + -Derivative(f(x - 1/2, y), y) + Derivative(f(x + 1/2, y), y) + + See also + ======== + + sympy.calculus.finite_diff.apply_finite_diff + sympy.calculus.finite_diff.finite_diff_weights + + """ + if derivative.is_Derivative: + pass + elif derivative.is_Atom: + return derivative + else: + return derivative.fromiter( + [_as_finite_diff(ar, points, x0, wrt) for ar + in derivative.args], **derivative.assumptions0) + + if wrt is None: + old = None + for v in derivative.variables: + if old is v: + continue + derivative = _as_finite_diff(derivative, points, x0, v) + old = v + return derivative + + order = derivative.variables.count(wrt) + + if x0 is None: + x0 = wrt + + if not iterable(points): + if getattr(points, 'is_Function', False) and wrt in points.args: + points = points.subs(wrt, x0) + # points is simply the step-size, let's make it a + # equidistant sequence centered around x0 + if order % 2 == 0: + # even order => odd number of points, grid point included + points = [x0 + points*i for i + in range(-order//2, order//2 + 1)] + else: + # odd order => even number of points, half-way wrt grid point + points = [x0 + points*S(i)/2 for i + in range(-order, order + 1, 2)] + others = [wrt, 0] + for v in set(derivative.variables): + if v == wrt: + continue + others += [v, derivative.variables.count(v)] + if len(points) < order+1: + raise ValueError("Too few points for order %d" % order) + return apply_finite_diff(order, points, [ + Derivative(derivative.expr.subs({wrt: x}), *others) for + x in points], x0) + + +def differentiate_finite(expr, *symbols, + points=1, x0=None, wrt=None, evaluate=False): + r""" Differentiate expr and replace Derivatives with finite differences. + + Parameters + ========== + + expr : expression + \*symbols : differentiate with respect to symbols + points: sequence, coefficient or undefined function, optional + see ``Derivative.as_finite_difference`` + x0: number or Symbol, optional + see ``Derivative.as_finite_difference`` + wrt: Symbol, optional + see ``Derivative.as_finite_difference`` + + Examples + ======== + + >>> from sympy import sin, Function, differentiate_finite + >>> from sympy.abc import x, y, h + >>> f, g = Function('f'), Function('g') + >>> differentiate_finite(f(x)*g(x), x, points=[x-h, x+h]) + -f(-h + x)*g(-h + x)/(2*h) + f(h + x)*g(h + x)/(2*h) + + ``differentiate_finite`` works on any expression, including the expressions + with embedded derivatives: + + >>> differentiate_finite(f(x) + sin(x), x, 2) + -2*f(x) + f(x - 1) + f(x + 1) - 2*sin(x) + sin(x - 1) + sin(x + 1) + >>> differentiate_finite(f(x, y), x, y) + f(x - 1/2, y - 1/2) - f(x - 1/2, y + 1/2) - f(x + 1/2, y - 1/2) + f(x + 1/2, y + 1/2) + >>> differentiate_finite(f(x)*g(x).diff(x), x) + (-g(x) + g(x + 1))*f(x + 1/2) - (g(x) - g(x - 1))*f(x - 1/2) + + To make finite difference with non-constant discretization step use + undefined functions: + + >>> dx = Function('dx') + >>> differentiate_finite(f(x)*g(x).diff(x), points=dx(x)) + -(-g(x - dx(x)/2 - dx(x - dx(x)/2)/2)/dx(x - dx(x)/2) + + g(x - dx(x)/2 + dx(x - dx(x)/2)/2)/dx(x - dx(x)/2))*f(x - dx(x)/2)/dx(x) + + (-g(x + dx(x)/2 - dx(x + dx(x)/2)/2)/dx(x + dx(x)/2) + + g(x + dx(x)/2 + dx(x + dx(x)/2)/2)/dx(x + dx(x)/2))*f(x + dx(x)/2)/dx(x) + + """ + if any(term.is_Derivative for term in list(preorder_traversal(expr))): + evaluate = False + + Dexpr = expr.diff(*symbols, evaluate=evaluate) + if evaluate: + sympy_deprecation_warning(""" + The evaluate flag to differentiate_finite() is deprecated. + + evaluate=True expands the intermediate derivatives before computing + differences, but this usually not what you want, as it does not + satisfy the product rule. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-differentiate_finite-evaluate", + ) + return Dexpr.replace( + lambda arg: arg.is_Derivative, + lambda arg: arg.as_finite_difference(points=points, x0=x0, wrt=wrt)) + else: + DFexpr = Dexpr.as_finite_difference(points=points, x0=x0, wrt=wrt) + return DFexpr.replace( + lambda arg: isinstance(arg, Subs), + lambda arg: arg.expr.as_finite_difference( + points=points, x0=arg.point[0], wrt=arg.variables[0])) diff --git a/.venv/lib/python3.13/site-packages/sympy/calculus/singularities.py b/.venv/lib/python3.13/site-packages/sympy/calculus/singularities.py new file mode 100644 index 0000000000000000000000000000000000000000..5adafc59efaf0bff44707f6b5e3f074be6bc1f32 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/calculus/singularities.py @@ -0,0 +1,406 @@ +""" +Singularities +============= + +This module implements algorithms for finding singularities for a function +and identifying types of functions. + +The differential calculus methods in this module include methods to identify +the following function types in the given ``Interval``: +- Increasing +- Strictly Increasing +- Decreasing +- Strictly Decreasing +- Monotonic + +""" + +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.trigonometric import sec, csc, cot, tan, cos +from sympy.functions.elementary.hyperbolic import ( + sech, csch, coth, tanh, cosh, asech, acsch, atanh, acoth) +from sympy.utilities.misc import filldedent + + +def singularities(expression, symbol, domain=None): + """ + Find singularities of a given function. + + Parameters + ========== + + expression : Expr + The target function in which singularities need to be found. + symbol : Symbol + The symbol over the values of which the singularity in + expression in being searched for. + + Returns + ======= + + Set + A set of values for ``symbol`` for which ``expression`` has a + singularity. An ``EmptySet`` is returned if ``expression`` has no + singularities for any given value of ``Symbol``. + + Raises + ====== + + NotImplementedError + Methods for determining the singularities of this function have + not been developed. + + Notes + ===== + + This function does not find non-isolated singularities + nor does it find branch points of the expression. + + Currently supported functions are: + - univariate continuous (real or complex) functions + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Mathematical_singularity + + Examples + ======== + + >>> from sympy import singularities, Symbol, log + >>> x = Symbol('x', real=True) + >>> y = Symbol('y', real=False) + >>> singularities(x**2 + x + 1, x) + EmptySet + >>> singularities(1/(x + 1), x) + {-1} + >>> singularities(1/(y**2 + 1), y) + {-I, I} + >>> singularities(1/(y**3 + 1), y) + {-1, 1/2 - sqrt(3)*I/2, 1/2 + sqrt(3)*I/2} + >>> singularities(log(x), x) + {0} + + """ + from sympy.solvers.solveset import solveset + + if domain is None: + domain = S.Reals if symbol.is_real else S.Complexes + try: + sings = S.EmptySet + e = expression.rewrite([sec, csc, cot, tan], cos) + e = e.rewrite([sech, csch, coth, tanh], cosh) + for i in e.atoms(Pow): + if i.exp.is_infinite: + raise NotImplementedError + if i.exp.is_negative: + # XXX: exponent of varying sign not handled + sings += solveset(i.base, symbol, domain) + for i in expression.atoms(log, asech, acsch): + sings += solveset(i.args[0], symbol, domain) + for i in expression.atoms(atanh, acoth): + sings += solveset(i.args[0] - 1, symbol, domain) + sings += solveset(i.args[0] + 1, symbol, domain) + return sings + except NotImplementedError: + raise NotImplementedError(filldedent(''' + Methods for determining the singularities + of this function have not been developed.''')) + + +########################################################################### +# DIFFERENTIAL CALCULUS METHODS # +########################################################################### + + +def monotonicity_helper(expression, predicate, interval=S.Reals, symbol=None): + """ + Helper function for functions checking function monotonicity. + + Parameters + ========== + + expression : Expr + The target function which is being checked + predicate : function + The property being tested for. The function takes in an integer + and returns a boolean. The integer input is the derivative and + the boolean result should be true if the property is being held, + and false otherwise. + interval : Set, optional + The range of values in which we are testing, defaults to all reals. + symbol : Symbol, optional + The symbol present in expression which gets varied over the given range. + + It returns a boolean indicating whether the interval in which + the function's derivative satisfies given predicate is a superset + of the given interval. + + Returns + ======= + + Boolean + True if ``predicate`` is true for all the derivatives when ``symbol`` + is varied in ``range``, False otherwise. + + """ + from sympy.solvers.solveset import solveset + + expression = sympify(expression) + free = expression.free_symbols + + if symbol is None: + if len(free) > 1: + raise NotImplementedError( + 'The function has not yet been implemented' + ' for all multivariate expressions.' + ) + + variable = symbol or (free.pop() if free else Symbol('x')) + derivative = expression.diff(variable) + predicate_interval = solveset(predicate(derivative), variable, S.Reals) + return interval.is_subset(predicate_interval) + + +def is_increasing(expression, interval=S.Reals, symbol=None): + """ + Return whether the function is increasing in the given interval. + + Parameters + ========== + + expression : Expr + The target function which is being checked. + interval : Set, optional + The range of values in which we are testing (defaults to set of + all real numbers). + symbol : Symbol, optional + The symbol present in expression which gets varied over the given range. + + Returns + ======= + + Boolean + True if ``expression`` is increasing (either strictly increasing or + constant) in the given ``interval``, False otherwise. + + Examples + ======== + + >>> from sympy import is_increasing + >>> from sympy.abc import x, y + >>> from sympy import S, Interval, oo + >>> is_increasing(x**3 - 3*x**2 + 4*x, S.Reals) + True + >>> is_increasing(-x**2, Interval(-oo, 0)) + True + >>> is_increasing(-x**2, Interval(0, oo)) + False + >>> is_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval(-2, 3)) + False + >>> is_increasing(x**2 + y, Interval(1, 2), x) + True + + """ + return monotonicity_helper(expression, lambda x: x >= 0, interval, symbol) + + +def is_strictly_increasing(expression, interval=S.Reals, symbol=None): + """ + Return whether the function is strictly increasing in the given interval. + + Parameters + ========== + + expression : Expr + The target function which is being checked. + interval : Set, optional + The range of values in which we are testing (defaults to set of + all real numbers). + symbol : Symbol, optional + The symbol present in expression which gets varied over the given range. + + Returns + ======= + + Boolean + True if ``expression`` is strictly increasing in the given ``interval``, + False otherwise. + + Examples + ======== + + >>> from sympy import is_strictly_increasing + >>> from sympy.abc import x, y + >>> from sympy import Interval, oo + >>> is_strictly_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval.Ropen(-oo, -2)) + True + >>> is_strictly_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval.Lopen(3, oo)) + True + >>> is_strictly_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval.open(-2, 3)) + False + >>> is_strictly_increasing(-x**2, Interval(0, oo)) + False + >>> is_strictly_increasing(-x**2 + y, Interval(-oo, 0), x) + False + + """ + return monotonicity_helper(expression, lambda x: x > 0, interval, symbol) + + +def is_decreasing(expression, interval=S.Reals, symbol=None): + """ + Return whether the function is decreasing in the given interval. + + Parameters + ========== + + expression : Expr + The target function which is being checked. + interval : Set, optional + The range of values in which we are testing (defaults to set of + all real numbers). + symbol : Symbol, optional + The symbol present in expression which gets varied over the given range. + + Returns + ======= + + Boolean + True if ``expression`` is decreasing (either strictly decreasing or + constant) in the given ``interval``, False otherwise. + + Examples + ======== + + >>> from sympy import is_decreasing + >>> from sympy.abc import x, y + >>> from sympy import S, Interval, oo + >>> is_decreasing(1/(x**2 - 3*x), Interval.open(S(3)/2, 3)) + True + >>> is_decreasing(1/(x**2 - 3*x), Interval.open(1.5, 3)) + True + >>> is_decreasing(1/(x**2 - 3*x), Interval.Lopen(3, oo)) + True + >>> is_decreasing(1/(x**2 - 3*x), Interval.Ropen(-oo, S(3)/2)) + False + >>> is_decreasing(1/(x**2 - 3*x), Interval.Ropen(-oo, 1.5)) + False + >>> is_decreasing(-x**2, Interval(-oo, 0)) + False + >>> is_decreasing(-x**2 + y, Interval(-oo, 0), x) + False + + """ + return monotonicity_helper(expression, lambda x: x <= 0, interval, symbol) + + +def is_strictly_decreasing(expression, interval=S.Reals, symbol=None): + """ + Return whether the function is strictly decreasing in the given interval. + + Parameters + ========== + + expression : Expr + The target function which is being checked. + interval : Set, optional + The range of values in which we are testing (defaults to set of + all real numbers). + symbol : Symbol, optional + The symbol present in expression which gets varied over the given range. + + Returns + ======= + + Boolean + True if ``expression`` is strictly decreasing in the given ``interval``, + False otherwise. + + Examples + ======== + + >>> from sympy import is_strictly_decreasing + >>> from sympy.abc import x, y + >>> from sympy import S, Interval, oo + >>> is_strictly_decreasing(1/(x**2 - 3*x), Interval.Lopen(3, oo)) + True + >>> is_strictly_decreasing(1/(x**2 - 3*x), Interval.Ropen(-oo, S(3)/2)) + False + >>> is_strictly_decreasing(1/(x**2 - 3*x), Interval.Ropen(-oo, 1.5)) + False + >>> is_strictly_decreasing(-x**2, Interval(-oo, 0)) + False + >>> is_strictly_decreasing(-x**2 + y, Interval(-oo, 0), x) + False + + """ + return monotonicity_helper(expression, lambda x: x < 0, interval, symbol) + + +def is_monotonic(expression, interval=S.Reals, symbol=None): + """ + Return whether the function is monotonic in the given interval. + + Parameters + ========== + + expression : Expr + The target function which is being checked. + interval : Set, optional + The range of values in which we are testing (defaults to set of + all real numbers). + symbol : Symbol, optional + The symbol present in expression which gets varied over the given range. + + Returns + ======= + + Boolean + True if ``expression`` is monotonic in the given ``interval``, + False otherwise. + + Raises + ====== + + NotImplementedError + Monotonicity check has not been implemented for the queried function. + + Examples + ======== + + >>> from sympy import is_monotonic + >>> from sympy.abc import x, y + >>> from sympy import S, Interval, oo + >>> is_monotonic(1/(x**2 - 3*x), Interval.open(S(3)/2, 3)) + True + >>> is_monotonic(1/(x**2 - 3*x), Interval.open(1.5, 3)) + True + >>> is_monotonic(1/(x**2 - 3*x), Interval.Lopen(3, oo)) + True + >>> is_monotonic(x**3 - 3*x**2 + 4*x, S.Reals) + True + >>> is_monotonic(-x**2, S.Reals) + False + >>> is_monotonic(x**2 + y + 1, Interval(1, 2), x) + True + + """ + from sympy.solvers.solveset import solveset + + expression = sympify(expression) + + free = expression.free_symbols + if symbol is None and len(free) > 1: + raise NotImplementedError( + 'is_monotonic has not yet been implemented' + ' for all multivariate expressions.' + ) + + variable = symbol or (free.pop() if free else Symbol('x')) + turning_points = solveset(expression.diff(variable), variable, interval) + return interval.intersection(turning_points) is S.EmptySet diff --git a/.venv/lib/python3.13/site-packages/sympy/calculus/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/calculus/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_accumulationbounds.py b/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_accumulationbounds.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc47c66327fe21ddca3a6b73ca5914e0441b38e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_accumulationbounds.py @@ -0,0 +1,336 @@ +from sympy.core.numbers import (E, Rational, oo, pi, zoo) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.core import Add, Mul, Pow +from sympy.core.expr import unchanged +from sympy.testing.pytest import raises, XFAIL +from sympy.abc import x + +a = Symbol('a', real=True) +B = AccumBounds + + +def test_AccumBounds(): + assert B(1, 2).args == (1, 2) + assert B(1, 2).delta is S.One + assert B(1, 2).mid == Rational(3, 2) + assert B(1, 3).is_real == True + + assert B(1, 1) is S.One + + assert B(1, 2) + 1 == B(2, 3) + assert 1 + B(1, 2) == B(2, 3) + assert B(1, 2) + B(2, 3) == B(3, 5) + + assert -B(1, 2) == B(-2, -1) + + assert B(1, 2) - 1 == B(0, 1) + assert 1 - B(1, 2) == B(-1, 0) + assert B(2, 3) - B(1, 2) == B(0, 2) + + assert x + B(1, 2) == Add(B(1, 2), x) + assert a + B(1, 2) == B(1 + a, 2 + a) + assert B(1, 2) - x == Add(B(1, 2), -x) + + assert B(-oo, 1) + oo == B(-oo, oo) + assert B(1, oo) + oo is oo + assert B(1, oo) - oo == B(-oo, oo) + assert (-oo - B(-1, oo)) is -oo + assert B(-oo, 1) - oo is -oo + + assert B(1, oo) - oo == B(-oo, oo) + assert B(-oo, 1) - (-oo) == B(-oo, oo) + assert (oo - B(1, oo)) == B(-oo, oo) + assert (-oo - B(1, oo)) is -oo + + assert B(1, 2)/2 == B(S.Half, 1) + assert 2/B(2, 3) == B(Rational(2, 3), 1) + assert 1/B(-1, 1) == B(-oo, oo) + + assert abs(B(1, 2)) == B(1, 2) + assert abs(B(-2, -1)) == B(1, 2) + assert abs(B(-2, 1)) == B(0, 2) + assert abs(B(-1, 2)) == B(0, 2) + c = Symbol('c') + raises(ValueError, lambda: B(0, c)) + raises(ValueError, lambda: B(1, -1)) + r = Symbol('r', real=True) + raises(ValueError, lambda: B(r, r - 1)) + + +def test_AccumBounds_mul(): + assert B(1, 2)*2 == B(2, 4) + assert 2*B(1, 2) == B(2, 4) + assert B(1, 2)*B(2, 3) == B(2, 6) + assert B(0, 2)*B(2, oo) == B(0, oo) + l, r = B(-oo, oo), B(-a, a) + assert l*r == B(-oo, oo) + assert r*l == B(-oo, oo) + l, r = B(1, oo), B(-3, -2) + assert l*r == B(-oo, -2) + assert r*l == B(-oo, -2) + assert B(1, 2)*0 == 0 + assert B(1, oo)*0 == B(0, oo) + assert B(-oo, 1)*0 == B(-oo, 0) + assert B(-oo, oo)*0 == B(-oo, oo) + + assert B(1, 2)*x == Mul(B(1, 2), x, evaluate=False) + + assert B(0, 2)*oo == B(0, oo) + assert B(-2, 0)*oo == B(-oo, 0) + assert B(0, 2)*(-oo) == B(-oo, 0) + assert B(-2, 0)*(-oo) == B(0, oo) + assert B(-1, 1)*oo == B(-oo, oo) + assert B(-1, 1)*(-oo) == B(-oo, oo) + assert B(-oo, oo)*oo == B(-oo, oo) + + +def test_AccumBounds_div(): + assert B(-1, 3)/B(3, 4) == B(Rational(-1, 3), 1) + assert B(-2, 4)/B(-3, 4) == B(-oo, oo) + assert B(-3, -2)/B(-4, 0) == B(S.Half, oo) + + # these two tests can have a better answer + # after Union of B is improved + assert B(-3, -2)/B(-2, 1) == B(-oo, oo) + assert B(2, 3)/B(-2, 2) == B(-oo, oo) + + assert B(-3, -2)/B(0, 4) == B(-oo, Rational(-1, 2)) + assert B(2, 4)/B(-3, 0) == B(-oo, Rational(-2, 3)) + assert B(2, 4)/B(0, 3) == B(Rational(2, 3), oo) + + assert B(0, 1)/B(0, 1) == B(0, oo) + assert B(-1, 0)/B(0, 1) == B(-oo, 0) + assert B(-1, 2)/B(-2, 2) == B(-oo, oo) + + assert 1/B(-1, 2) == B(-oo, oo) + assert 1/B(0, 2) == B(S.Half, oo) + assert (-1)/B(0, 2) == B(-oo, Rational(-1, 2)) + assert 1/B(-oo, 0) == B(-oo, 0) + assert 1/B(-1, 0) == B(-oo, -1) + assert (-2)/B(-oo, 0) == B(0, oo) + assert 1/B(-oo, -1) == B(-1, 0) + + assert B(1, 2)/a == Mul(B(1, 2), 1/a, evaluate=False) + + assert B(1, 2)/0 == B(1, 2)*zoo + assert B(1, oo)/oo == B(0, oo) + assert B(1, oo)/(-oo) == B(-oo, 0) + assert B(-oo, -1)/oo == B(-oo, 0) + assert B(-oo, -1)/(-oo) == B(0, oo) + assert B(-oo, oo)/oo == B(-oo, oo) + assert B(-oo, oo)/(-oo) == B(-oo, oo) + assert B(-1, oo)/oo == B(0, oo) + assert B(-1, oo)/(-oo) == B(-oo, 0) + assert B(-oo, 1)/oo == B(-oo, 0) + assert B(-oo, 1)/(-oo) == B(0, oo) + + +def test_issue_18795(): + r = Symbol('r', real=True) + a = B(-1,1) + c = B(7, oo) + b = B(-oo, oo) + assert c - tan(r) == B(7-tan(r), oo) + assert b + tan(r) == B(-oo, oo) + assert (a + r)/a == B(-oo, oo)*B(r - 1, r + 1) + assert (b + a)/a == B(-oo, oo) + + +def test_AccumBounds_func(): + assert (x**2 + 2*x + 1).subs(x, B(-1, 1)) == B(-1, 4) + assert exp(B(0, 1)) == B(1, E) + assert exp(B(-oo, oo)) == B(0, oo) + assert log(B(3, 6)) == B(log(3), log(6)) + + +@XFAIL +def test_AccumBounds_powf(): + nn = Symbol('nn', nonnegative=True) + assert B(1 + nn, 2 + nn)**B(1, 2) == B(1 + nn, (2 + nn)**2) + i = Symbol('i', integer=True, negative=True) + assert B(1, 2)**i == B(2**i, 1) + + +def test_AccumBounds_pow(): + assert B(0, 2)**2 == B(0, 4) + assert B(-1, 1)**2 == B(0, 1) + assert B(1, 2)**2 == B(1, 4) + assert B(-1, 2)**3 == B(-1, 8) + assert B(-1, 1)**0 == 1 + + assert B(1, 2)**Rational(5, 2) == B(1, 4*sqrt(2)) + assert B(0, 2)**S.Half == B(0, sqrt(2)) + + neg = Symbol('neg', negative=True) + assert unchanged(Pow, B(neg, 1), S.Half) + nn = Symbol('nn', nonnegative=True) + assert B(nn, nn + 1)**S.Half == B(sqrt(nn), sqrt(nn + 1)) + assert B(nn, nn + 1)**nn == B(nn**nn, (nn + 1)**nn) + assert unchanged(Pow, B(nn, nn + 1), x) + i = Symbol('i', integer=True) + assert B(1, 2)**i == B(Min(1, 2**i), Max(1, 2**i)) + i = Symbol('i', integer=True, nonnegative=True) + assert B(1, 2)**i == B(1, 2**i) + assert B(0, 1)**i == B(0**i, 1) + + assert B(1, 5)**(-2) == B(Rational(1, 25), 1) + assert B(-1, 3)**(-2) == B(0, oo) + assert B(0, 2)**(-3) == B(Rational(1, 8), oo) + assert B(-2, 0)**(-3) == B(-oo, -Rational(1, 8)) + assert B(0, 2)**(-2) == B(Rational(1, 4), oo) + assert B(-1, 2)**(-3) == B(-oo, oo) + assert B(-3, -2)**(-3) == B(Rational(-1, 8), Rational(-1, 27)) + assert B(-3, -2)**(-2) == B(Rational(1, 9), Rational(1, 4)) + assert B(0, oo)**S.Half == B(0, oo) + assert B(-oo, 0)**(-2) == B(0, oo) + assert B(-2, 0)**(-2) == B(Rational(1, 4), oo) + + assert B(Rational(1, 3), S.Half)**oo is S.Zero + assert B(0, S.Half)**oo is S.Zero + assert B(S.Half, 1)**oo == B(0, oo) + assert B(0, 1)**oo == B(0, oo) + assert B(2, 3)**oo is oo + assert B(1, 2)**oo == B(0, oo) + assert B(S.Half, 3)**oo == B(0, oo) + assert B(Rational(-1, 3), Rational(-1, 4))**oo is S.Zero + assert B(-1, Rational(-1, 2))**oo is S.NaN + assert B(-3, -2)**oo is zoo + assert B(-2, -1)**oo is S.NaN + assert B(-2, Rational(-1, 2))**oo is S.NaN + assert B(Rational(-1, 2), S.Half)**oo is S.Zero + assert B(Rational(-1, 2), 1)**oo == B(0, oo) + assert B(Rational(-2, 3), 2)**oo == B(0, oo) + assert B(-1, 1)**oo == B(-oo, oo) + assert B(-1, S.Half)**oo == B(-oo, oo) + assert B(-1, 2)**oo == B(-oo, oo) + assert B(-2, S.Half)**oo == B(-oo, oo) + + assert B(1, 2)**x == Pow(B(1, 2), x, evaluate=False) + + assert B(2, 3)**(-oo) is S.Zero + assert B(0, 2)**(-oo) == B(0, oo) + assert B(-1, 2)**(-oo) == B(-oo, oo) + + assert (tan(x)**sin(2*x)).subs(x, B(0, pi/2)) == \ + Pow(B(-oo, oo), B(0, 1)) + + +def test_AccumBounds_exponent(): + # base is 0 + z = 0**B(a, a + S.Half) + assert z.subs(a, 0) == B(0, 1) + assert z.subs(a, 1) == 0 + p = z.subs(a, -1) + assert p.is_Pow and p.args == (0, B(-1, -S.Half)) + # base > 0 + # when base is 1 the type of bounds does not matter + assert 1**B(a, a + 1) == 1 + # otherwise we need to know if 0 is in the bounds + assert S.Half**B(-2, 2) == B(S(1)/4, 4) + assert 2**B(-2, 2) == B(S(1)/4, 4) + + # +eps may introduce +oo + # if there is a negative integer exponent + assert B(0, 1)**B(S(1)/2, 1) == B(0, 1) + assert B(0, 1)**B(0, 1) == B(0, 1) + + # positive bases have positive bounds + assert B(2, 3)**B(-3, -2) == B(S(1)/27, S(1)/4) + assert B(2, 3)**B(-3, 2) == B(S(1)/27, 9) + + # bounds generating imaginary parts unevaluated + assert unchanged(Pow, B(-1, 1), B(1, 2)) + assert B(0, S(1)/2)**B(1, oo) == B(0, S(1)/2) + assert B(0, 1)**B(1, oo) == B(0, oo) + assert B(0, 2)**B(1, oo) == B(0, oo) + assert B(0, oo)**B(1, oo) == B(0, oo) + assert B(S(1)/2, 1)**B(1, oo) == B(0, oo) + assert B(S(1)/2, 1)**B(-oo, -1) == B(0, oo) + assert B(S(1)/2, 1)**B(-oo, oo) == B(0, oo) + assert B(S(1)/2, 2)**B(1, oo) == B(0, oo) + assert B(S(1)/2, 2)**B(-oo, -1) == B(0, oo) + assert B(S(1)/2, 2)**B(-oo, oo) == B(0, oo) + assert B(S(1)/2, oo)**B(1, oo) == B(0, oo) + assert B(S(1)/2, oo)**B(-oo, -1) == B(0, oo) + assert B(S(1)/2, oo)**B(-oo, oo) == B(0, oo) + assert B(1, 2)**B(1, oo) == B(0, oo) + assert B(1, 2)**B(-oo, -1) == B(0, oo) + assert B(1, 2)**B(-oo, oo) == B(0, oo) + assert B(1, oo)**B(1, oo) == B(0, oo) + assert B(1, oo)**B(-oo, -1) == B(0, oo) + assert B(1, oo)**B(-oo, oo) == B(0, oo) + assert B(2, oo)**B(1, oo) == B(2, oo) + assert B(2, oo)**B(-oo, -1) == B(0, S(1)/2) + assert B(2, oo)**B(-oo, oo) == B(0, oo) + + +def test_comparison_AccumBounds(): + assert (B(1, 3) < 4) == S.true + assert (B(1, 3) < -1) == S.false + assert (B(1, 3) < 2).rel_op == '<' + assert (B(1, 3) <= 2).rel_op == '<=' + + assert (B(1, 3) > 4) == S.false + assert (B(1, 3) > -1) == S.true + assert (B(1, 3) > 2).rel_op == '>' + assert (B(1, 3) >= 2).rel_op == '>=' + + assert (B(1, 3) < B(4, 6)) == S.true + assert (B(1, 3) < B(2, 4)).rel_op == '<' + assert (B(1, 3) < B(-2, 0)) == S.false + + assert (B(1, 3) <= B(4, 6)) == S.true + assert (B(1, 3) <= B(-2, 0)) == S.false + + assert (B(1, 3) > B(4, 6)) == S.false + assert (B(1, 3) > B(-2, 0)) == S.true + + assert (B(1, 3) >= B(4, 6)) == S.false + assert (B(1, 3) >= B(-2, 0)) == S.true + + # issue 13499 + assert (cos(x) > 0).subs(x, oo) == (B(-1, 1) > 0) + + c = Symbol('c') + raises(TypeError, lambda: (B(0, 1) < c)) + raises(TypeError, lambda: (B(0, 1) <= c)) + raises(TypeError, lambda: (B(0, 1) > c)) + raises(TypeError, lambda: (B(0, 1) >= c)) + + +def test_contains_AccumBounds(): + assert (1 in B(1, 2)) == S.true + raises(TypeError, lambda: a in B(1, 2)) + assert 0 in B(-1, 0) + raises(TypeError, lambda: + (cos(1)**2 + sin(1)**2 - 1) in B(-1, 0)) + assert (-oo in B(1, oo)) == S.true + assert (oo in B(-oo, 0)) == S.true + + # issue 13159 + assert Mul(0, B(-1, 1)) == Mul(B(-1, 1), 0) == 0 + import itertools + for perm in itertools.permutations([0, B(-1, 1), x]): + assert Mul(*perm) == 0 + + +def test_intersection_AccumBounds(): + assert B(0, 3).intersection(B(1, 2)) == B(1, 2) + assert B(0, 3).intersection(B(1, 4)) == B(1, 3) + assert B(0, 3).intersection(B(-1, 2)) == B(0, 2) + assert B(0, 3).intersection(B(-1, 4)) == B(0, 3) + assert B(0, 1).intersection(B(2, 3)) == S.EmptySet + raises(TypeError, lambda: B(0, 3).intersection(1)) + + +def test_union_AccumBounds(): + assert B(0, 3).union(B(1, 2)) == B(0, 3) + assert B(0, 3).union(B(1, 4)) == B(0, 4) + assert B(0, 3).union(B(-1, 2)) == B(-1, 3) + assert B(0, 3).union(B(-1, 4)) == B(-1, 4) + raises(TypeError, lambda: B(0, 3).union(1)) diff --git a/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_euler.py b/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_euler.py new file mode 100644 index 0000000000000000000000000000000000000000..56371c8c787d9459d1390e18c306fddde94d2745 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_euler.py @@ -0,0 +1,74 @@ +from sympy.core.function import (Derivative as D, Function) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.testing.pytest import raises +from sympy.calculus.euler import euler_equations as euler + + +def test_euler_interface(): + x = Function('x') + y = Symbol('y') + t = Symbol('t') + raises(TypeError, lambda: euler()) + raises(TypeError, lambda: euler(D(x(t), t)*y(t), [x(t), y])) + raises(ValueError, lambda: euler(D(x(t), t)*x(y), [x(t), x(y)])) + raises(TypeError, lambda: euler(D(x(t), t)**2, x(0))) + raises(TypeError, lambda: euler(D(x(t), t)*y(t), [t])) + assert euler(D(x(t), t)**2/2, {x(t)}) == [Eq(-D(x(t), t, t), 0)] + assert euler(D(x(t), t)**2/2, x(t), {t}) == [Eq(-D(x(t), t, t), 0)] + + +def test_euler_pendulum(): + x = Function('x') + t = Symbol('t') + L = D(x(t), t)**2/2 + cos(x(t)) + assert euler(L, x(t), t) == [Eq(-sin(x(t)) - D(x(t), t, t), 0)] + + +def test_euler_henonheiles(): + x = Function('x') + y = Function('y') + t = Symbol('t') + L = sum(D(z(t), t)**2/2 - z(t)**2/2 for z in [x, y]) + L += -x(t)**2*y(t) + y(t)**3/3 + assert euler(L, [x(t), y(t)], t) == [Eq(-2*x(t)*y(t) - x(t) - + D(x(t), t, t), 0), + Eq(-x(t)**2 + y(t)**2 - + y(t) - D(y(t), t, t), 0)] + + +def test_euler_sineg(): + psi = Function('psi') + t = Symbol('t') + x = Symbol('x') + L = D(psi(t, x), t)**2/2 - D(psi(t, x), x)**2/2 + cos(psi(t, x)) + assert euler(L, psi(t, x), [t, x]) == [Eq(-sin(psi(t, x)) - + D(psi(t, x), t, t) + + D(psi(t, x), x, x), 0)] + + +def test_euler_high_order(): + # an example from hep-th/0309038 + m = Symbol('m') + k = Symbol('k') + x = Function('x') + y = Function('y') + t = Symbol('t') + L = (m*D(x(t), t)**2/2 + m*D(y(t), t)**2/2 - + k*D(x(t), t)*D(y(t), t, t) + k*D(y(t), t)*D(x(t), t, t)) + assert euler(L, [x(t), y(t)]) == [Eq(2*k*D(y(t), t, t, t) - + m*D(x(t), t, t), 0), + Eq(-2*k*D(x(t), t, t, t) - + m*D(y(t), t, t), 0)] + + w = Symbol('w') + L = D(x(t, w), t, w)**2/2 + assert euler(L) == [Eq(D(x(t, w), t, t, w, w), 0)] + +def test_issue_18653(): + x, y, z = symbols("x y z") + f, g, h = symbols("f g h", cls=Function, args=(x, y)) + f, g, h = f(), g(), h() + expr2 = f.diff(x)*h.diff(z) + assert euler(expr2, (f,), (x, y)) == [] diff --git a/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_finite_diff.py b/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_finite_diff.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ecfbdd61b15f516c54bd6d716ba1f264ee2ca0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_finite_diff.py @@ -0,0 +1,164 @@ +from itertools import product + +from sympy.core.function import (Function, diff) +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.calculus.finite_diff import ( + apply_finite_diff, differentiate_finite, finite_diff_weights, + _as_finite_diff +) +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +def test_apply_finite_diff(): + x, h = symbols('x h') + f = Function('f') + assert (apply_finite_diff(1, [x-h, x+h], [f(x-h), f(x+h)], x) - + (f(x+h)-f(x-h))/(2*h)).simplify() == 0 + + assert (apply_finite_diff(1, [5, 6, 7], [f(5), f(6), f(7)], 5) - + (Rational(-3, 2)*f(5) + 2*f(6) - S.Half*f(7))).simplify() == 0 + raises(ValueError, lambda: apply_finite_diff(1, [x, h], [f(x)])) + + +def test_finite_diff_weights(): + + d = finite_diff_weights(1, [5, 6, 7], 5) + assert d[1][2] == [Rational(-3, 2), 2, Rational(-1, 2)] + + # Table 1, p. 702 in doi:10.1090/S0025-5718-1988-0935077-0 + # -------------------------------------------------------- + xl = [0, 1, -1, 2, -2, 3, -3, 4, -4] + + # d holds all coefficients + d = finite_diff_weights(4, xl, S.Zero) + + # Zeroeth derivative + for i in range(5): + assert d[0][i] == [S.One] + [S.Zero]*8 + + # First derivative + assert d[1][0] == [S.Zero]*9 + assert d[1][2] == [S.Zero, S.Half, Rational(-1, 2)] + [S.Zero]*6 + assert d[1][4] == [S.Zero, Rational(2, 3), Rational(-2, 3), Rational(-1, 12), Rational(1, 12)] + [S.Zero]*4 + assert d[1][6] == [S.Zero, Rational(3, 4), Rational(-3, 4), Rational(-3, 20), Rational(3, 20), + Rational(1, 60), Rational(-1, 60)] + [S.Zero]*2 + assert d[1][8] == [S.Zero, Rational(4, 5), Rational(-4, 5), Rational(-1, 5), Rational(1, 5), + Rational(4, 105), Rational(-4, 105), Rational(-1, 280), Rational(1, 280)] + + # Second derivative + for i in range(2): + assert d[2][i] == [S.Zero]*9 + assert d[2][2] == [-S(2), S.One, S.One] + [S.Zero]*6 + assert d[2][4] == [Rational(-5, 2), Rational(4, 3), Rational(4, 3), Rational(-1, 12), Rational(-1, 12)] + [S.Zero]*4 + assert d[2][6] == [Rational(-49, 18), Rational(3, 2), Rational(3, 2), Rational(-3, 20), Rational(-3, 20), + Rational(1, 90), Rational(1, 90)] + [S.Zero]*2 + assert d[2][8] == [Rational(-205, 72), Rational(8, 5), Rational(8, 5), Rational(-1, 5), Rational(-1, 5), + Rational(8, 315), Rational(8, 315), Rational(-1, 560), Rational(-1, 560)] + + # Third derivative + for i in range(3): + assert d[3][i] == [S.Zero]*9 + assert d[3][4] == [S.Zero, -S.One, S.One, S.Half, Rational(-1, 2)] + [S.Zero]*4 + assert d[3][6] == [S.Zero, Rational(-13, 8), Rational(13, 8), S.One, -S.One, + Rational(-1, 8), Rational(1, 8)] + [S.Zero]*2 + assert d[3][8] == [S.Zero, Rational(-61, 30), Rational(61, 30), Rational(169, 120), Rational(-169, 120), + Rational(-3, 10), Rational(3, 10), Rational(7, 240), Rational(-7, 240)] + + # Fourth derivative + for i in range(4): + assert d[4][i] == [S.Zero]*9 + assert d[4][4] == [S(6), -S(4), -S(4), S.One, S.One] + [S.Zero]*4 + assert d[4][6] == [Rational(28, 3), Rational(-13, 2), Rational(-13, 2), S(2), S(2), + Rational(-1, 6), Rational(-1, 6)] + [S.Zero]*2 + assert d[4][8] == [Rational(91, 8), Rational(-122, 15), Rational(-122, 15), Rational(169, 60), Rational(169, 60), + Rational(-2, 5), Rational(-2, 5), Rational(7, 240), Rational(7, 240)] + + # Table 2, p. 703 in doi:10.1090/S0025-5718-1988-0935077-0 + # -------------------------------------------------------- + xl = [[j/S(2) for j in list(range(-i*2+1, 0, 2))+list(range(1, i*2+1, 2))] + for i in range(1, 5)] + + # d holds all coefficients + d = [finite_diff_weights({0: 1, 1: 2, 2: 4, 3: 4}[i], xl[i], 0) for + i in range(4)] + + # Zeroth derivative + assert d[0][0][1] == [S.Half, S.Half] + assert d[1][0][3] == [Rational(-1, 16), Rational(9, 16), Rational(9, 16), Rational(-1, 16)] + assert d[2][0][5] == [Rational(3, 256), Rational(-25, 256), Rational(75, 128), Rational(75, 128), + Rational(-25, 256), Rational(3, 256)] + assert d[3][0][7] == [Rational(-5, 2048), Rational(49, 2048), Rational(-245, 2048), Rational(1225, 2048), + Rational(1225, 2048), Rational(-245, 2048), Rational(49, 2048), Rational(-5, 2048)] + + # First derivative + assert d[0][1][1] == [-S.One, S.One] + assert d[1][1][3] == [Rational(1, 24), Rational(-9, 8), Rational(9, 8), Rational(-1, 24)] + assert d[2][1][5] == [Rational(-3, 640), Rational(25, 384), Rational(-75, 64), + Rational(75, 64), Rational(-25, 384), Rational(3, 640)] + assert d[3][1][7] == [Rational(5, 7168), Rational(-49, 5120), + Rational(245, 3072), Rational(-1225, 1024), + Rational(1225, 1024), Rational(-245, 3072), + Rational(49, 5120), Rational(-5, 7168)] + + # Reasonably the rest of the table is also correct... (testing of that + # deemed excessive at the moment) + raises(ValueError, lambda: finite_diff_weights(-1, [1, 2])) + raises(ValueError, lambda: finite_diff_weights(1.2, [1, 2])) + x = symbols('x') + raises(ValueError, lambda: finite_diff_weights(x, [1, 2])) + + +def test_as_finite_diff(): + x = symbols('x') + f = Function('f') + dx = Function('dx') + + _as_finite_diff(f(x).diff(x), [x-2, x-1, x, x+1, x+2]) + + # Use of undefined functions in ``points`` + df_true = -f(x+dx(x)/2-dx(x+dx(x)/2)/2) / dx(x+dx(x)/2) \ + + f(x+dx(x)/2+dx(x+dx(x)/2)/2) / dx(x+dx(x)/2) + df_test = diff(f(x), x).as_finite_difference(points=dx(x), x0=x+dx(x)/2) + assert (df_test - df_true).simplify() == 0 + + +def test_differentiate_finite(): + x, y, h = symbols('x y h') + f = Function('f') + with warns_deprecated_sympy(): + res0 = differentiate_finite(f(x, y) + exp(42), x, y, evaluate=True) + xm, xp, ym, yp = [v + sign*S.Half for v, sign in product([x, y], [-1, 1])] + ref0 = f(xm, ym) + f(xp, yp) - f(xm, yp) - f(xp, ym) + assert (res0 - ref0).simplify() == 0 + + g = Function('g') + with warns_deprecated_sympy(): + res1 = differentiate_finite(f(x)*g(x) + 42, x, evaluate=True) + ref1 = (-f(x - S.Half) + f(x + S.Half))*g(x) + \ + (-g(x - S.Half) + g(x + S.Half))*f(x) + assert (res1 - ref1).simplify() == 0 + + res2 = differentiate_finite(f(x) + x**3 + 42, x, points=[x-1, x+1]) + ref2 = (f(x + 1) + (x + 1)**3 - f(x - 1) - (x - 1)**3)/2 + assert (res2 - ref2).simplify() == 0 + raises(TypeError, lambda: differentiate_finite(f(x)*g(x), x, + pints=[x-1, x+1])) + + res3 = differentiate_finite(f(x)*g(x).diff(x), x) + ref3 = (-g(x) + g(x + 1))*f(x + S.Half) - (g(x) - g(x - 1))*f(x - S.Half) + assert res3 == ref3 + + res4 = differentiate_finite(f(x)*g(x).diff(x).diff(x), x) + ref4 = -((g(x - Rational(3, 2)) - 2*g(x - S.Half) + g(x + S.Half))*f(x - S.Half)) \ + + (g(x - S.Half) - 2*g(x + S.Half) + g(x + Rational(3, 2)))*f(x + S.Half) + assert res4 == ref4 + + res5_expr = f(x).diff(x)*g(x).diff(x) + res5 = differentiate_finite(res5_expr, points=[x-h, x, x+h]) + ref5 = (-2*f(x)/h + f(-h + x)/(2*h) + 3*f(h + x)/(2*h))*(-2*g(x)/h + g(-h + x)/(2*h) \ + + 3*g(h + x)/(2*h))/(2*h) - (2*f(x)/h - 3*f(-h + x)/(2*h) - \ + f(h + x)/(2*h))*(2*g(x)/h - 3*g(-h + x)/(2*h) - g(h + x)/(2*h))/(2*h) + assert res5 == ref5 diff --git a/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_singularities.py b/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_singularities.py new file mode 100644 index 0000000000000000000000000000000000000000..19a042332326658021ce12a38f4e058f55903869 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_singularities.py @@ -0,0 +1,122 @@ +from sympy.core.numbers import (I, Rational, pi, oo) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol, Dummy +from sympy.core.function import Lambda +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.trigonometric import sec, csc +from sympy.functions.elementary.hyperbolic import (coth, sech, + atanh, asech, acoth, acsch) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.calculus.singularities import ( + singularities, + is_increasing, + is_strictly_increasing, + is_decreasing, + is_strictly_decreasing, + is_monotonic +) +from sympy.sets import Interval, FiniteSet, Union, ImageSet +from sympy.testing.pytest import raises +from sympy.abc import x, y + + +def test_singularities(): + x = Symbol('x') + assert singularities(x**2, x) == S.EmptySet + assert singularities(x/(x**2 + 3*x + 2), x) == FiniteSet(-2, -1) + assert singularities(1/(x**2 + 1), x) == FiniteSet(I, -I) + assert singularities(x/(x**3 + 1), x) == \ + FiniteSet(-1, (1 - sqrt(3) * I) / 2, (1 + sqrt(3) * I) / 2) + assert singularities(1/(y**2 + 2*I*y + 1), y) == \ + FiniteSet(-I + sqrt(2)*I, -I - sqrt(2)*I) + _n = Dummy('n') + assert singularities(sech(x), x).dummy_eq(Union( + ImageSet(Lambda(_n, 2*_n*I*pi + I*pi/2), S.Integers), + ImageSet(Lambda(_n, 2*_n*I*pi + 3*I*pi/2), S.Integers))) + assert singularities(coth(x), x).dummy_eq(Union( + ImageSet(Lambda(_n, 2*_n*I*pi + I*pi), S.Integers), + ImageSet(Lambda(_n, 2*_n*I*pi), S.Integers))) + assert singularities(atanh(x), x) == FiniteSet(-1, 1) + assert singularities(acoth(x), x) == FiniteSet(-1, 1) + assert singularities(asech(x), x) == FiniteSet(0) + assert singularities(acsch(x), x) == FiniteSet(0) + + x = Symbol('x', real=True) + assert singularities(1/(x**2 + 1), x) == S.EmptySet + assert singularities(exp(1/x), x, S.Reals) == FiniteSet(0) + assert singularities(exp(1/x), x, Interval(1, 2)) == S.EmptySet + assert singularities(log((x - 2)**2), x, Interval(1, 3)) == FiniteSet(2) + raises(NotImplementedError, lambda: singularities(x**-oo, x)) + assert singularities(sec(x), x, Interval(0, 3*pi)) == FiniteSet( + pi/2, 3*pi/2, 5*pi/2) + assert singularities(csc(x), x, Interval(0, 3*pi)) == FiniteSet( + 0, pi, 2*pi, 3*pi) + + +def test_is_increasing(): + """Test whether is_increasing returns correct value.""" + a = Symbol('a', negative=True) + + assert is_increasing(x**3 - 3*x**2 + 4*x, S.Reals) + assert is_increasing(-x**2, Interval(-oo, 0)) + assert not is_increasing(-x**2, Interval(0, oo)) + assert not is_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval(-2, 3)) + assert is_increasing(x**2 + y, Interval(1, oo), x) + assert is_increasing(-x**2*a, Interval(1, oo), x) + assert is_increasing(1) + + assert is_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval(-2, 3)) is False + + +def test_is_strictly_increasing(): + """Test whether is_strictly_increasing returns correct value.""" + assert is_strictly_increasing( + 4*x**3 - 6*x**2 - 72*x + 30, Interval.Ropen(-oo, -2)) + assert is_strictly_increasing( + 4*x**3 - 6*x**2 - 72*x + 30, Interval.Lopen(3, oo)) + assert not is_strictly_increasing( + 4*x**3 - 6*x**2 - 72*x + 30, Interval.open(-2, 3)) + assert not is_strictly_increasing(-x**2, Interval(0, oo)) + assert not is_strictly_decreasing(1) + + assert is_strictly_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval.open(-2, 3)) is False + + +def test_is_decreasing(): + """Test whether is_decreasing returns correct value.""" + b = Symbol('b', positive=True) + + assert is_decreasing(1/(x**2 - 3*x), Interval.open(Rational(3,2), 3)) + assert is_decreasing(1/(x**2 - 3*x), Interval.open(1.5, 3)) + assert is_decreasing(1/(x**2 - 3*x), Interval.Lopen(3, oo)) + assert not is_decreasing(1/(x**2 - 3*x), Interval.Ropen(-oo, Rational(3, 2))) + assert not is_decreasing(-x**2, Interval(-oo, 0)) + assert not is_decreasing(-x**2*b, Interval(-oo, 0), x) + + +def test_is_strictly_decreasing(): + """Test whether is_strictly_decreasing returns correct value.""" + assert is_strictly_decreasing(1/(x**2 - 3*x), Interval.Lopen(3, oo)) + assert not is_strictly_decreasing( + 1/(x**2 - 3*x), Interval.Ropen(-oo, Rational(3, 2))) + assert not is_strictly_decreasing(-x**2, Interval(-oo, 0)) + assert not is_strictly_decreasing(1) + assert is_strictly_decreasing(1/(x**2 - 3*x), Interval.open(Rational(3,2), 3)) + assert is_strictly_decreasing(1/(x**2 - 3*x), Interval.open(1.5, 3)) + + +def test_is_monotonic(): + """Test whether is_monotonic returns correct value.""" + assert is_monotonic(1/(x**2 - 3*x), Interval.open(Rational(3,2), 3)) + assert is_monotonic(1/(x**2 - 3*x), Interval.open(1.5, 3)) + assert is_monotonic(1/(x**2 - 3*x), Interval.Lopen(3, oo)) + assert is_monotonic(x**3 - 3*x**2 + 4*x, S.Reals) + assert not is_monotonic(-x**2, S.Reals) + assert is_monotonic(x**2 + y + 1, Interval(1, 2), x) + raises(NotImplementedError, lambda: is_monotonic(x**2 + y + 1)) + + +def test_issue_23401(): + x = Symbol('x') + expr = (x + 1)/(-1.0e-3*x**2 + 0.1*x + 0.1) + assert is_increasing(expr, Interval(1,2), x) diff --git a/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_util.py b/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..c18b7a79fd54fdb2638cc746d43ab26753fc72a9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_util.py @@ -0,0 +1,392 @@ +from sympy.core.function import Lambda +from sympy.core.numbers import (E, I, Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.elementary.complexes import (Abs, re) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.integers import frac +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import ( + cos, cot, csc, sec, sin, tan, asin, acos, atan, acot, asec, acsc) +from sympy.functions.elementary.hyperbolic import (sinh, cosh, tanh, coth, + sech, csch, asinh, acosh, atanh, acoth, asech, acsch) +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.error_functions import expint +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.simplify.simplify import simplify +from sympy.calculus.util import (function_range, continuous_domain, not_empty_in, + periodicity, lcim, is_convex, + stationary_points, minimum, maximum) +from sympy.sets.sets import (Interval, FiniteSet, Complement, Union) +from sympy.sets.fancysets import ImageSet +from sympy.sets.conditionset import ConditionSet +from sympy.testing.pytest import XFAIL, raises, _both_exp_pow, slow +from sympy.abc import x, y + +a = Symbol('a', real=True) + +def test_function_range(): + assert function_range(sin(x), x, Interval(-pi/2, pi/2) + ) == Interval(-1, 1) + assert function_range(sin(x), x, Interval(0, pi) + ) == Interval(0, 1) + assert function_range(tan(x), x, Interval(0, pi) + ) == Interval(-oo, oo) + assert function_range(tan(x), x, Interval(pi/2, pi) + ) == Interval(-oo, 0) + assert function_range((x + 3)/(x - 2), x, Interval(-5, 5) + ) == Union(Interval(-oo, Rational(2, 7)), Interval(Rational(8, 3), oo)) + assert function_range(1/(x**2), x, Interval(-1, 1) + ) == Interval(1, oo) + assert function_range(exp(x), x, Interval(-1, 1) + ) == Interval(exp(-1), exp(1)) + assert function_range(log(x) - x, x, S.Reals + ) == Interval(-oo, -1) + assert function_range(sqrt(3*x - 1), x, Interval(0, 2) + ) == Interval(0, sqrt(5)) + assert function_range(x*(x - 1) - (x**2 - x), x, S.Reals + ) == FiniteSet(0) + assert function_range(x*(x - 1) - (x**2 - x) + y, x, S.Reals + ) == FiniteSet(y) + assert function_range(sin(x), x, Union(Interval(-5, -3), FiniteSet(4)) + ) == Union(Interval(-sin(3), 1), FiniteSet(sin(4))) + assert function_range(cos(x), x, Interval(-oo, -4) + ) == Interval(-1, 1) + assert function_range(cos(x), x, S.EmptySet) == S.EmptySet + assert function_range(x/sqrt(x**2+1), x, S.Reals) == Interval.open(-1,1) + raises(NotImplementedError, lambda : function_range( + exp(x)*(sin(x) - cos(x))/2 - x, x, S.Reals)) + raises(NotImplementedError, lambda : function_range( + sin(x) + x, x, S.Reals)) # issue 13273 + raises(NotImplementedError, lambda : function_range( + log(x), x, S.Integers)) + raises(NotImplementedError, lambda : function_range( + sin(x)/2, x, S.Naturals)) + + +@slow +def test_function_range1(): + assert function_range(tan(x)**2 + tan(3*x)**2 + 1, x, S.Reals) == Interval(1,oo) + + +def test_continuous_domain(): + assert continuous_domain(sin(x), x, Interval(0, 2*pi)) == Interval(0, 2*pi) + assert continuous_domain(tan(x), x, Interval(0, 2*pi)) == \ + Union(Interval(0, pi/2, False, True), Interval(pi/2, pi*Rational(3, 2), True, True), + Interval(pi*Rational(3, 2), 2*pi, True, False)) + assert continuous_domain(cot(x), x, Interval(0, 2*pi)) == Union( + Interval.open(0, pi), Interval.open(pi, 2*pi)) + assert continuous_domain((x - 1)/((x - 1)**2), x, S.Reals) == \ + Union(Interval(-oo, 1, True, True), Interval(1, oo, True, True)) + assert continuous_domain(log(x) + log(4*x - 1), x, S.Reals) == \ + Interval(Rational(1, 4), oo, True, True) + assert continuous_domain(1/sqrt(x - 3), x, S.Reals) == Interval(3, oo, True, True) + assert continuous_domain(1/x - 2, x, S.Reals) == \ + Union(Interval.open(-oo, 0), Interval.open(0, oo)) + assert continuous_domain(1/(x**2 - 4) + 2, x, S.Reals) == \ + Union(Interval.open(-oo, -2), Interval.open(-2, 2), Interval.open(2, oo)) + assert continuous_domain((x+1)**pi, x, S.Reals) == Interval(-1, oo) + assert continuous_domain((x+1)**(pi/2), x, S.Reals) == Interval(-1, oo) + assert continuous_domain(x**x, x, S.Reals) == Interval(0, oo) + assert continuous_domain((x+1)**log(x**2), x, S.Reals) == Union( + Interval.Ropen(-1, 0), Interval.open(0, oo)) + domain = continuous_domain(log(tan(x)**2 + 1), x, S.Reals) + assert not domain.contains(3*pi/2) + assert domain.contains(5) + d = Symbol('d', even=True, zero=False) + assert continuous_domain(x**(1/d), x, S.Reals) == Interval(0, oo) + n = Dummy('n') + assert continuous_domain(1/sin(x), x, S.Reals).dummy_eq(Complement( + S.Reals, Union(ImageSet(Lambda(n, 2*n*pi + pi), S.Integers), + ImageSet(Lambda(n, 2*n*pi), S.Integers)))) + assert continuous_domain(sin(x) + cos(x), x, S.Reals) == S.Reals + assert continuous_domain(asin(x), x, S.Reals) == Interval(-1, 1) # issue #21786 + assert continuous_domain(1/acos(log(x)), x, S.Reals) == Interval.Ropen(exp(-1), E) + assert continuous_domain(sinh(x)+cosh(x), x, S.Reals) == S.Reals + assert continuous_domain(tanh(x)+sech(x), x, S.Reals) == S.Reals + assert continuous_domain(atan(x)+asinh(x), x, S.Reals) == S.Reals + assert continuous_domain(acosh(x), x, S.Reals) == Interval(1, oo) + assert continuous_domain(atanh(x), x, S.Reals) == Interval.open(-1, 1) + assert continuous_domain(atanh(x)+acosh(x), x, S.Reals) == S.EmptySet + assert continuous_domain(asech(x), x, S.Reals) == Interval.Lopen(0, 1) + assert continuous_domain(acoth(x), x, S.Reals) == Union( + Interval.open(-oo, -1), Interval.open(1, oo)) + assert continuous_domain(asec(x), x, S.Reals) == Union( + Interval(-oo, -1), Interval(1, oo)) + assert continuous_domain(acsc(x), x, S.Reals) == Union( + Interval(-oo, -1), Interval(1, oo)) + for f in (coth, acsch, csch): + assert continuous_domain(f(x), x, S.Reals) == Union( + Interval.open(-oo, 0), Interval.open(0, oo)) + assert continuous_domain(acot(x), x, S.Reals).contains(0) == False + assert continuous_domain(1/(exp(x) - x), x, S.Reals) == Complement( + S.Reals, ConditionSet(x, Eq(-x + exp(x), 0), S.Reals)) + assert continuous_domain(frac(x**2), x, Interval(-2,-1)) == Union( + Interval.open(-2, -sqrt(3)), Interval.open(-sqrt(2), -1), + Interval.open(-sqrt(3), -sqrt(2))) + assert continuous_domain(frac(x), x, S.Reals) == Complement( + S.Reals, S.Integers) + raises(NotImplementedError, lambda : continuous_domain( + 1/(x**2+1), x, S.Complexes)) + raises(NotImplementedError, lambda : continuous_domain( + gamma(x), x, Interval(-5,0))) + assert continuous_domain(x + gamma(pi), x, S.Reals) == S.Reals + + +@XFAIL +def test_continuous_domain_acot(): + acot_cont = Piecewise((pi+acot(x), x<0), (acot(x), True)) + assert continuous_domain(acot_cont, x, S.Reals) == S.Reals + +@XFAIL +def test_continuous_domain_gamma(): + assert continuous_domain(gamma(x), x, S.Reals).contains(-1) == False + +@XFAIL +def test_continuous_domain_neg_power(): + assert continuous_domain((x-2)**(1-x), x, S.Reals) == Interval.open(2, oo) + + +def test_not_empty_in(): + assert not_empty_in(FiniteSet(x, 2*x).intersect(Interval(1, 2, True, False)), x) == \ + Interval(S.Half, 2, True, False) + assert not_empty_in(FiniteSet(x, x**2).intersect(Interval(1, 2)), x) == \ + Union(Interval(-sqrt(2), -1), Interval(1, 2)) + assert not_empty_in(FiniteSet(x**2 + x, x).intersect(Interval(2, 4)), x) == \ + Union(Interval(-sqrt(17)/2 - S.Half, -2), + Interval(1, Rational(-1, 2) + sqrt(17)/2), Interval(2, 4)) + assert not_empty_in(FiniteSet(x/(x - 1)).intersect(S.Reals), x) == \ + Complement(S.Reals, FiniteSet(1)) + assert not_empty_in(FiniteSet(a/(a - 1)).intersect(S.Reals), a) == \ + Complement(S.Reals, FiniteSet(1)) + assert not_empty_in(FiniteSet((x**2 - 3*x + 2)/(x - 1)).intersect(S.Reals), x) == \ + Complement(S.Reals, FiniteSet(1)) + assert not_empty_in(FiniteSet(3, 4, x/(x - 1)).intersect(Interval(2, 3)), x) == \ + Interval(-oo, oo) + assert not_empty_in(FiniteSet(4, x/(x - 1)).intersect(Interval(2, 3)), x) == \ + Interval(S(3)/2, 2) + assert not_empty_in(FiniteSet(x/(x**2 - 1)).intersect(S.Reals), x) == \ + Complement(S.Reals, FiniteSet(-1, 1)) + assert not_empty_in(FiniteSet(x, x**2).intersect(Union(Interval(1, 3, True, True), + Interval(4, 5))), x) == \ + Union(Interval(-sqrt(5), -2), Interval(-sqrt(3), -1, True, True), + Interval(1, 3, True, True), Interval(4, 5)) + assert not_empty_in(FiniteSet(1).intersect(Interval(3, 4)), x) == S.EmptySet + assert not_empty_in(FiniteSet(x**2/(x + 2)).intersect(Interval(1, oo)), x) == \ + Union(Interval(-2, -1, True, False), Interval(2, oo)) + raises(ValueError, lambda: not_empty_in(x)) + raises(ValueError, lambda: not_empty_in(Interval(0, 1), x)) + raises(NotImplementedError, + lambda: not_empty_in(FiniteSet(x).intersect(S.Reals), x, a)) + + +@_both_exp_pow +def test_periodicity(): + assert periodicity(sin(2*x), x) == pi + assert periodicity((-2)*tan(4*x), x) == pi/4 + assert periodicity(sin(x)**2, x) == 2*pi + assert periodicity(3**tan(3*x), x) == pi/3 + assert periodicity(tan(x)*cos(x), x) == 2*pi + assert periodicity(sin(x)**(tan(x)), x) == 2*pi + assert periodicity(tan(x)*sec(x), x) == 2*pi + assert periodicity(sin(2*x)*cos(2*x) - y, x) == pi/2 + assert periodicity(tan(x) + cot(x), x) == pi + assert periodicity(sin(x) - cos(2*x), x) == 2*pi + assert periodicity(sin(x) - 1, x) == 2*pi + assert periodicity(sin(4*x) + sin(x)*cos(x), x) == pi + assert periodicity(exp(sin(x)), x) == 2*pi + assert periodicity(log(cot(2*x)) - sin(cos(2*x)), x) == pi + assert periodicity(sin(2*x)*exp(tan(x) - csc(2*x)), x) == pi + assert periodicity(cos(sec(x) - csc(2*x)), x) == 2*pi + assert periodicity(tan(sin(2*x)), x) == pi + assert periodicity(2*tan(x)**2, x) == pi + assert periodicity(sin(x%4), x) == 4 + assert periodicity(sin(x)%4, x) == 2*pi + assert periodicity(tan((3*x-2)%4), x) == Rational(4, 3) + assert periodicity((sqrt(2)*(x+1)+x) % 3, x) == 3 / (sqrt(2)+1) + assert periodicity((x**2+1) % x, x) is None + assert periodicity(sin(re(x)), x) == 2*pi + assert periodicity(sin(x)**2 + cos(x)**2, x) is S.Zero + assert periodicity(tan(x), y) is S.Zero + assert periodicity(sin(x) + I*cos(x), x) == 2*pi + assert periodicity(x - sin(2*y), y) == pi + + assert periodicity(exp(x), x) is None + assert periodicity(exp(I*x), x) == 2*pi + assert periodicity(exp(I*a), a) == 2*pi + assert periodicity(exp(a), a) is None + assert periodicity(exp(log(sin(a) + I*cos(2*a)), evaluate=False), a) == 2*pi + assert periodicity(exp(log(sin(2*a) + I*cos(a)), evaluate=False), a) == 2*pi + assert periodicity(exp(sin(a)), a) == 2*pi + assert periodicity(exp(2*I*a), a) == pi + assert periodicity(exp(a + I*sin(a)), a) is None + assert periodicity(exp(cos(a/2) + sin(a)), a) == 4*pi + assert periodicity(log(x), x) is None + assert periodicity(exp(x)**sin(x), x) is None + assert periodicity(sin(x)**y, y) is None + + assert periodicity(Abs(sin(Abs(sin(x)))), x) == pi + assert all(periodicity(Abs(f(x)), x) == pi for f in ( + cos, sin, sec, csc, tan, cot)) + assert periodicity(Abs(sin(tan(x))), x) == pi + assert periodicity(Abs(sin(sin(x) + tan(x))), x) == 2*pi + assert periodicity(sin(x) > S.Half, x) == 2*pi + + assert periodicity(x > 2, x) is None + assert periodicity(x**3 - x**2 + 1, x) is None + assert periodicity(Abs(x), x) is None + assert periodicity(Abs(x**2 - 1), x) is None + + assert periodicity((x**2 + 4)%2, x) is None + assert periodicity((E**x)%3, x) is None + + assert periodicity(sin(expint(1, x))/expint(1, x), x) is None + # returning `None` for any Piecewise + p = Piecewise((0, x < -1), (x**2, x <= 1), (log(x), True)) + assert periodicity(p, x) is None + + m = MatrixSymbol('m', 3, 3) + raises(NotImplementedError, lambda: periodicity(sin(m), m)) + raises(NotImplementedError, lambda: periodicity(sin(m[0, 0]), m)) + raises(NotImplementedError, lambda: periodicity(sin(m), m[0, 0])) + raises(NotImplementedError, lambda: periodicity(sin(m[0, 0]), m[0, 0])) + + +def test_periodicity_check(): + assert periodicity(tan(x), x, check=True) == pi + assert periodicity(sin(x) + cos(x), x, check=True) == 2*pi + assert periodicity(sec(x), x) == 2*pi + assert periodicity(sin(x*y), x) == 2*pi/abs(y) + assert periodicity(Abs(sec(sec(x))), x) == pi + + +def test_lcim(): + assert lcim([S.Half, S(2), S(3)]) == 6 + assert lcim([pi/2, pi/4, pi]) == pi + assert lcim([2*pi, pi/2]) == 2*pi + assert lcim([S.One, 2*pi]) is None + assert lcim([S(2) + 2*E, E/3 + Rational(1, 3), S.One + E]) == S(2) + 2*E + + +def test_is_convex(): + assert is_convex(1/x, x, domain=Interval.open(0, oo)) == True + assert is_convex(1/x, x, domain=Interval(-oo, 0)) == False + assert is_convex(x**2, x, domain=Interval(0, oo)) == True + assert is_convex(1/x**3, x, domain=Interval.Lopen(0, oo)) == True + assert is_convex(-1/x**3, x, domain=Interval.Ropen(-oo, 0)) == True + assert is_convex(log(x) ,x) == False + assert is_convex(x**2+y**2, x, y) == True + assert is_convex(cos(x) + cos(y), x) == False + assert is_convex(8*x**2 - 2*y**2, x, y) == False + + +def test_stationary_points(): + assert stationary_points(sin(x), x, Interval(-pi/2, pi/2) + ) == {-pi/2, pi/2} + assert stationary_points(sin(x), x, Interval.Ropen(0, pi/4) + ) is S.EmptySet + assert stationary_points(tan(x), x, + ) is S.EmptySet + assert stationary_points(sin(x)*cos(x), x, Interval(0, pi) + ) == {pi/4, pi*Rational(3, 4)} + assert stationary_points(sec(x), x, Interval(0, pi) + ) == {0, pi} + assert stationary_points((x+3)*(x-2), x + ) == FiniteSet(Rational(-1, 2)) + assert stationary_points((x + 3)/(x - 2), x, Interval(-5, 5) + ) is S.EmptySet + assert stationary_points((x**2+3)/(x-2), x + ) == {2 - sqrt(7), 2 + sqrt(7)} + assert stationary_points((x**2+3)/(x-2), x, Interval(0, 5) + ) == {2 + sqrt(7)} + assert stationary_points(x**4 + x**3 - 5*x**2, x, S.Reals + ) == FiniteSet(-2, 0, Rational(5, 4)) + assert stationary_points(exp(x), x + ) is S.EmptySet + assert stationary_points(log(x) - x, x, S.Reals + ) == {1} + assert stationary_points(cos(x), x, Union(Interval(0, 5), Interval(-6, -3)) + ) == {0, -pi, pi} + assert stationary_points(y, x, S.Reals + ) == S.Reals + assert stationary_points(y, x, S.EmptySet) == S.EmptySet + + +def test_maximum(): + assert maximum(sin(x), x) is S.One + assert maximum(sin(x), x, Interval(0, 1)) == sin(1) + assert maximum(tan(x), x) is oo + assert maximum(tan(x), x, Interval(-pi/4, pi/4)) is S.One + assert maximum(sin(x)*cos(x), x, S.Reals) == S.Half + assert simplify(maximum(sin(x)*cos(x), x, Interval(pi*Rational(3, 8), pi*Rational(5, 8))) + ) == sqrt(2)/4 + assert maximum((x+3)*(x-2), x) is oo + assert maximum((x+3)*(x-2), x, Interval(-5, 0)) == S(14) + assert maximum((x+3)/(x-2), x, Interval(-5, 0)) == Rational(2, 7) + assert simplify(maximum(-x**4-x**3+x**2+10, x) + ) == 41*sqrt(41)/512 + Rational(5419, 512) + assert maximum(exp(x), x, Interval(-oo, 2)) == exp(2) + assert maximum(log(x) - x, x, S.Reals) is S.NegativeOne + assert maximum(cos(x), x, Union(Interval(0, 5), Interval(-6, -3)) + ) is S.One + assert maximum(cos(x)-sin(x), x, S.Reals) == sqrt(2) + assert maximum(y, x, S.Reals) == y + assert maximum(abs(a**3 + a), a, Interval(0, 2)) == 10 + assert maximum(abs(60*a**3 + 24*a), a, Interval(0, 2)) == 528 + assert maximum(abs(12*a*(5*a**2 + 2)), a, Interval(0, 2)) == 528 + assert maximum(x/sqrt(x**2+1), x, S.Reals) == 1 + + raises(ValueError, lambda : maximum(sin(x), x, S.EmptySet)) + raises(ValueError, lambda : maximum(log(cos(x)), x, S.EmptySet)) + raises(ValueError, lambda : maximum(1/(x**2 + y**2 + 1), x, S.EmptySet)) + raises(ValueError, lambda : maximum(sin(x), sin(x))) + raises(ValueError, lambda : maximum(sin(x), x*y, S.EmptySet)) + raises(ValueError, lambda : maximum(sin(x), S.One)) + + +def test_minimum(): + assert minimum(sin(x), x) is S.NegativeOne + assert minimum(sin(x), x, Interval(1, 4)) == sin(4) + assert minimum(tan(x), x) is -oo + assert minimum(tan(x), x, Interval(-pi/4, pi/4)) is S.NegativeOne + assert minimum(sin(x)*cos(x), x, S.Reals) == Rational(-1, 2) + assert simplify(minimum(sin(x)*cos(x), x, Interval(pi*Rational(3, 8), pi*Rational(5, 8))) + ) == -sqrt(2)/4 + assert minimum((x+3)*(x-2), x) == Rational(-25, 4) + assert minimum((x+3)/(x-2), x, Interval(-5, 0)) == Rational(-3, 2) + assert minimum(x**4-x**3+x**2+10, x) == S(10) + assert minimum(exp(x), x, Interval(-2, oo)) == exp(-2) + assert minimum(log(x) - x, x, S.Reals) is -oo + assert minimum(cos(x), x, Union(Interval(0, 5), Interval(-6, -3)) + ) is S.NegativeOne + assert minimum(cos(x)-sin(x), x, S.Reals) == -sqrt(2) + assert minimum(y, x, S.Reals) == y + assert minimum(x/sqrt(x**2+1), x, S.Reals) == -1 + + raises(ValueError, lambda : minimum(sin(x), x, S.EmptySet)) + raises(ValueError, lambda : minimum(log(cos(x)), x, S.EmptySet)) + raises(ValueError, lambda : minimum(1/(x**2 + y**2 + 1), x, S.EmptySet)) + raises(ValueError, lambda : minimum(sin(x), sin(x))) + raises(ValueError, lambda : minimum(sin(x), x*y, S.EmptySet)) + raises(ValueError, lambda : minimum(sin(x), S.One)) + + +def test_issue_19869(): + assert (maximum(sqrt(3)*(x - 1)/(3*sqrt(x**2 + 1)), x) + ) == sqrt(3)/3 + + +def test_issue_16469(): + f = abs(a) + assert function_range(f, a, S.Reals) == Interval(0, oo, False, True) + + +@_both_exp_pow +def test_issue_18747(): + assert periodicity(exp(pi*I*(x/4 + S.Half/2)), x) == 8 + + +def test_issue_25942(): + assert (acos(x) > pi/3).as_set() == Interval.Ropen(-1, S(1)/2) diff --git a/.venv/lib/python3.13/site-packages/sympy/calculus/util.py b/.venv/lib/python3.13/site-packages/sympy/calculus/util.py new file mode 100644 index 0000000000000000000000000000000000000000..dac316358873de2418dd8ab56445823f07a37b1b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/calculus/util.py @@ -0,0 +1,895 @@ +from .accumulationbounds import AccumBounds, AccumulationBounds # noqa: F401 +from .singularities import singularities +from sympy.core import Pow, S +from sympy.core.function import diff, expand_mul, Function +from sympy.core.kind import NumberKind +from sympy.core.mod import Mod +from sympy.core.numbers import equal_valued +from sympy.core.relational import Relational +from sympy.core.symbol import Symbol, Dummy +from sympy.core.sympify import _sympify +from sympy.functions.elementary.complexes import Abs, im, re +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.integers import frac +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import ( + TrigonometricFunction, sin, cos, tan, cot, csc, sec, + asin, acos, acot, atan, asec, acsc) +from sympy.functions.elementary.hyperbolic import (sinh, cosh, tanh, coth, + sech, csch, asinh, acosh, atanh, acoth, asech, acsch) +from sympy.polys.polytools import degree, lcm_list +from sympy.sets.sets import (Interval, Intersection, FiniteSet, Union, + Complement) +from sympy.sets.fancysets import ImageSet +from sympy.sets.conditionset import ConditionSet +from sympy.utilities import filldedent +from sympy.utilities.iterables import iterable +from sympy.matrices.dense import hessian + + +def continuous_domain(f, symbol, domain): + """ + Returns the domain on which the function expression f is continuous. + + This function is limited by the ability to determine the various + singularities and discontinuities of the given function. + The result is either given as a union of intervals or constructed using + other set operations. + + Parameters + ========== + + f : :py:class:`~.Expr` + The concerned function. + symbol : :py:class:`~.Symbol` + The variable for which the intervals are to be determined. + domain : :py:class:`~.Interval` + The domain over which the continuity of the symbol has to be checked. + + Examples + ======== + + >>> from sympy import Interval, Symbol, S, tan, log, pi, sqrt + >>> from sympy.calculus.util import continuous_domain + >>> x = Symbol('x') + >>> continuous_domain(1/x, x, S.Reals) + Union(Interval.open(-oo, 0), Interval.open(0, oo)) + >>> continuous_domain(tan(x), x, Interval(0, pi)) + Union(Interval.Ropen(0, pi/2), Interval.Lopen(pi/2, pi)) + >>> continuous_domain(sqrt(x - 2), x, Interval(-5, 5)) + Interval(2, 5) + >>> continuous_domain(log(2*x - 1), x, S.Reals) + Interval.open(1/2, oo) + + Returns + ======= + + :py:class:`~.Interval` + Union of all intervals where the function is continuous. + + Raises + ====== + + NotImplementedError + If the method to determine continuity of such a function + has not yet been developed. + + """ + from sympy.solvers.inequalities import solve_univariate_inequality + + if not domain.is_subset(S.Reals): + raise NotImplementedError(filldedent(''' + Domain must be a subset of S.Reals. + ''')) + implemented = [Pow, exp, log, Abs, frac, + sin, cos, tan, cot, sec, csc, + asin, acos, atan, acot, asec, acsc, + sinh, cosh, tanh, coth, sech, csch, + asinh, acosh, atanh, acoth, asech, acsch] + used = [fct.func for fct in f.atoms(Function) if fct.has(symbol)] + if any(func not in implemented for func in used): + raise NotImplementedError(filldedent(''' + Unable to determine the domain of the given function. + ''')) + + x = Symbol('x') + constraints = { + log: (x > 0,), + asin: (x >= -1, x <= 1), + acos: (x >= -1, x <= 1), + acosh: (x >= 1,), + atanh: (x > -1, x < 1), + asech: (x > 0, x <= 1) + } + constraints_union = { + asec: (x <= -1, x >= 1), + acsc: (x <= -1, x >= 1), + acoth: (x < -1, x > 1) + } + + cont_domain = domain + for atom in f.atoms(Pow): + den = atom.exp.as_numer_denom()[1] + if atom.exp.is_rational and den.is_odd: + pass # 0**negative handled by singularities() + else: + constraint = solve_univariate_inequality(atom.base >= 0, + symbol).as_set() + cont_domain = Intersection(constraint, cont_domain) + + for atom in f.atoms(Function): + if atom.func in constraints: + for c in constraints[atom.func]: + constraint_relational = c.subs(x, atom.args[0]) + constraint_set = solve_univariate_inequality( + constraint_relational, symbol).as_set() + cont_domain = Intersection(constraint_set, cont_domain) + elif atom.func in constraints_union: + constraint_set = S.EmptySet + for c in constraints_union[atom.func]: + constraint_relational = c.subs(x, atom.args[0]) + constraint_set += solve_univariate_inequality( + constraint_relational, symbol).as_set() + cont_domain = Intersection(constraint_set, cont_domain) + # XXX: the discontinuities below could be factored out in + # a new "discontinuities()". + elif atom.func == acot: + from sympy.solvers.solveset import solveset_real + # Sympy's acot() has a step discontinuity at 0. Since it's + # neither an essential singularity nor a pole, singularities() + # will not report it. But it's still relevant for determining + # the continuity of the function f. + cont_domain -= solveset_real(atom.args[0], symbol) + # Note that the above may introduce spurious discontinuities, e.g. + # for abs(acot(x)) at 0. + elif atom.func == frac: + from sympy.solvers.solveset import solveset_real + r = function_range(atom.args[0], symbol, domain) + r = Intersection(r, S.Integers) + if r.is_finite_set: + discont = S.EmptySet + for n in r: + discont += solveset_real(atom.args[0]-n, symbol) + else: + discont = ConditionSet( + symbol, S.Integers.contains(atom.args[0]), cont_domain) + cont_domain -= discont + + return cont_domain - singularities(f, symbol, domain) + + +def function_range(f, symbol, domain): + """ + Finds the range of a function in a given domain. + This method is limited by the ability to determine the singularities and + determine limits. + + Parameters + ========== + + f : :py:class:`~.Expr` + The concerned function. + symbol : :py:class:`~.Symbol` + The variable for which the range of function is to be determined. + domain : :py:class:`~.Interval` + The domain under which the range of the function has to be found. + + Examples + ======== + + >>> from sympy import Interval, Symbol, S, exp, log, pi, sqrt, sin, tan + >>> from sympy.calculus.util import function_range + >>> x = Symbol('x') + >>> function_range(sin(x), x, Interval(0, 2*pi)) + Interval(-1, 1) + >>> function_range(tan(x), x, Interval(-pi/2, pi/2)) + Interval(-oo, oo) + >>> function_range(1/x, x, S.Reals) + Union(Interval.open(-oo, 0), Interval.open(0, oo)) + >>> function_range(exp(x), x, S.Reals) + Interval.open(0, oo) + >>> function_range(log(x), x, S.Reals) + Interval(-oo, oo) + >>> function_range(sqrt(x), x, Interval(-5, 9)) + Interval(0, 3) + + Returns + ======= + + :py:class:`~.Interval` + Union of all ranges for all intervals under domain where function is + continuous. + + Raises + ====== + + NotImplementedError + If any of the intervals, in the given domain, for which function + is continuous are not finite or real, + OR if the critical points of the function on the domain cannot be found. + """ + + if domain is S.EmptySet: + return S.EmptySet + + period = periodicity(f, symbol) + if period == S.Zero: + # the expression is constant wrt symbol + return FiniteSet(f.expand()) + + from sympy.series.limits import limit + from sympy.solvers.solveset import solveset + + if period is not None: + if isinstance(domain, Interval): + if (domain.inf - domain.sup).is_infinite: + domain = Interval(0, period) + elif isinstance(domain, Union): + for sub_dom in domain.args: + if isinstance(sub_dom, Interval) and \ + ((sub_dom.inf - sub_dom.sup).is_infinite): + domain = Interval(0, period) + + intervals = continuous_domain(f, symbol, domain) + range_int = S.EmptySet + if isinstance(intervals,(Interval, FiniteSet)): + interval_iter = (intervals,) + elif isinstance(intervals, Union): + interval_iter = intervals.args + else: + raise NotImplementedError("Unable to find range for the given domain.") + + for interval in interval_iter: + if isinstance(interval, FiniteSet): + for singleton in interval: + if singleton in domain: + range_int += FiniteSet(f.subs(symbol, singleton)) + elif isinstance(interval, Interval): + vals = S.EmptySet + critical_values = S.EmptySet + bounds = ((interval.left_open, interval.inf, '+'), + (interval.right_open, interval.sup, '-')) + + for is_open, limit_point, direction in bounds: + if is_open: + critical_values += FiniteSet(limit(f, symbol, limit_point, direction)) + vals += critical_values + else: + vals += FiniteSet(f.subs(symbol, limit_point)) + + critical_points = solveset(f.diff(symbol), symbol, interval) + + if not iterable(critical_points): + raise NotImplementedError( + 'Unable to find critical points for {}'.format(f)) + if isinstance(critical_points, ImageSet): + raise NotImplementedError( + 'Infinite number of critical points for {}'.format(f)) + + for critical_point in critical_points: + vals += FiniteSet(f.subs(symbol, critical_point)) + + left_open, right_open = False, False + + if critical_values is not S.EmptySet: + if critical_values.inf == vals.inf: + left_open = True + + if critical_values.sup == vals.sup: + right_open = True + + range_int += Interval(vals.inf, vals.sup, left_open, right_open) + else: + raise NotImplementedError("Unable to find range for the given domain.") + + return range_int + + +def not_empty_in(finset_intersection, *syms): + """ + Finds the domain of the functions in ``finset_intersection`` in which the + ``finite_set`` is not-empty. + + Parameters + ========== + + finset_intersection : Intersection of FiniteSet + The unevaluated intersection of FiniteSet containing + real-valued functions with Union of Sets + syms : Tuple of symbols + Symbol for which domain is to be found + + Raises + ====== + + NotImplementedError + The algorithms to find the non-emptiness of the given FiniteSet are + not yet implemented. + ValueError + The input is not valid. + RuntimeError + It is a bug, please report it to the github issue tracker + (https://github.com/sympy/sympy/issues). + + Examples + ======== + + >>> from sympy import FiniteSet, Interval, not_empty_in, oo + >>> from sympy.abc import x + >>> not_empty_in(FiniteSet(x/2).intersect(Interval(0, 1)), x) + Interval(0, 2) + >>> not_empty_in(FiniteSet(x, x**2).intersect(Interval(1, 2)), x) + Union(Interval(1, 2), Interval(-sqrt(2), -1)) + >>> not_empty_in(FiniteSet(x**2/(x + 2)).intersect(Interval(1, oo)), x) + Union(Interval.Lopen(-2, -1), Interval(2, oo)) + """ + + # TODO: handle piecewise defined functions + # TODO: handle transcendental functions + # TODO: handle multivariate functions + if len(syms) == 0: + raise ValueError("One or more symbols must be given in syms.") + + if finset_intersection is S.EmptySet: + return S.EmptySet + + if isinstance(finset_intersection, Union): + elm_in_sets = finset_intersection.args[0] + return Union(not_empty_in(finset_intersection.args[1], *syms), + elm_in_sets) + + if isinstance(finset_intersection, FiniteSet): + finite_set = finset_intersection + _sets = S.Reals + else: + finite_set = finset_intersection.args[1] + _sets = finset_intersection.args[0] + + if not isinstance(finite_set, FiniteSet): + raise ValueError('A FiniteSet must be given, not %s: %s' % + (type(finite_set), finite_set)) + + if len(syms) == 1: + symb = syms[0] + else: + raise NotImplementedError('more than one variables %s not handled' % + (syms,)) + + def elm_domain(expr, intrvl): + """ Finds the domain of an expression in any given interval """ + from sympy.solvers.solveset import solveset + + _start = intrvl.start + _end = intrvl.end + _singularities = solveset(expr.as_numer_denom()[1], symb, + domain=S.Reals) + + if intrvl.right_open: + if _end is S.Infinity: + _domain1 = S.Reals + else: + _domain1 = solveset(expr < _end, symb, domain=S.Reals) + else: + _domain1 = solveset(expr <= _end, symb, domain=S.Reals) + + if intrvl.left_open: + if _start is S.NegativeInfinity: + _domain2 = S.Reals + else: + _domain2 = solveset(expr > _start, symb, domain=S.Reals) + else: + _domain2 = solveset(expr >= _start, symb, domain=S.Reals) + + # domain in the interval + expr_with_sing = Intersection(_domain1, _domain2) + expr_domain = Complement(expr_with_sing, _singularities) + return expr_domain + + if isinstance(_sets, Interval): + return Union(*[elm_domain(element, _sets) for element in finite_set]) + + if isinstance(_sets, Union): + _domain = S.EmptySet + for intrvl in _sets.args: + _domain_element = Union(*[elm_domain(element, intrvl) + for element in finite_set]) + _domain = Union(_domain, _domain_element) + return _domain + + +def periodicity(f, symbol, check=False): + """ + Tests the given function for periodicity in the given symbol. + + Parameters + ========== + + f : :py:class:`~.Expr` + The concerned function. + symbol : :py:class:`~.Symbol` + The variable for which the period is to be determined. + check : bool, optional + The flag to verify whether the value being returned is a period or not. + + Returns + ======= + + period + The period of the function is returned. + ``None`` is returned when the function is aperiodic or has a complex period. + The value of $0$ is returned as the period of a constant function. + + Raises + ====== + + NotImplementedError + The value of the period computed cannot be verified. + + + Notes + ===== + + Currently, we do not support functions with a complex period. + The period of functions having complex periodic values such + as ``exp``, ``sinh`` is evaluated to ``None``. + + The value returned might not be the "fundamental" period of the given + function i.e. it may not be the smallest periodic value of the function. + + The verification of the period through the ``check`` flag is not reliable + due to internal simplification of the given expression. Hence, it is set + to ``False`` by default. + + Examples + ======== + >>> from sympy import periodicity, Symbol, sin, cos, tan, exp + >>> x = Symbol('x') + >>> f = sin(x) + sin(2*x) + sin(3*x) + >>> periodicity(f, x) + 2*pi + >>> periodicity(sin(x)*cos(x), x) + pi + >>> periodicity(exp(tan(2*x) - 1), x) + pi/2 + >>> periodicity(sin(4*x)**cos(2*x), x) + pi + >>> periodicity(exp(x), x) + """ + if symbol.kind is not NumberKind: + raise NotImplementedError("Cannot use symbol of kind %s" % symbol.kind) + temp = Dummy('x', real=True) + f = f.subs(symbol, temp) + symbol = temp + + def _check(orig_f, period): + '''Return the checked period or raise an error.''' + new_f = orig_f.subs(symbol, symbol + period) + if new_f.equals(orig_f): + return period + else: + raise NotImplementedError(filldedent(''' + The period of the given function cannot be verified. + When `%s` was replaced with `%s + %s` in `%s`, the result + was `%s` which was not recognized as being the same as + the original function. + So either the period was wrong or the two forms were + not recognized as being equal. + Set check=False to obtain the value.''' % + (symbol, symbol, period, orig_f, new_f))) + + orig_f = f + period = None + + if isinstance(f, Relational): + f = f.lhs - f.rhs + + f = f.simplify() + + if symbol not in f.free_symbols: + return S.Zero + + if isinstance(f, TrigonometricFunction): + try: + period = f.period(symbol) + except NotImplementedError: + pass + + if isinstance(f, Abs): + arg = f.args[0] + if isinstance(arg, (sec, csc, cos)): + # all but tan and cot might have a + # a period that is half as large + # so recast as sin + arg = sin(arg.args[0]) + period = periodicity(arg, symbol) + if period is not None and isinstance(arg, sin): + # the argument of Abs was a trigonometric other than + # cot or tan; test to see if the half-period + # is valid. Abs(arg) has behaviour equivalent to + # orig_f, so use that for test: + orig_f = Abs(arg) + try: + return _check(orig_f, period/2) + except NotImplementedError as err: + if check: + raise NotImplementedError(err) + # else let new orig_f and period be + # checked below + + if isinstance(f, exp) or (f.is_Pow and f.base == S.Exp1): + f = Pow(S.Exp1, expand_mul(f.exp)) + if im(f) != 0: + period_real = periodicity(re(f), symbol) + period_imag = periodicity(im(f), symbol) + if period_real is not None and period_imag is not None: + period = lcim([period_real, period_imag]) + + if f.is_Pow and f.base != S.Exp1: + base, expo = f.args + base_has_sym = base.has(symbol) + expo_has_sym = expo.has(symbol) + + if base_has_sym and not expo_has_sym: + period = periodicity(base, symbol) + + elif expo_has_sym and not base_has_sym: + period = periodicity(expo, symbol) + + else: + period = _periodicity(f.args, symbol) + + elif f.is_Mul: + coeff, g = f.as_independent(symbol, as_Add=False) + if isinstance(g, TrigonometricFunction) or not equal_valued(coeff, 1): + period = periodicity(g, symbol) + else: + period = _periodicity(g.args, symbol) + + elif f.is_Add: + k, g = f.as_independent(symbol) + if k is not S.Zero: + return periodicity(g, symbol) + + period = _periodicity(g.args, symbol) + + elif isinstance(f, Mod): + a, n = f.args + + if a == symbol: + period = n + elif isinstance(a, TrigonometricFunction): + period = periodicity(a, symbol) + #check if 'f' is linear in 'symbol' + elif (a.is_polynomial(symbol) and degree(a, symbol) == 1 and + symbol not in n.free_symbols): + period = Abs(n / a.diff(symbol)) + + elif isinstance(f, Piecewise): + pass # not handling Piecewise yet as the return type is not favorable + + elif period is None: + from sympy.solvers.decompogen import compogen, decompogen + g_s = decompogen(f, symbol) + num_of_gs = len(g_s) + if num_of_gs > 1: + for index, g in enumerate(reversed(g_s)): + start_index = num_of_gs - 1 - index + g = compogen(g_s[start_index:], symbol) + if g not in (orig_f, f): # Fix for issue 12620 + period = periodicity(g, symbol) + if period is not None: + break + + if period is not None: + if check: + return _check(orig_f, period) + return period + + return None + + +def _periodicity(args, symbol): + """ + Helper for `periodicity` to find the period of a list of simpler + functions. + It uses the `lcim` method to find the least common period of + all the functions. + + Parameters + ========== + + args : Tuple of :py:class:`~.Symbol` + All the symbols present in a function. + + symbol : :py:class:`~.Symbol` + The symbol over which the function is to be evaluated. + + Returns + ======= + + period + The least common period of the function for all the symbols + of the function. + ``None`` if for at least one of the symbols the function is aperiodic. + + """ + periods = [] + for f in args: + period = periodicity(f, symbol) + if period is None: + return None + + if period is not S.Zero: + periods.append(period) + + if len(periods) > 1: + return lcim(periods) + + if periods: + return periods[0] + + +def lcim(numbers): + """Returns the least common integral multiple of a list of numbers. + + The numbers can be rational or irrational or a mixture of both. + `None` is returned for incommensurable numbers. + + Parameters + ========== + + numbers : list + Numbers (rational and/or irrational) for which lcim is to be found. + + Returns + ======= + + number + lcim if it exists, otherwise ``None`` for incommensurable numbers. + + Examples + ======== + + >>> from sympy.calculus.util import lcim + >>> from sympy import S, pi + >>> lcim([S(1)/2, S(3)/4, S(5)/6]) + 15/2 + >>> lcim([2*pi, 3*pi, pi, pi/2]) + 6*pi + >>> lcim([S(1), 2*pi]) + """ + result = None + if all(num.is_irrational for num in numbers): + factorized_nums = [num.factor() for num in numbers] + factors_num = [num.as_coeff_Mul() for num in factorized_nums] + term = factors_num[0][1] + if all(factor == term for coeff, factor in factors_num): + common_term = term + coeffs = [coeff for coeff, factor in factors_num] + result = lcm_list(coeffs) * common_term + + elif all(num.is_rational for num in numbers): + result = lcm_list(numbers) + + else: + pass + + return result + +def is_convex(f, *syms, domain=S.Reals): + r"""Determines the convexity of the function passed in the argument. + + Parameters + ========== + + f : :py:class:`~.Expr` + The concerned function. + syms : Tuple of :py:class:`~.Symbol` + The variables with respect to which the convexity is to be determined. + domain : :py:class:`~.Interval`, optional + The domain over which the convexity of the function has to be checked. + If unspecified, S.Reals will be the default domain. + + Returns + ======= + + bool + The method returns ``True`` if the function is convex otherwise it + returns ``False``. + + Raises + ====== + + NotImplementedError + The check for the convexity of multivariate functions is not implemented yet. + + Notes + ===== + + To determine concavity of a function pass `-f` as the concerned function. + To determine logarithmic convexity of a function pass `\log(f)` as + concerned function. + To determine logarithmic concavity of a function pass `-\log(f)` as + concerned function. + + Currently, convexity check of multivariate functions is not handled. + + Examples + ======== + + >>> from sympy import is_convex, symbols, exp, oo, Interval + >>> x = symbols('x') + >>> is_convex(exp(x), x) + True + >>> is_convex(x**3, x, domain = Interval(-1, oo)) + False + >>> is_convex(1/x**2, x, domain=Interval.open(0, oo)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Convex_function + .. [2] http://www.ifp.illinois.edu/~angelia/L3_convfunc.pdf + .. [3] https://en.wikipedia.org/wiki/Logarithmically_convex_function + .. [4] https://en.wikipedia.org/wiki/Logarithmically_concave_function + .. [5] https://en.wikipedia.org/wiki/Concave_function + + """ + if len(syms) > 1 : + return hessian(f, syms).is_positive_semidefinite + from sympy.solvers.inequalities import solve_univariate_inequality + f = _sympify(f) + var = syms[0] + if any(s in domain for s in singularities(f, var)): + return False + condition = f.diff(var, 2) < 0 + if solve_univariate_inequality(condition, var, False, domain): + return False + return True + + +def stationary_points(f, symbol, domain=S.Reals): + """ + Returns the stationary points of a function (where derivative of the + function is 0) in the given domain. + + Parameters + ========== + + f : :py:class:`~.Expr` + The concerned function. + symbol : :py:class:`~.Symbol` + The variable for which the stationary points are to be determined. + domain : :py:class:`~.Interval` + The domain over which the stationary points have to be checked. + If unspecified, ``S.Reals`` will be the default domain. + + Returns + ======= + + Set + A set of stationary points for the function. If there are no + stationary point, an :py:class:`~.EmptySet` is returned. + + Examples + ======== + + >>> from sympy import Interval, Symbol, S, sin, pi, pprint, stationary_points + >>> x = Symbol('x') + + >>> stationary_points(1/x, x, S.Reals) + EmptySet + + >>> pprint(stationary_points(sin(x), x), use_unicode=False) + pi 3*pi + {2*n*pi + -- | n in Integers} U {2*n*pi + ---- | n in Integers} + 2 2 + + >>> stationary_points(sin(x),x, Interval(0, 4*pi)) + {pi/2, 3*pi/2, 5*pi/2, 7*pi/2} + + """ + from sympy.solvers.solveset import solveset + + if domain is S.EmptySet: + return S.EmptySet + + domain = continuous_domain(f, symbol, domain) + set = solveset(diff(f, symbol), symbol, domain) + + return set + + +def maximum(f, symbol, domain=S.Reals): + """ + Returns the maximum value of a function in the given domain. + + Parameters + ========== + + f : :py:class:`~.Expr` + The concerned function. + symbol : :py:class:`~.Symbol` + The variable for maximum value needs to be determined. + domain : :py:class:`~.Interval` + The domain over which the maximum have to be checked. + If unspecified, then the global maximum is returned. + + Returns + ======= + + number + Maximum value of the function in given domain. + + Examples + ======== + + >>> from sympy import Interval, Symbol, S, sin, cos, pi, maximum + >>> x = Symbol('x') + + >>> f = -x**2 + 2*x + 5 + >>> maximum(f, x, S.Reals) + 6 + + >>> maximum(sin(x), x, Interval(-pi, pi/4)) + sqrt(2)/2 + + >>> maximum(sin(x)*cos(x), x) + 1/2 + + """ + if isinstance(symbol, Symbol): + if domain is S.EmptySet: + raise ValueError("Maximum value not defined for empty domain.") + + return function_range(f, symbol, domain).sup + else: + raise ValueError("%s is not a valid symbol." % symbol) + + +def minimum(f, symbol, domain=S.Reals): + """ + Returns the minimum value of a function in the given domain. + + Parameters + ========== + + f : :py:class:`~.Expr` + The concerned function. + symbol : :py:class:`~.Symbol` + The variable for minimum value needs to be determined. + domain : :py:class:`~.Interval` + The domain over which the minimum have to be checked. + If unspecified, then the global minimum is returned. + + Returns + ======= + + number + Minimum value of the function in the given domain. + + Examples + ======== + + >>> from sympy import Interval, Symbol, S, sin, cos, minimum + >>> x = Symbol('x') + + >>> f = x**2 + 2*x + 5 + >>> minimum(f, x, S.Reals) + 4 + + >>> minimum(sin(x), x, Interval(2, 3)) + sin(3) + + >>> minimum(sin(x)*cos(x), x) + -1/2 + + """ + if isinstance(symbol, Symbol): + if domain is S.EmptySet: + raise ValueError("Minimum value not defined for empty domain.") + + return function_range(f, symbol, domain).inf + else: + raise ValueError("%s is not a valid symbol." % symbol) diff --git a/.venv/lib/python3.13/site-packages/sympy/categories/__init__.py b/.venv/lib/python3.13/site-packages/sympy/categories/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c5007308a1b232e57f9ed164276862df0c5f265 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/categories/__init__.py @@ -0,0 +1,33 @@ +""" +Category Theory module. + +Provides some of the fundamental category-theory-related classes, +including categories, morphisms, diagrams. Functors are not +implemented yet. + +The general reference work this module tries to follow is + + [JoyOfCats] J. Adamek, H. Herrlich. G. E. Strecker: Abstract and + Concrete Categories. The Joy of Cats. + +The latest version of this book should be available for free download +from + + katmat.math.uni-bremen.de/acc/acc.pdf + +""" + +from .baseclasses import (Object, Morphism, IdentityMorphism, + NamedMorphism, CompositeMorphism, Category, + Diagram) + +from .diagram_drawing import (DiagramGrid, XypicDiagramDrawer, + xypic_draw_diagram, preview_diagram) + +__all__ = [ + 'Object', 'Morphism', 'IdentityMorphism', 'NamedMorphism', + 'CompositeMorphism', 'Category', 'Diagram', + + 'DiagramGrid', 'XypicDiagramDrawer', 'xypic_draw_diagram', + 'preview_diagram', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/categories/baseclasses.py b/.venv/lib/python3.13/site-packages/sympy/categories/baseclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ab5153ae4e95f193030864c8f32a52254f2458 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/categories/baseclasses.py @@ -0,0 +1,978 @@ +from sympy.core import S, Basic, Dict, Symbol, Tuple, sympify +from sympy.core.symbol import Str +from sympy.sets import Set, FiniteSet, EmptySet +from sympy.utilities.iterables import iterable + + +class Class(Set): + r""" + The base class for any kind of class in the set-theoretic sense. + + Explanation + =========== + + In axiomatic set theories, everything is a class. A class which + can be a member of another class is a set. A class which is not a + member of another class is a proper class. The class `\{1, 2\}` + is a set; the class of all sets is a proper class. + + This class is essentially a synonym for :class:`sympy.core.Set`. + The goal of this class is to assure easier migration to the + eventual proper implementation of set theory. + """ + is_proper = False + + +class Object(Symbol): + """ + The base class for any kind of object in an abstract category. + + Explanation + =========== + + While technically any instance of :class:`~.Basic` will do, this + class is the recommended way to create abstract objects in + abstract categories. + """ + + +class Morphism(Basic): + """ + The base class for any morphism in an abstract category. + + Explanation + =========== + + In abstract categories, a morphism is an arrow between two + category objects. The object where the arrow starts is called the + domain, while the object where the arrow ends is called the + codomain. + + Two morphisms between the same pair of objects are considered to + be the same morphisms. To distinguish between morphisms between + the same objects use :class:`NamedMorphism`. + + It is prohibited to instantiate this class. Use one of the + derived classes instead. + + See Also + ======== + + IdentityMorphism, NamedMorphism, CompositeMorphism + """ + def __new__(cls, domain, codomain): + raise(NotImplementedError( + "Cannot instantiate Morphism. Use derived classes instead.")) + + @property + def domain(self): + """ + Returns the domain of the morphism. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> f = NamedMorphism(A, B, "f") + >>> f.domain + Object("A") + + """ + return self.args[0] + + @property + def codomain(self): + """ + Returns the codomain of the morphism. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> f = NamedMorphism(A, B, "f") + >>> f.codomain + Object("B") + + """ + return self.args[1] + + def compose(self, other): + r""" + Composes self with the supplied morphism. + + The order of elements in the composition is the usual order, + i.e., to construct `g\circ f` use ``g.compose(f)``. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> g * f + CompositeMorphism((NamedMorphism(Object("A"), Object("B"), "f"), + NamedMorphism(Object("B"), Object("C"), "g"))) + >>> (g * f).domain + Object("A") + >>> (g * f).codomain + Object("C") + + """ + return CompositeMorphism(other, self) + + def __mul__(self, other): + r""" + Composes self with the supplied morphism. + + The semantics of this operation is given by the following + equation: ``g * f == g.compose(f)`` for composable morphisms + ``g`` and ``f``. + + See Also + ======== + + compose + """ + return self.compose(other) + + +class IdentityMorphism(Morphism): + """ + Represents an identity morphism. + + Explanation + =========== + + An identity morphism is a morphism with equal domain and codomain, + which acts as an identity with respect to composition. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, IdentityMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> f = NamedMorphism(A, B, "f") + >>> id_A = IdentityMorphism(A) + >>> id_B = IdentityMorphism(B) + >>> f * id_A == f + True + >>> id_B * f == f + True + + See Also + ======== + + Morphism + """ + def __new__(cls, domain): + return Basic.__new__(cls, domain) + + @property + def codomain(self): + return self.domain + + +class NamedMorphism(Morphism): + """ + Represents a morphism which has a name. + + Explanation + =========== + + Names are used to distinguish between morphisms which have the + same domain and codomain: two named morphisms are equal if they + have the same domains, codomains, and names. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> f = NamedMorphism(A, B, "f") + >>> f + NamedMorphism(Object("A"), Object("B"), "f") + >>> f.name + 'f' + + See Also + ======== + + Morphism + """ + def __new__(cls, domain, codomain, name): + if not name: + raise ValueError("Empty morphism names not allowed.") + + if not isinstance(name, Str): + name = Str(name) + + return Basic.__new__(cls, domain, codomain, name) + + @property + def name(self): + """ + Returns the name of the morphism. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> f = NamedMorphism(A, B, "f") + >>> f.name + 'f' + + """ + return self.args[2].name + + +class CompositeMorphism(Morphism): + r""" + Represents a morphism which is a composition of other morphisms. + + Explanation + =========== + + Two composite morphisms are equal if the morphisms they were + obtained from (components) are the same and were listed in the + same order. + + The arguments to the constructor for this class should be listed + in diagram order: to obtain the composition `g\circ f` from the + instances of :class:`Morphism` ``g`` and ``f`` use + ``CompositeMorphism(f, g)``. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, CompositeMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> g * f + CompositeMorphism((NamedMorphism(Object("A"), Object("B"), "f"), + NamedMorphism(Object("B"), Object("C"), "g"))) + >>> CompositeMorphism(f, g) == g * f + True + + """ + @staticmethod + def _add_morphism(t, morphism): + """ + Intelligently adds ``morphism`` to tuple ``t``. + + Explanation + =========== + + If ``morphism`` is a composite morphism, its components are + added to the tuple. If ``morphism`` is an identity, nothing + is added to the tuple. + + No composability checks are performed. + """ + if isinstance(morphism, CompositeMorphism): + # ``morphism`` is a composite morphism; we have to + # denest its components. + return t + morphism.components + elif isinstance(morphism, IdentityMorphism): + # ``morphism`` is an identity. Nothing happens. + return t + else: + return t + Tuple(morphism) + + def __new__(cls, *components): + if components and not isinstance(components[0], Morphism): + # Maybe the user has explicitly supplied a list of + # morphisms. + return CompositeMorphism.__new__(cls, *components[0]) + + normalised_components = Tuple() + + for current, following in zip(components, components[1:]): + if not isinstance(current, Morphism) or \ + not isinstance(following, Morphism): + raise TypeError("All components must be morphisms.") + + if current.codomain != following.domain: + raise ValueError("Uncomposable morphisms.") + + normalised_components = CompositeMorphism._add_morphism( + normalised_components, current) + + # We haven't added the last morphism to the list of normalised + # components. Add it now. + normalised_components = CompositeMorphism._add_morphism( + normalised_components, components[-1]) + + if not normalised_components: + # If ``normalised_components`` is empty, only identities + # were supplied. Since they all were composable, they are + # all the same identities. + return components[0] + elif len(normalised_components) == 1: + # No sense to construct a whole CompositeMorphism. + return normalised_components[0] + + return Basic.__new__(cls, normalised_components) + + @property + def components(self): + """ + Returns the components of this composite morphism. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> (g * f).components + (NamedMorphism(Object("A"), Object("B"), "f"), + NamedMorphism(Object("B"), Object("C"), "g")) + + """ + return self.args[0] + + @property + def domain(self): + """ + Returns the domain of this composite morphism. + + The domain of the composite morphism is the domain of its + first component. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> (g * f).domain + Object("A") + + """ + return self.components[0].domain + + @property + def codomain(self): + """ + Returns the codomain of this composite morphism. + + The codomain of the composite morphism is the codomain of its + last component. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> (g * f).codomain + Object("C") + + """ + return self.components[-1].codomain + + def flatten(self, new_name): + """ + Forgets the composite structure of this morphism. + + Explanation + =========== + + If ``new_name`` is not empty, returns a :class:`NamedMorphism` + with the supplied name, otherwise returns a :class:`Morphism`. + In both cases the domain of the new morphism is the domain of + this composite morphism and the codomain of the new morphism + is the codomain of this composite morphism. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> (g * f).flatten("h") + NamedMorphism(Object("A"), Object("C"), "h") + + """ + return NamedMorphism(self.domain, self.codomain, new_name) + + +class Category(Basic): + r""" + An (abstract) category. + + Explanation + =========== + + A category [JoyOfCats] is a quadruple `\mbox{K} = (O, \hom, id, + \circ)` consisting of + + * a (set-theoretical) class `O`, whose members are called + `K`-objects, + + * for each pair `(A, B)` of `K`-objects, a set `\hom(A, B)` whose + members are called `K`-morphisms from `A` to `B`, + + * for a each `K`-object `A`, a morphism `id:A\rightarrow A`, + called the `K`-identity of `A`, + + * a composition law `\circ` associating with every `K`-morphisms + `f:A\rightarrow B` and `g:B\rightarrow C` a `K`-morphism `g\circ + f:A\rightarrow C`, called the composite of `f` and `g`. + + Composition is associative, `K`-identities are identities with + respect to composition, and the sets `\hom(A, B)` are pairwise + disjoint. + + This class knows nothing about its objects and morphisms. + Concrete cases of (abstract) categories should be implemented as + classes derived from this one. + + Certain instances of :class:`Diagram` can be asserted to be + commutative in a :class:`Category` by supplying the argument + ``commutative_diagrams`` in the constructor. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram, Category + >>> from sympy import FiniteSet + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g]) + >>> K = Category("K", commutative_diagrams=[d]) + >>> K.commutative_diagrams == FiniteSet(d) + True + + See Also + ======== + + Diagram + """ + def __new__(cls, name, objects=EmptySet, commutative_diagrams=EmptySet): + if not name: + raise ValueError("A Category cannot have an empty name.") + + if not isinstance(name, Str): + name = Str(name) + + if not isinstance(objects, Class): + objects = Class(objects) + + new_category = Basic.__new__(cls, name, objects, + FiniteSet(*commutative_diagrams)) + return new_category + + @property + def name(self): + """ + Returns the name of this category. + + Examples + ======== + + >>> from sympy.categories import Category + >>> K = Category("K") + >>> K.name + 'K' + + """ + return self.args[0].name + + @property + def objects(self): + """ + Returns the class of objects of this category. + + Examples + ======== + + >>> from sympy.categories import Object, Category + >>> from sympy import FiniteSet + >>> A = Object("A") + >>> B = Object("B") + >>> K = Category("K", FiniteSet(A, B)) + >>> K.objects + Class({Object("A"), Object("B")}) + + """ + return self.args[1] + + @property + def commutative_diagrams(self): + """ + Returns the :class:`~.FiniteSet` of diagrams which are known to + be commutative in this category. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram, Category + >>> from sympy import FiniteSet + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g]) + >>> K = Category("K", commutative_diagrams=[d]) + >>> K.commutative_diagrams == FiniteSet(d) + True + + """ + return self.args[2] + + def hom(self, A, B): + raise NotImplementedError( + "hom-sets are not implemented in Category.") + + def all_morphisms(self): + raise NotImplementedError( + "Obtaining the class of morphisms is not implemented in Category.") + + +class Diagram(Basic): + r""" + Represents a diagram in a certain category. + + Explanation + =========== + + Informally, a diagram is a collection of objects of a category and + certain morphisms between them. A diagram is still a monoid with + respect to morphism composition; i.e., identity morphisms, as well + as all composites of morphisms included in the diagram belong to + the diagram. For a more formal approach to this notion see + [Pare1970]. + + The components of composite morphisms are also added to the + diagram. No properties are assigned to such morphisms by default. + + A commutative diagram is often accompanied by a statement of the + following kind: "if such morphisms with such properties exist, + then such morphisms which such properties exist and the diagram is + commutative". To represent this, an instance of :class:`Diagram` + includes a collection of morphisms which are the premises and + another collection of conclusions. ``premises`` and + ``conclusions`` associate morphisms belonging to the corresponding + categories with the :class:`~.FiniteSet`'s of their properties. + + The set of properties of a composite morphism is the intersection + of the sets of properties of its components. The domain and + codomain of a conclusion morphism should be among the domains and + codomains of the morphisms listed as the premises of a diagram. + + No checks are carried out of whether the supplied object and + morphisms do belong to one and the same category. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy import pprint, default_sort_key + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g]) + >>> premises_keys = sorted(d.premises.keys(), key=default_sort_key) + >>> pprint(premises_keys, use_unicode=False) + [g*f:A-->C, id:A-->A, id:B-->B, id:C-->C, f:A-->B, g:B-->C] + >>> pprint(d.premises, use_unicode=False) + {g*f:A-->C: EmptySet, id:A-->A: EmptySet, id:B-->B: EmptySet, + id:C-->C: EmptySet, f:A-->B: EmptySet, g:B-->C: EmptySet} + >>> d = Diagram([f, g], {g * f: "unique"}) + >>> pprint(d.conclusions,use_unicode=False) + {g*f:A-->C: {unique}} + + References + ========== + + [Pare1970] B. Pareigis: Categories and functors. Academic Press, 1970. + + """ + @staticmethod + def _set_dict_union(dictionary, key, value): + """ + If ``key`` is in ``dictionary``, set the new value of ``key`` + to be the union between the old value and ``value``. + Otherwise, set the value of ``key`` to ``value. + + Returns ``True`` if the key already was in the dictionary and + ``False`` otherwise. + """ + if key in dictionary: + dictionary[key] = dictionary[key] | value + return True + else: + dictionary[key] = value + return False + + @staticmethod + def _add_morphism_closure(morphisms, morphism, props, add_identities=True, + recurse_composites=True): + """ + Adds a morphism and its attributes to the supplied dictionary + ``morphisms``. If ``add_identities`` is True, also adds the + identity morphisms for the domain and the codomain of + ``morphism``. + """ + if not Diagram._set_dict_union(morphisms, morphism, props): + # We have just added a new morphism. + + if isinstance(morphism, IdentityMorphism): + if props: + # Properties for identity morphisms don't really + # make sense, because very much is known about + # identity morphisms already, so much that they + # are trivial. Having properties for identity + # morphisms would only be confusing. + raise ValueError( + "Instances of IdentityMorphism cannot have properties.") + return + + if add_identities: + empty = EmptySet + + id_dom = IdentityMorphism(morphism.domain) + id_cod = IdentityMorphism(morphism.codomain) + + Diagram._set_dict_union(morphisms, id_dom, empty) + Diagram._set_dict_union(morphisms, id_cod, empty) + + for existing_morphism, existing_props in list(morphisms.items()): + new_props = existing_props & props + if morphism.domain == existing_morphism.codomain: + left = morphism * existing_morphism + Diagram._set_dict_union(morphisms, left, new_props) + if morphism.codomain == existing_morphism.domain: + right = existing_morphism * morphism + Diagram._set_dict_union(morphisms, right, new_props) + + if isinstance(morphism, CompositeMorphism) and recurse_composites: + # This is a composite morphism, add its components as + # well. + empty = EmptySet + for component in morphism.components: + Diagram._add_morphism_closure(morphisms, component, empty, + add_identities) + + def __new__(cls, *args): + """ + Construct a new instance of Diagram. + + Explanation + =========== + + If no arguments are supplied, an empty diagram is created. + + If at least an argument is supplied, ``args[0]`` is + interpreted as the premises of the diagram. If ``args[0]`` is + a list, it is interpreted as a list of :class:`Morphism`'s, in + which each :class:`Morphism` has an empty set of properties. + If ``args[0]`` is a Python dictionary or a :class:`Dict`, it + is interpreted as a dictionary associating to some + :class:`Morphism`'s some properties. + + If at least two arguments are supplied ``args[1]`` is + interpreted as the conclusions of the diagram. The type of + ``args[1]`` is interpreted in exactly the same way as the type + of ``args[0]``. If only one argument is supplied, the diagram + has no conclusions. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import IdentityMorphism, Diagram + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g]) + >>> IdentityMorphism(A) in d.premises.keys() + True + >>> g * f in d.premises.keys() + True + >>> d = Diagram([f, g], {g * f: "unique"}) + >>> d.conclusions[g * f] + {unique} + + """ + premises = {} + conclusions = {} + + # Here we will keep track of the objects which appear in the + # premises. + objects = EmptySet + + if len(args) >= 1: + # We've got some premises in the arguments. + premises_arg = args[0] + + if isinstance(premises_arg, list): + # The user has supplied a list of morphisms, none of + # which have any attributes. + empty = EmptySet + + for morphism in premises_arg: + objects |= FiniteSet(morphism.domain, morphism.codomain) + Diagram._add_morphism_closure(premises, morphism, empty) + elif isinstance(premises_arg, (dict, Dict)): + # The user has supplied a dictionary of morphisms and + # their properties. + for morphism, props in premises_arg.items(): + objects |= FiniteSet(morphism.domain, morphism.codomain) + Diagram._add_morphism_closure( + premises, morphism, FiniteSet(*props) if iterable(props) else FiniteSet(props)) + + if len(args) >= 2: + # We also have some conclusions. + conclusions_arg = args[1] + + if isinstance(conclusions_arg, list): + # The user has supplied a list of morphisms, none of + # which have any attributes. + empty = EmptySet + + for morphism in conclusions_arg: + # Check that no new objects appear in conclusions. + if ((sympify(objects.contains(morphism.domain)) is S.true) and + (sympify(objects.contains(morphism.codomain)) is S.true)): + # No need to add identities and recurse + # composites this time. + Diagram._add_morphism_closure( + conclusions, morphism, empty, add_identities=False, + recurse_composites=False) + elif isinstance(conclusions_arg, (dict, Dict)): + # The user has supplied a dictionary of morphisms and + # their properties. + for morphism, props in conclusions_arg.items(): + # Check that no new objects appear in conclusions. + if (morphism.domain in objects) and \ + (morphism.codomain in objects): + # No need to add identities and recurse + # composites this time. + Diagram._add_morphism_closure( + conclusions, morphism, FiniteSet(*props) if iterable(props) else FiniteSet(props), + add_identities=False, recurse_composites=False) + + return Basic.__new__(cls, Dict(premises), Dict(conclusions), objects) + + @property + def premises(self): + """ + Returns the premises of this diagram. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import IdentityMorphism, Diagram + >>> from sympy import pretty + >>> A = Object("A") + >>> B = Object("B") + >>> f = NamedMorphism(A, B, "f") + >>> id_A = IdentityMorphism(A) + >>> id_B = IdentityMorphism(B) + >>> d = Diagram([f]) + >>> print(pretty(d.premises, use_unicode=False)) + {id:A-->A: EmptySet, id:B-->B: EmptySet, f:A-->B: EmptySet} + + """ + return self.args[0] + + @property + def conclusions(self): + """ + Returns the conclusions of this diagram. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import IdentityMorphism, Diagram + >>> from sympy import FiniteSet + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g]) + >>> IdentityMorphism(A) in d.premises.keys() + True + >>> g * f in d.premises.keys() + True + >>> d = Diagram([f, g], {g * f: "unique"}) + >>> d.conclusions[g * f] == FiniteSet("unique") + True + + """ + return self.args[1] + + @property + def objects(self): + """ + Returns the :class:`~.FiniteSet` of objects that appear in this + diagram. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g]) + >>> d.objects + {Object("A"), Object("B"), Object("C")} + + """ + return self.args[2] + + def hom(self, A, B): + """ + Returns a 2-tuple of sets of morphisms between objects ``A`` and + ``B``: one set of morphisms listed as premises, and the other set + of morphisms listed as conclusions. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy import pretty + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g], {g * f: "unique"}) + >>> print(pretty(d.hom(A, C), use_unicode=False)) + ({g*f:A-->C}, {g*f:A-->C}) + + See Also + ======== + Object, Morphism + """ + premises = EmptySet + conclusions = EmptySet + + for morphism in self.premises.keys(): + if (morphism.domain == A) and (morphism.codomain == B): + premises |= FiniteSet(morphism) + for morphism in self.conclusions.keys(): + if (morphism.domain == A) and (morphism.codomain == B): + conclusions |= FiniteSet(morphism) + + return (premises, conclusions) + + def is_subdiagram(self, diagram): + """ + Checks whether ``diagram`` is a subdiagram of ``self``. + Diagram `D'` is a subdiagram of `D` if all premises + (conclusions) of `D'` are contained in the premises + (conclusions) of `D`. The morphisms contained + both in `D'` and `D` should have the same properties for `D'` + to be a subdiagram of `D`. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g], {g * f: "unique"}) + >>> d1 = Diagram([f]) + >>> d.is_subdiagram(d1) + True + >>> d1.is_subdiagram(d) + False + """ + premises = all((m in self.premises) and + (diagram.premises[m] == self.premises[m]) + for m in diagram.premises) + if not premises: + return False + + conclusions = all((m in self.conclusions) and + (diagram.conclusions[m] == self.conclusions[m]) + for m in diagram.conclusions) + + # Premises is surely ``True`` here. + return conclusions + + def subdiagram_from_objects(self, objects): + """ + If ``objects`` is a subset of the objects of ``self``, returns + a diagram which has as premises all those premises of ``self`` + which have a domains and codomains in ``objects``, likewise + for conclusions. Properties are preserved. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy import FiniteSet + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g], {f: "unique", g*f: "veryunique"}) + >>> d1 = d.subdiagram_from_objects(FiniteSet(A, B)) + >>> d1 == Diagram([f], {f: "unique"}) + True + """ + if not objects.is_subset(self.objects): + raise ValueError( + "Supplied objects should all belong to the diagram.") + + new_premises = {} + for morphism, props in self.premises.items(): + if ((sympify(objects.contains(morphism.domain)) is S.true) and + (sympify(objects.contains(morphism.codomain)) is S.true)): + new_premises[morphism] = props + + new_conclusions = {} + for morphism, props in self.conclusions.items(): + if ((sympify(objects.contains(morphism.domain)) is S.true) and + (sympify(objects.contains(morphism.codomain)) is S.true)): + new_conclusions[morphism] = props + + return Diagram(new_premises, new_conclusions) diff --git a/.venv/lib/python3.13/site-packages/sympy/categories/diagram_drawing.py b/.venv/lib/python3.13/site-packages/sympy/categories/diagram_drawing.py new file mode 100644 index 0000000000000000000000000000000000000000..2a9b507cd86cf0e633b5abf7a0c9a353740af334 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/categories/diagram_drawing.py @@ -0,0 +1,2580 @@ +r""" +This module contains the functionality to arrange the nodes of a +diagram on an abstract grid, and then to produce a graphical +representation of the grid. + +The currently supported back-ends are Xy-pic [Xypic]. + +Layout Algorithm +================ + +This section provides an overview of the algorithms implemented in +:class:`DiagramGrid` to lay out diagrams. + +The first step of the algorithm is the removal composite and identity +morphisms which do not have properties in the supplied diagram. The +premises and conclusions of the diagram are then merged. + +The generic layout algorithm begins with the construction of the +"skeleton" of the diagram. The skeleton is an undirected graph which +has the objects of the diagram as vertices and has an (undirected) +edge between each pair of objects between which there exist morphisms. +The direction of the morphisms does not matter at this stage. The +skeleton also includes an edge between each pair of vertices `A` and +`C` such that there exists an object `B` which is connected via +a morphism to `A`, and via a morphism to `C`. + +The skeleton constructed in this way has the property that every +object is a vertex of a triangle formed by three edges of the +skeleton. This property lies at the base of the generic layout +algorithm. + +After the skeleton has been constructed, the algorithm lists all +triangles which can be formed. Note that some triangles will not have +all edges corresponding to morphisms which will actually be drawn. +Triangles which have only one edge or less which will actually be +drawn are immediately discarded. + +The list of triangles is sorted according to the number of edges which +correspond to morphisms, then the triangle with the least number of such +edges is selected. One of such edges is picked and the corresponding +objects are placed horizontally, on a grid. This edge is recorded to +be in the fringe. The algorithm then finds a "welding" of a triangle +to the fringe. A welding is an edge in the fringe where a triangle +could be attached. If the algorithm succeeds in finding such a +welding, it adds to the grid that vertex of the triangle which was not +yet included in any edge in the fringe and records the two new edges in +the fringe. This process continues iteratively until all objects of +the diagram has been placed or until no more weldings can be found. + +An edge is only removed from the fringe when a welding to this edge +has been found, and there is no room around this edge to place +another vertex. + +When no more weldings can be found, but there are still triangles +left, the algorithm searches for a possibility of attaching one of the +remaining triangles to the existing structure by a vertex. If such a +possibility is found, the corresponding edge of the found triangle is +placed in the found space and the iterative process of welding +triangles restarts. + +When logical groups are supplied, each of these groups is laid out +independently. Then a diagram is constructed in which groups are +objects and any two logical groups between which there exist morphisms +are connected via a morphism. This diagram is laid out. Finally, +the grid which includes all objects of the initial diagram is +constructed by replacing the cells which contain logical groups with +the corresponding laid out grids, and by correspondingly expanding the +rows and columns. + +The sequential layout algorithm begins by constructing the +underlying undirected graph defined by the morphisms obtained after +simplifying premises and conclusions and merging them (see above). +The vertex with the minimal degree is then picked up and depth-first +search is started from it. All objects which are located at distance +`n` from the root in the depth-first search tree, are positioned in +the `n`-th column of the resulting grid. The sequential layout will +therefore attempt to lay the objects out along a line. + +References +========== + +.. [Xypic] https://xy-pic.sourceforge.net/ + +""" +from sympy.categories import (CompositeMorphism, IdentityMorphism, + NamedMorphism, Diagram) +from sympy.core import Dict, Symbol, default_sort_key +from sympy.printing.latex import latex +from sympy.sets import FiniteSet +from sympy.utilities.iterables import iterable +from sympy.utilities.decorator import doctest_depends_on + +from itertools import chain + + +__doctest_requires__ = {('preview_diagram',): 'pyglet'} + + +class _GrowableGrid: + """ + Holds a growable grid of objects. + + Explanation + =========== + + It is possible to append or prepend a row or a column to the grid + using the corresponding methods. Prepending rows or columns has + the effect of changing the coordinates of the already existing + elements. + + This class currently represents a naive implementation of the + functionality with little attempt at optimisation. + """ + def __init__(self, width, height): + self._width = width + self._height = height + + self._array = [[None for j in range(width)] for i in range(height)] + + @property + def width(self): + return self._width + + @property + def height(self): + return self._height + + def __getitem__(self, i_j): + """ + Returns the element located at in the i-th line and j-th + column. + """ + i, j = i_j + return self._array[i][j] + + def __setitem__(self, i_j, newvalue): + """ + Sets the element located at in the i-th line and j-th + column. + """ + i, j = i_j + self._array[i][j] = newvalue + + def append_row(self): + """ + Appends an empty row to the grid. + """ + self._height += 1 + self._array.append([None for j in range(self._width)]) + + def append_column(self): + """ + Appends an empty column to the grid. + """ + self._width += 1 + for i in range(self._height): + self._array[i].append(None) + + def prepend_row(self): + """ + Prepends the grid with an empty row. + """ + self._height += 1 + self._array.insert(0, [None for j in range(self._width)]) + + def prepend_column(self): + """ + Prepends the grid with an empty column. + """ + self._width += 1 + for i in range(self._height): + self._array[i].insert(0, None) + + +class DiagramGrid: + r""" + Constructs and holds the fitting of the diagram into a grid. + + Explanation + =========== + + The mission of this class is to analyse the structure of the + supplied diagram and to place its objects on a grid such that, + when the objects and the morphisms are actually drawn, the diagram + would be "readable", in the sense that there will not be many + intersections of moprhisms. This class does not perform any + actual drawing. It does strive nevertheless to offer sufficient + metadata to draw a diagram. + + Consider the following simple diagram. + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import Diagram, DiagramGrid + >>> from sympy import pprint + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g]) + + The simplest way to have a diagram laid out is the following: + + >>> grid = DiagramGrid(diagram) + >>> (grid.width, grid.height) + (2, 2) + >>> pprint(grid) + A B + + C + + Sometimes one sees the diagram as consisting of logical groups. + One can advise ``DiagramGrid`` as to such groups by employing the + ``groups`` keyword argument. + + Consider the following diagram: + + >>> D = Object("D") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> h = NamedMorphism(D, A, "h") + >>> k = NamedMorphism(D, B, "k") + >>> diagram = Diagram([f, g, h, k]) + + Lay it out with generic layout: + + >>> grid = DiagramGrid(diagram) + >>> pprint(grid) + A B D + + C + + Now, we can group the objects `A` and `D` to have them near one + another: + + >>> grid = DiagramGrid(diagram, groups=[[A, D], B, C]) + >>> pprint(grid) + B C + + A D + + Note how the positioning of the other objects changes. + + Further indications can be supplied to the constructor of + :class:`DiagramGrid` using keyword arguments. The currently + supported hints are explained in the following paragraphs. + + :class:`DiagramGrid` does not automatically guess which layout + would suit the supplied diagram better. Consider, for example, + the following linear diagram: + + >>> E = Object("E") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> h = NamedMorphism(C, D, "h") + >>> i = NamedMorphism(D, E, "i") + >>> diagram = Diagram([f, g, h, i]) + + When laid out with the generic layout, it does not get to look + linear: + + >>> grid = DiagramGrid(diagram) + >>> pprint(grid) + A B + + C D + + E + + To get it laid out in a line, use ``layout="sequential"``: + + >>> grid = DiagramGrid(diagram, layout="sequential") + >>> pprint(grid) + A B C D E + + One may sometimes need to transpose the resulting layout. While + this can always be done by hand, :class:`DiagramGrid` provides a + hint for that purpose: + + >>> grid = DiagramGrid(diagram, layout="sequential", transpose=True) + >>> pprint(grid) + A + + B + + C + + D + + E + + Separate hints can also be provided for each group. For an + example, refer to ``tests/test_drawing.py``, and see the different + ways in which the five lemma [FiveLemma] can be laid out. + + See Also + ======== + + Diagram + + References + ========== + + .. [FiveLemma] https://en.wikipedia.org/wiki/Five_lemma + """ + @staticmethod + def _simplify_morphisms(morphisms): + """ + Given a dictionary mapping morphisms to their properties, + returns a new dictionary in which there are no morphisms which + do not have properties, and which are compositions of other + morphisms included in the dictionary. Identities are dropped + as well. + """ + newmorphisms = {} + for morphism, props in morphisms.items(): + if isinstance(morphism, CompositeMorphism) and not props: + continue + elif isinstance(morphism, IdentityMorphism): + continue + else: + newmorphisms[morphism] = props + return newmorphisms + + @staticmethod + def _merge_premises_conclusions(premises, conclusions): + """ + Given two dictionaries of morphisms and their properties, + produces a single dictionary which includes elements from both + dictionaries. If a morphism has some properties in premises + and also in conclusions, the properties in conclusions take + priority. + """ + return dict(chain(premises.items(), conclusions.items())) + + @staticmethod + def _juxtapose_edges(edge1, edge2): + """ + If ``edge1`` and ``edge2`` have precisely one common endpoint, + returns an edge which would form a triangle with ``edge1`` and + ``edge2``. + + If ``edge1`` and ``edge2`` do not have a common endpoint, + returns ``None``. + + If ``edge1`` and ``edge`` are the same edge, returns ``None``. + """ + intersection = edge1 & edge2 + if len(intersection) != 1: + # The edges either have no common points or are equal. + return None + + # The edges have a common endpoint. Extract the different + # endpoints and set up the new edge. + return (edge1 - intersection) | (edge2 - intersection) + + @staticmethod + def _add_edge_append(dictionary, edge, elem): + """ + If ``edge`` is not in ``dictionary``, adds ``edge`` to the + dictionary and sets its value to ``[elem]``. Otherwise + appends ``elem`` to the value of existing entry. + + Note that edges are undirected, thus `(A, B) = (B, A)`. + """ + if edge in dictionary: + dictionary[edge].append(elem) + else: + dictionary[edge] = [elem] + + @staticmethod + def _build_skeleton(morphisms): + """ + Creates a dictionary which maps edges to corresponding + morphisms. Thus for a morphism `f:A\rightarrow B`, the edge + `(A, B)` will be associated with `f`. This function also adds + to the list those edges which are formed by juxtaposition of + two edges already in the list. These new edges are not + associated with any morphism and are only added to assure that + the diagram can be decomposed into triangles. + """ + edges = {} + # Create edges for morphisms. + for morphism in morphisms: + DiagramGrid._add_edge_append( + edges, frozenset([morphism.domain, morphism.codomain]), morphism) + + # Create new edges by juxtaposing existing edges. + edges1 = dict(edges) + for w in edges1: + for v in edges1: + wv = DiagramGrid._juxtapose_edges(w, v) + if wv and wv not in edges: + edges[wv] = [] + + return edges + + @staticmethod + def _list_triangles(edges): + """ + Builds the set of triangles formed by the supplied edges. The + triangles are arbitrary and need not be commutative. A + triangle is a set that contains all three of its sides. + """ + triangles = set() + + for w in edges: + for v in edges: + wv = DiagramGrid._juxtapose_edges(w, v) + if wv and wv in edges: + triangles.add(frozenset([w, v, wv])) + + return triangles + + @staticmethod + def _drop_redundant_triangles(triangles, skeleton): + """ + Returns a list which contains only those triangles who have + morphisms associated with at least two edges. + """ + return [tri for tri in triangles + if len([e for e in tri if skeleton[e]]) >= 2] + + @staticmethod + def _morphism_length(morphism): + """ + Returns the length of a morphism. The length of a morphism is + the number of components it consists of. A non-composite + morphism is of length 1. + """ + if isinstance(morphism, CompositeMorphism): + return len(morphism.components) + else: + return 1 + + @staticmethod + def _compute_triangle_min_sizes(triangles, edges): + r""" + Returns a dictionary mapping triangles to their minimal sizes. + The minimal size of a triangle is the sum of maximal lengths + of morphisms associated to the sides of the triangle. The + length of a morphism is the number of components it consists + of. A non-composite morphism is of length 1. + + Sorting triangles by this metric attempts to address two + aspects of layout. For triangles with only simple morphisms + in the edge, this assures that triangles with all three edges + visible will get typeset after triangles with less visible + edges, which sometimes minimizes the necessity in diagonal + arrows. For triangles with composite morphisms in the edges, + this assures that objects connected with shorter morphisms + will be laid out first, resulting the visual proximity of + those objects which are connected by shorter morphisms. + """ + triangle_sizes = {} + for triangle in triangles: + size = 0 + for e in triangle: + morphisms = edges[e] + if morphisms: + size += max(DiagramGrid._morphism_length(m) + for m in morphisms) + triangle_sizes[triangle] = size + return triangle_sizes + + @staticmethod + def _triangle_objects(triangle): + """ + Given a triangle, returns the objects included in it. + """ + # A triangle is a frozenset of three two-element frozensets + # (the edges). This chains the three edges together and + # creates a frozenset from the iterator, thus producing a + # frozenset of objects of the triangle. + return frozenset(chain(*tuple(triangle))) + + @staticmethod + def _other_vertex(triangle, edge): + """ + Given a triangle and an edge of it, returns the vertex which + opposes the edge. + """ + # This gets the set of objects of the triangle and then + # subtracts the set of objects employed in ``edge`` to get the + # vertex opposite to ``edge``. + return list(DiagramGrid._triangle_objects(triangle) - set(edge))[0] + + @staticmethod + def _empty_point(pt, grid): + """ + Checks if the cell at coordinates ``pt`` is either empty or + out of the bounds of the grid. + """ + if (pt[0] < 0) or (pt[1] < 0) or \ + (pt[0] >= grid.height) or (pt[1] >= grid.width): + return True + return grid[pt] is None + + @staticmethod + def _put_object(coords, obj, grid, fringe): + """ + Places an object at the coordinate ``cords`` in ``grid``, + growing the grid and updating ``fringe``, if necessary. + Returns (0, 0) if no row or column has been prepended, (1, 0) + if a row was prepended, (0, 1) if a column was prepended and + (1, 1) if both a column and a row were prepended. + """ + (i, j) = coords + offset = (0, 0) + if i == -1: + grid.prepend_row() + i = 0 + offset = (1, 0) + for k in range(len(fringe)): + ((i1, j1), (i2, j2)) = fringe[k] + fringe[k] = ((i1 + 1, j1), (i2 + 1, j2)) + elif i == grid.height: + grid.append_row() + + if j == -1: + j = 0 + offset = (offset[0], 1) + grid.prepend_column() + for k in range(len(fringe)): + ((i1, j1), (i2, j2)) = fringe[k] + fringe[k] = ((i1, j1 + 1), (i2, j2 + 1)) + elif j == grid.width: + grid.append_column() + + grid[i, j] = obj + return offset + + @staticmethod + def _choose_target_cell(pt1, pt2, edge, obj, skeleton, grid): + """ + Given two points, ``pt1`` and ``pt2``, and the welding edge + ``edge``, chooses one of the two points to place the opposing + vertex ``obj`` of the triangle. If neither of this points + fits, returns ``None``. + """ + pt1_empty = DiagramGrid._empty_point(pt1, grid) + pt2_empty = DiagramGrid._empty_point(pt2, grid) + + if pt1_empty and pt2_empty: + # Both cells are empty. Of these two, choose that cell + # which will assure that a visible edge of the triangle + # will be drawn perpendicularly to the current welding + # edge. + + A = grid[edge[0]] + + if skeleton.get(frozenset([A, obj])): + return pt1 + else: + return pt2 + if pt1_empty: + return pt1 + elif pt2_empty: + return pt2 + else: + return None + + @staticmethod + def _find_triangle_to_weld(triangles, fringe, grid): + """ + Finds, if possible, a triangle and an edge in the ``fringe`` to + which the triangle could be attached. Returns the tuple + containing the triangle and the index of the corresponding + edge in the ``fringe``. + + This function relies on the fact that objects are unique in + the diagram. + """ + for triangle in triangles: + for (a, b) in fringe: + if frozenset([grid[a], grid[b]]) in triangle: + return (triangle, (a, b)) + return None + + @staticmethod + def _weld_triangle(tri, welding_edge, fringe, grid, skeleton): + """ + If possible, welds the triangle ``tri`` to ``fringe`` and + returns ``False``. If this method encounters a degenerate + situation in the fringe and corrects it such that a restart of + the search is required, it returns ``True`` (which means that + a restart in finding triangle weldings is required). + + A degenerate situation is a situation when an edge listed in + the fringe does not belong to the visual boundary of the + diagram. + """ + a, b = welding_edge + target_cell = None + + obj = DiagramGrid._other_vertex(tri, (grid[a], grid[b])) + + # We now have a triangle and an edge where it can be welded to + # the fringe. Decide where to place the other vertex of the + # triangle and check for degenerate situations en route. + + if (abs(a[0] - b[0]) == 1) and (abs(a[1] - b[1]) == 1): + # A diagonal edge. + target_cell = (a[0], b[1]) + if grid[target_cell]: + # That cell is already occupied. + target_cell = (b[0], a[1]) + + if grid[target_cell]: + # Degenerate situation, this edge is not + # on the actual fringe. Correct the + # fringe and go on. + fringe.remove((a, b)) + return True + elif a[0] == b[0]: + # A horizontal edge. We first attempt to build the + # triangle in the downward direction. + + down_left = a[0] + 1, a[1] + down_right = a[0] + 1, b[1] + + target_cell = DiagramGrid._choose_target_cell( + down_left, down_right, (a, b), obj, skeleton, grid) + + if not target_cell: + # No room below this edge. Check above. + up_left = a[0] - 1, a[1] + up_right = a[0] - 1, b[1] + + target_cell = DiagramGrid._choose_target_cell( + up_left, up_right, (a, b), obj, skeleton, grid) + + if not target_cell: + # This edge is not in the fringe, remove it + # and restart. + fringe.remove((a, b)) + return True + elif a[1] == b[1]: + # A vertical edge. We will attempt to place the other + # vertex of the triangle to the right of this edge. + right_up = a[0], a[1] + 1 + right_down = b[0], a[1] + 1 + + target_cell = DiagramGrid._choose_target_cell( + right_up, right_down, (a, b), obj, skeleton, grid) + + if not target_cell: + # No room to the left. See what's to the right. + left_up = a[0], a[1] - 1 + left_down = b[0], a[1] - 1 + + target_cell = DiagramGrid._choose_target_cell( + left_up, left_down, (a, b), obj, skeleton, grid) + + if not target_cell: + # This edge is not in the fringe, remove it + # and restart. + fringe.remove((a, b)) + return True + + # We now know where to place the other vertex of the + # triangle. + offset = DiagramGrid._put_object(target_cell, obj, grid, fringe) + + # Take care of the displacement of coordinates if a row or + # a column was prepended. + target_cell = (target_cell[0] + offset[0], + target_cell[1] + offset[1]) + a = (a[0] + offset[0], a[1] + offset[1]) + b = (b[0] + offset[0], b[1] + offset[1]) + + fringe.extend([(a, target_cell), (b, target_cell)]) + + # No restart is required. + return False + + @staticmethod + def _triangle_key(tri, triangle_sizes): + """ + Returns a key for the supplied triangle. It should be the + same independently of the hash randomisation. + """ + objects = sorted( + DiagramGrid._triangle_objects(tri), key=default_sort_key) + return (triangle_sizes[tri], default_sort_key(objects)) + + @staticmethod + def _pick_root_edge(tri, skeleton): + """ + For a given triangle always picks the same root edge. The + root edge is the edge that will be placed first on the grid. + """ + candidates = [sorted(e, key=default_sort_key) + for e in tri if skeleton[e]] + sorted_candidates = sorted(candidates, key=default_sort_key) + # Don't forget to assure the proper ordering of the vertices + # in this edge. + return tuple(sorted(sorted_candidates[0], key=default_sort_key)) + + @staticmethod + def _drop_irrelevant_triangles(triangles, placed_objects): + """ + Returns only those triangles whose set of objects is not + completely included in ``placed_objects``. + """ + return [tri for tri in triangles if not placed_objects.issuperset( + DiagramGrid._triangle_objects(tri))] + + @staticmethod + def _grow_pseudopod(triangles, fringe, grid, skeleton, placed_objects): + """ + Starting from an object in the existing structure on the ``grid``, + adds an edge to which a triangle from ``triangles`` could be + welded. If this method has found a way to do so, it returns + the object it has just added. + + This method should be applied when ``_weld_triangle`` cannot + find weldings any more. + """ + for i in range(grid.height): + for j in range(grid.width): + obj = grid[i, j] + if not obj: + continue + + # Here we need to choose a triangle which has only + # ``obj`` in common with the existing structure. The + # situations when this is not possible should be + # handled elsewhere. + + def good_triangle(tri): + objs = DiagramGrid._triangle_objects(tri) + return obj in objs and \ + placed_objects & (objs - {obj}) == set() + + tris = [tri for tri in triangles if good_triangle(tri)] + if not tris: + # This object is not interesting. + continue + + # Pick the "simplest" of the triangles which could be + # attached. Remember that the list of triangles is + # sorted according to their "simplicity" (see + # _compute_triangle_min_sizes for the metric). + # + # Note that ``tris`` are sequentially built from + # ``triangles``, so we don't have to worry about hash + # randomisation. + tri = tris[0] + + # We have found a triangle which could be attached to + # the existing structure by a vertex. + + candidates = sorted([e for e in tri if skeleton[e]], + key=lambda e: FiniteSet(*e).sort_key()) + edges = [e for e in candidates if obj in e] + + # Note that a meaningful edge (i.e., and edge that is + # associated with a morphism) containing ``obj`` + # always exists. That's because all triangles are + # guaranteed to have at least two meaningful edges. + # See _drop_redundant_triangles. + + # Get the object at the other end of the edge. + edge = edges[0] + other_obj = tuple(edge - frozenset([obj]))[0] + + # Now check for free directions. When checking for + # free directions, prefer the horizontal and vertical + # directions. + neighbours = [(i - 1, j), (i, j + 1), (i + 1, j), (i, j - 1), + (i - 1, j - 1), (i - 1, j + 1), (i + 1, j - 1), (i + 1, j + 1)] + + for pt in neighbours: + if DiagramGrid._empty_point(pt, grid): + # We have a found a place to grow the + # pseudopod into. + offset = DiagramGrid._put_object( + pt, other_obj, grid, fringe) + + i += offset[0] + j += offset[1] + pt = (pt[0] + offset[0], pt[1] + offset[1]) + fringe.append(((i, j), pt)) + + return other_obj + + # This diagram is actually cooler that I can handle. Fail cowardly. + return None + + @staticmethod + def _handle_groups(diagram, groups, merged_morphisms, hints): + """ + Given the slightly preprocessed morphisms of the diagram, + produces a grid laid out according to ``groups``. + + If a group has hints, it is laid out with those hints only, + without any influence from ``hints``. Otherwise, it is laid + out with ``hints``. + """ + def lay_out_group(group, local_hints): + """ + If ``group`` is a set of objects, uses a ``DiagramGrid`` + to lay it out and returns the grid. Otherwise returns the + object (i.e., ``group``). If ``local_hints`` is not + empty, it is supplied to ``DiagramGrid`` as the dictionary + of hints. Otherwise, the ``hints`` argument of + ``_handle_groups`` is used. + """ + if isinstance(group, FiniteSet): + # Set up the corresponding object-to-group + # mappings. + for obj in group: + obj_groups[obj] = group + + # Lay out the current group. + if local_hints: + groups_grids[group] = DiagramGrid( + diagram.subdiagram_from_objects(group), **local_hints) + else: + groups_grids[group] = DiagramGrid( + diagram.subdiagram_from_objects(group), **hints) + else: + obj_groups[group] = group + + def group_to_finiteset(group): + """ + Converts ``group`` to a :class:``FiniteSet`` if it is an + iterable. + """ + if iterable(group): + return FiniteSet(*group) + else: + return group + + obj_groups = {} + groups_grids = {} + + # We would like to support various containers to represent + # groups. To achieve that, before laying each group out, it + # should be converted to a FiniteSet, because that is what the + # following code expects. + + if isinstance(groups, (dict, Dict)): + finiteset_groups = {} + for group, local_hints in groups.items(): + finiteset_group = group_to_finiteset(group) + finiteset_groups[finiteset_group] = local_hints + lay_out_group(group, local_hints) + groups = finiteset_groups + else: + finiteset_groups = [] + for group in groups: + finiteset_group = group_to_finiteset(group) + finiteset_groups.append(finiteset_group) + lay_out_group(finiteset_group, None) + groups = finiteset_groups + + new_morphisms = [] + for morphism in merged_morphisms: + dom = obj_groups[morphism.domain] + cod = obj_groups[morphism.codomain] + # Note that we are not really interested in morphisms + # which do not employ two different groups, because + # these do not influence the layout. + if dom != cod: + # These are essentially unnamed morphisms; they are + # not going to mess in the final layout. By giving + # them the same names, we avoid unnecessary + # duplicates. + new_morphisms.append(NamedMorphism(dom, cod, "dummy")) + + # Lay out the new diagram. Since these are dummy morphisms, + # properties and conclusions are irrelevant. + top_grid = DiagramGrid(Diagram(new_morphisms)) + + # We now have to substitute the groups with the corresponding + # grids, laid out at the beginning of this function. Compute + # the size of each row and column in the grid, so that all + # nested grids fit. + + def group_size(group): + """ + For the supplied group (or object, eventually), returns + the size of the cell that will hold this group (object). + """ + if group in groups_grids: + grid = groups_grids[group] + return (grid.height, grid.width) + else: + return (1, 1) + + row_heights = [max(group_size(top_grid[i, j])[0] + for j in range(top_grid.width)) + for i in range(top_grid.height)] + + column_widths = [max(group_size(top_grid[i, j])[1] + for i in range(top_grid.height)) + for j in range(top_grid.width)] + + grid = _GrowableGrid(sum(column_widths), sum(row_heights)) + + real_row = 0 + real_column = 0 + for logical_row in range(top_grid.height): + for logical_column in range(top_grid.width): + obj = top_grid[logical_row, logical_column] + + if obj in groups_grids: + # This is a group. Copy the corresponding grid in + # place. + local_grid = groups_grids[obj] + for i in range(local_grid.height): + for j in range(local_grid.width): + grid[real_row + i, + real_column + j] = local_grid[i, j] + else: + # This is an object. Just put it there. + grid[real_row, real_column] = obj + + real_column += column_widths[logical_column] + real_column = 0 + real_row += row_heights[logical_row] + + return grid + + @staticmethod + def _generic_layout(diagram, merged_morphisms): + """ + Produces the generic layout for the supplied diagram. + """ + all_objects = set(diagram.objects) + if len(all_objects) == 1: + # There only one object in the diagram, just put in on 1x1 + # grid. + grid = _GrowableGrid(1, 1) + grid[0, 0] = tuple(all_objects)[0] + return grid + + skeleton = DiagramGrid._build_skeleton(merged_morphisms) + + grid = _GrowableGrid(2, 1) + + if len(skeleton) == 1: + # This diagram contains only one morphism. Draw it + # horizontally. + objects = sorted(all_objects, key=default_sort_key) + grid[0, 0] = objects[0] + grid[0, 1] = objects[1] + + return grid + + triangles = DiagramGrid._list_triangles(skeleton) + triangles = DiagramGrid._drop_redundant_triangles(triangles, skeleton) + triangle_sizes = DiagramGrid._compute_triangle_min_sizes( + triangles, skeleton) + + triangles = sorted(triangles, key=lambda tri: + DiagramGrid._triangle_key(tri, triangle_sizes)) + + # Place the first edge on the grid. + root_edge = DiagramGrid._pick_root_edge(triangles[0], skeleton) + grid[0, 0], grid[0, 1] = root_edge + fringe = [((0, 0), (0, 1))] + + # Record which objects we now have on the grid. + placed_objects = set(root_edge) + + while placed_objects != all_objects: + welding = DiagramGrid._find_triangle_to_weld( + triangles, fringe, grid) + + if welding: + (triangle, welding_edge) = welding + + restart_required = DiagramGrid._weld_triangle( + triangle, welding_edge, fringe, grid, skeleton) + if restart_required: + continue + + placed_objects.update( + DiagramGrid._triangle_objects(triangle)) + else: + # No more weldings found. Try to attach triangles by + # vertices. + new_obj = DiagramGrid._grow_pseudopod( + triangles, fringe, grid, skeleton, placed_objects) + + if not new_obj: + # No more triangles can be attached, not even by + # the edge. We will set up a new diagram out of + # what has been left, laid it out independently, + # and then attach it to this one. + + remaining_objects = all_objects - placed_objects + + remaining_diagram = diagram.subdiagram_from_objects( + FiniteSet(*remaining_objects)) + remaining_grid = DiagramGrid(remaining_diagram) + + # Now, let's glue ``remaining_grid`` to ``grid``. + final_width = grid.width + remaining_grid.width + final_height = max(grid.height, remaining_grid.height) + final_grid = _GrowableGrid(final_width, final_height) + + for i in range(grid.width): + for j in range(grid.height): + final_grid[i, j] = grid[i, j] + + start_j = grid.width + for i in range(remaining_grid.height): + for j in range(remaining_grid.width): + final_grid[i, start_j + j] = remaining_grid[i, j] + + return final_grid + + placed_objects.add(new_obj) + + triangles = DiagramGrid._drop_irrelevant_triangles( + triangles, placed_objects) + + return grid + + @staticmethod + def _get_undirected_graph(objects, merged_morphisms): + """ + Given the objects and the relevant morphisms of a diagram, + returns the adjacency lists of the underlying undirected + graph. + """ + adjlists = {obj: [] for obj in objects} + + for morphism in merged_morphisms: + adjlists[morphism.domain].append(morphism.codomain) + adjlists[morphism.codomain].append(morphism.domain) + + # Assure that the objects in the adjacency list are always in + # the same order. + for obj in adjlists.keys(): + adjlists[obj].sort(key=default_sort_key) + + return adjlists + + @staticmethod + def _sequential_layout(diagram, merged_morphisms): + r""" + Lays out the diagram in "sequential" layout. This method + will attempt to produce a result as close to a line as + possible. For linear diagrams, the result will actually be a + line. + """ + objects = diagram.objects + sorted_objects = sorted(objects, key=default_sort_key) + + # Set up the adjacency lists of the underlying undirected + # graph of ``merged_morphisms``. + adjlists = DiagramGrid._get_undirected_graph(objects, merged_morphisms) + + root = min(sorted_objects, key=lambda x: len(adjlists[x])) + grid = _GrowableGrid(1, 1) + grid[0, 0] = root + + placed_objects = {root} + + def place_objects(pt, placed_objects): + """ + Does depth-first search in the underlying graph of the + diagram and places the objects en route. + """ + # We will start placing new objects from here. + new_pt = (pt[0], pt[1] + 1) + + for adjacent_obj in adjlists[grid[pt]]: + if adjacent_obj in placed_objects: + # This object has already been placed. + continue + + DiagramGrid._put_object(new_pt, adjacent_obj, grid, []) + placed_objects.add(adjacent_obj) + placed_objects.update(place_objects(new_pt, placed_objects)) + + new_pt = (new_pt[0] + 1, new_pt[1]) + + return placed_objects + + place_objects((0, 0), placed_objects) + + return grid + + @staticmethod + def _drop_inessential_morphisms(merged_morphisms): + r""" + Removes those morphisms which should appear in the diagram, + but which have no relevance to object layout. + + Currently this removes "loop" morphisms: the non-identity + morphisms with the same domains and codomains. + """ + morphisms = [m for m in merged_morphisms if m.domain != m.codomain] + return morphisms + + @staticmethod + def _get_connected_components(objects, merged_morphisms): + """ + Given a container of morphisms, returns a list of connected + components formed by these morphisms. A connected component + is represented by a diagram consisting of the corresponding + morphisms. + """ + component_index = {} + for o in objects: + component_index[o] = None + + # Get the underlying undirected graph of the diagram. + adjlist = DiagramGrid._get_undirected_graph(objects, merged_morphisms) + + def traverse_component(object, current_index): + """ + Does a depth-first search traversal of the component + containing ``object``. + """ + component_index[object] = current_index + for o in adjlist[object]: + if component_index[o] is None: + traverse_component(o, current_index) + + # Traverse all components. + current_index = 0 + for o in adjlist: + if component_index[o] is None: + traverse_component(o, current_index) + current_index += 1 + + # List the objects of the components. + component_objects = [[] for i in range(current_index)] + for o, idx in component_index.items(): + component_objects[idx].append(o) + + # Finally, list the morphisms belonging to each component. + # + # Note: If some objects are isolated, they will not get any + # morphisms at this stage, and since the layout algorithm + # relies, we are essentially going to lose this object. + # Therefore, check if there are isolated objects and, for each + # of them, provide the trivial identity morphism. It will get + # discarded later, but the object will be there. + + component_morphisms = [] + for component in component_objects: + current_morphisms = {} + for m in merged_morphisms: + if (m.domain in component) and (m.codomain in component): + current_morphisms[m] = merged_morphisms[m] + + if len(component) == 1: + # Let's add an identity morphism, for the sake of + # surely having morphisms in this component. + current_morphisms[IdentityMorphism(component[0])] = FiniteSet() + + component_morphisms.append(Diagram(current_morphisms)) + + return component_morphisms + + def __init__(self, diagram, groups=None, **hints): + premises = DiagramGrid._simplify_morphisms(diagram.premises) + conclusions = DiagramGrid._simplify_morphisms(diagram.conclusions) + all_merged_morphisms = DiagramGrid._merge_premises_conclusions( + premises, conclusions) + merged_morphisms = DiagramGrid._drop_inessential_morphisms( + all_merged_morphisms) + + # Store the merged morphisms for later use. + self._morphisms = all_merged_morphisms + + components = DiagramGrid._get_connected_components( + diagram.objects, all_merged_morphisms) + + if groups and (groups != diagram.objects): + # Lay out the diagram according to the groups. + self._grid = DiagramGrid._handle_groups( + diagram, groups, merged_morphisms, hints) + elif len(components) > 1: + # Note that we check for connectedness _before_ checking + # the layout hints because the layout strategies don't + # know how to deal with disconnected diagrams. + + # The diagram is disconnected. Lay out the components + # independently. + grids = [] + + # Sort the components to eventually get the grids arranged + # in a fixed, hash-independent order. + components = sorted(components, key=default_sort_key) + + for component in components: + grid = DiagramGrid(component, **hints) + grids.append(grid) + + # Throw the grids together, in a line. + total_width = sum(g.width for g in grids) + total_height = max(g.height for g in grids) + + grid = _GrowableGrid(total_width, total_height) + start_j = 0 + for g in grids: + for i in range(g.height): + for j in range(g.width): + grid[i, start_j + j] = g[i, j] + + start_j += g.width + + self._grid = grid + elif "layout" in hints: + if hints["layout"] == "sequential": + self._grid = DiagramGrid._sequential_layout( + diagram, merged_morphisms) + else: + self._grid = DiagramGrid._generic_layout(diagram, merged_morphisms) + + if hints.get("transpose"): + # Transpose the resulting grid. + grid = _GrowableGrid(self._grid.height, self._grid.width) + for i in range(self._grid.height): + for j in range(self._grid.width): + grid[j, i] = self._grid[i, j] + self._grid = grid + + @property + def width(self): + """ + Returns the number of columns in this diagram layout. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import Diagram, DiagramGrid + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g]) + >>> grid = DiagramGrid(diagram) + >>> grid.width + 2 + + """ + return self._grid.width + + @property + def height(self): + """ + Returns the number of rows in this diagram layout. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import Diagram, DiagramGrid + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g]) + >>> grid = DiagramGrid(diagram) + >>> grid.height + 2 + + """ + return self._grid.height + + def __getitem__(self, i_j): + """ + Returns the object placed in the row ``i`` and column ``j``. + The indices are 0-based. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import Diagram, DiagramGrid + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g]) + >>> grid = DiagramGrid(diagram) + >>> (grid[0, 0], grid[0, 1]) + (Object("A"), Object("B")) + >>> (grid[1, 0], grid[1, 1]) + (None, Object("C")) + + """ + i, j = i_j + return self._grid[i, j] + + @property + def morphisms(self): + """ + Returns those morphisms (and their properties) which are + sufficiently meaningful to be drawn. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import Diagram, DiagramGrid + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g]) + >>> grid = DiagramGrid(diagram) + >>> grid.morphisms + {NamedMorphism(Object("A"), Object("B"), "f"): EmptySet, + NamedMorphism(Object("B"), Object("C"), "g"): EmptySet} + + """ + return self._morphisms + + def __str__(self): + """ + Produces a string representation of this class. + + This method returns a string representation of the underlying + list of lists of objects. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism + >>> from sympy.categories import Diagram, DiagramGrid + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g]) + >>> grid = DiagramGrid(diagram) + >>> print(grid) + [[Object("A"), Object("B")], + [None, Object("C")]] + + """ + return repr(self._grid._array) + + +class ArrowStringDescription: + r""" + Stores the information necessary for producing an Xy-pic + description of an arrow. + + The principal goal of this class is to abstract away the string + representation of an arrow and to also provide the functionality + to produce the actual Xy-pic string. + + ``unit`` sets the unit which will be used to specify the amount of + curving and other distances. ``horizontal_direction`` should be a + string of ``"r"`` or ``"l"`` specifying the horizontal offset of the + target cell of the arrow relatively to the current one. + ``vertical_direction`` should specify the vertical offset using a + series of either ``"d"`` or ``"u"``. ``label_position`` should be + either ``"^"``, ``"_"``, or ``"|"`` to specify that the label should + be positioned above the arrow, below the arrow or just over the arrow, + in a break. Note that the notions "above" and "below" are relative + to arrow direction. ``label`` stores the morphism label. + + This works as follows (disregard the yet unexplained arguments): + + >>> from sympy.categories.diagram_drawing import ArrowStringDescription + >>> astr = ArrowStringDescription( + ... unit="mm", curving=None, curving_amount=None, + ... looping_start=None, looping_end=None, horizontal_direction="d", + ... vertical_direction="r", label_position="_", label="f") + >>> print(str(astr)) + \ar[dr]_{f} + + ``curving`` should be one of ``"^"``, ``"_"`` to specify in which + direction the arrow is going to curve. ``curving_amount`` is a number + describing how many ``unit``'s the morphism is going to curve: + + >>> astr = ArrowStringDescription( + ... unit="mm", curving="^", curving_amount=12, + ... looping_start=None, looping_end=None, horizontal_direction="d", + ... vertical_direction="r", label_position="_", label="f") + >>> print(str(astr)) + \ar@/^12mm/[dr]_{f} + + ``looping_start`` and ``looping_end`` are currently only used for + loop morphisms, those which have the same domain and codomain. + These two attributes should store a valid Xy-pic direction and + specify, correspondingly, the direction the arrow gets out into + and the direction the arrow gets back from: + + >>> astr = ArrowStringDescription( + ... unit="mm", curving=None, curving_amount=None, + ... looping_start="u", looping_end="l", horizontal_direction="", + ... vertical_direction="", label_position="_", label="f") + >>> print(str(astr)) + \ar@(u,l)[]_{f} + + ``label_displacement`` controls how far the arrow label is from + the ends of the arrow. For example, to position the arrow label + near the arrow head, use ">": + + >>> astr = ArrowStringDescription( + ... unit="mm", curving="^", curving_amount=12, + ... looping_start=None, looping_end=None, horizontal_direction="d", + ... vertical_direction="r", label_position="_", label="f") + >>> astr.label_displacement = ">" + >>> print(str(astr)) + \ar@/^12mm/[dr]_>{f} + + Finally, ``arrow_style`` is used to specify the arrow style. To + get a dashed arrow, for example, use "{-->}" as arrow style: + + >>> astr = ArrowStringDescription( + ... unit="mm", curving="^", curving_amount=12, + ... looping_start=None, looping_end=None, horizontal_direction="d", + ... vertical_direction="r", label_position="_", label="f") + >>> astr.arrow_style = "{-->}" + >>> print(str(astr)) + \ar@/^12mm/@{-->}[dr]_{f} + + Notes + ===== + + Instances of :class:`ArrowStringDescription` will be constructed + by :class:`XypicDiagramDrawer` and provided for further use in + formatters. The user is not expected to construct instances of + :class:`ArrowStringDescription` themselves. + + To be able to properly utilise this class, the reader is encouraged + to checkout the Xy-pic user guide, available at [Xypic]. + + See Also + ======== + + XypicDiagramDrawer + + References + ========== + + .. [Xypic] https://xy-pic.sourceforge.net/ + """ + def __init__(self, unit, curving, curving_amount, looping_start, + looping_end, horizontal_direction, vertical_direction, + label_position, label): + self.unit = unit + self.curving = curving + self.curving_amount = curving_amount + self.looping_start = looping_start + self.looping_end = looping_end + self.horizontal_direction = horizontal_direction + self.vertical_direction = vertical_direction + self.label_position = label_position + self.label = label + + self.label_displacement = "" + self.arrow_style = "" + + # This flag shows that the position of the label of this + # morphism was set while typesetting a curved morphism and + # should not be modified later. + self.forced_label_position = False + + def __str__(self): + if self.curving: + curving_str = "@/%s%d%s/" % (self.curving, self.curving_amount, + self.unit) + else: + curving_str = "" + + if self.looping_start and self.looping_end: + looping_str = "@(%s,%s)" % (self.looping_start, self.looping_end) + else: + looping_str = "" + + if self.arrow_style: + + style_str = "@" + self.arrow_style + else: + style_str = "" + + return "\\ar%s%s%s[%s%s]%s%s{%s}" % \ + (curving_str, looping_str, style_str, self.horizontal_direction, + self.vertical_direction, self.label_position, + self.label_displacement, self.label) + + +class XypicDiagramDrawer: + r""" + Given a :class:`~.Diagram` and the corresponding + :class:`DiagramGrid`, produces the Xy-pic representation of the + diagram. + + The most important method in this class is ``draw``. Consider the + following triangle diagram: + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy.categories import DiagramGrid, XypicDiagramDrawer + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g], {g * f: "unique"}) + + To draw this diagram, its objects need to be laid out with a + :class:`DiagramGrid`:: + + >>> grid = DiagramGrid(diagram) + + Finally, the drawing: + + >>> drawer = XypicDiagramDrawer() + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar[d]_{g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\ + C & + } + + For further details see the docstring of this method. + + To control the appearance of the arrows, formatters are used. The + dictionary ``arrow_formatters`` maps morphisms to formatter + functions. A formatter is accepts an + :class:`ArrowStringDescription` and is allowed to modify any of + the arrow properties exposed thereby. For example, to have all + morphisms with the property ``unique`` appear as dashed arrows, + and to have their names prepended with `\exists !`, the following + should be done: + + >>> def formatter(astr): + ... astr.label = r"\exists !" + astr.label + ... astr.arrow_style = "{-->}" + >>> drawer.arrow_formatters["unique"] = formatter + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar@{-->}[d]_{\exists !g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\ + C & + } + + To modify the appearance of all arrows in the diagram, set + ``default_arrow_formatter``. For example, to place all morphism + labels a little bit farther from the arrow head so that they look + more centred, do as follows: + + >>> def default_formatter(astr): + ... astr.label_displacement = "(0.45)" + >>> drawer.default_arrow_formatter = default_formatter + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar@{-->}[d]_(0.45){\exists !g\circ f} \ar[r]^(0.45){f} & B \ar[ld]^(0.45){g} \\ + C & + } + + In some diagrams some morphisms are drawn as curved arrows. + Consider the following diagram: + + >>> D = Object("D") + >>> E = Object("E") + >>> h = NamedMorphism(D, A, "h") + >>> k = NamedMorphism(D, B, "k") + >>> diagram = Diagram([f, g, h, k]) + >>> grid = DiagramGrid(diagram) + >>> drawer = XypicDiagramDrawer() + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar[r]_{f} & B \ar[d]^{g} & D \ar[l]^{k} \ar@/_3mm/[ll]_{h} \\ + & C & + } + + To control how far the morphisms are curved by default, one can + use the ``unit`` and ``default_curving_amount`` attributes: + + >>> drawer.unit = "cm" + >>> drawer.default_curving_amount = 1 + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar[r]_{f} & B \ar[d]^{g} & D \ar[l]^{k} \ar@/_1cm/[ll]_{h} \\ + & C & + } + + In some diagrams, there are multiple curved morphisms between the + same two objects. To control by how much the curving changes + between two such successive morphisms, use + ``default_curving_step``: + + >>> drawer.default_curving_step = 1 + >>> h1 = NamedMorphism(A, D, "h1") + >>> diagram = Diagram([f, g, h, k, h1]) + >>> grid = DiagramGrid(diagram) + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar[r]_{f} \ar@/^1cm/[rr]^{h_{1}} & B \ar[d]^{g} & D \ar[l]^{k} \ar@/_2cm/[ll]_{h} \\ + & C & + } + + The default value of ``default_curving_step`` is 4 units. + + See Also + ======== + + draw, ArrowStringDescription + """ + def __init__(self): + self.unit = "mm" + self.default_curving_amount = 3 + self.default_curving_step = 4 + + # This dictionary maps properties to the corresponding arrow + # formatters. + self.arrow_formatters = {} + + # This is the default arrow formatter which will be applied to + # each arrow independently of its properties. + self.default_arrow_formatter = None + + @staticmethod + def _process_loop_morphism(i, j, grid, morphisms_str_info, object_coords): + """ + Produces the information required for constructing the string + representation of a loop morphism. This function is invoked + from ``_process_morphism``. + + See Also + ======== + + _process_morphism + """ + curving = "" + label_pos = "^" + looping_start = "" + looping_end = "" + + # This is a loop morphism. Count how many morphisms stick + # in each of the four quadrants. Note that straight + # vertical and horizontal morphisms count in two quadrants + # at the same time (i.e., a morphism going up counts both + # in the first and the second quadrants). + + # The usual numbering (counterclockwise) of quadrants + # applies. + quadrant = [0, 0, 0, 0] + + obj = grid[i, j] + + for m, m_str_info in morphisms_str_info.items(): + if (m.domain == obj) and (m.codomain == obj): + # That's another loop morphism. Check how it + # loops and mark the corresponding quadrants as + # busy. + (l_s, l_e) = (m_str_info.looping_start, m_str_info.looping_end) + + if (l_s, l_e) == ("r", "u"): + quadrant[0] += 1 + elif (l_s, l_e) == ("u", "l"): + quadrant[1] += 1 + elif (l_s, l_e) == ("l", "d"): + quadrant[2] += 1 + elif (l_s, l_e) == ("d", "r"): + quadrant[3] += 1 + + continue + if m.domain == obj: + (end_i, end_j) = object_coords[m.codomain] + goes_out = True + elif m.codomain == obj: + (end_i, end_j) = object_coords[m.domain] + goes_out = False + else: + continue + + d_i = end_i - i + d_j = end_j - j + m_curving = m_str_info.curving + + if (d_i != 0) and (d_j != 0): + # This is really a diagonal morphism. Detect the + # quadrant. + if (d_i > 0) and (d_j > 0): + quadrant[0] += 1 + elif (d_i > 0) and (d_j < 0): + quadrant[1] += 1 + elif (d_i < 0) and (d_j < 0): + quadrant[2] += 1 + elif (d_i < 0) and (d_j > 0): + quadrant[3] += 1 + elif d_i == 0: + # Knowing where the other end of the morphism is + # and which way it goes, we now have to decide + # which quadrant is now the upper one and which is + # the lower one. + if d_j > 0: + if goes_out: + upper_quadrant = 0 + lower_quadrant = 3 + else: + upper_quadrant = 3 + lower_quadrant = 0 + else: + if goes_out: + upper_quadrant = 2 + lower_quadrant = 1 + else: + upper_quadrant = 1 + lower_quadrant = 2 + + if m_curving: + if m_curving == "^": + quadrant[upper_quadrant] += 1 + elif m_curving == "_": + quadrant[lower_quadrant] += 1 + else: + # This morphism counts in both upper and lower + # quadrants. + quadrant[upper_quadrant] += 1 + quadrant[lower_quadrant] += 1 + elif d_j == 0: + # Knowing where the other end of the morphism is + # and which way it goes, we now have to decide + # which quadrant is now the left one and which is + # the right one. + if d_i < 0: + if goes_out: + left_quadrant = 1 + right_quadrant = 0 + else: + left_quadrant = 0 + right_quadrant = 1 + else: + if goes_out: + left_quadrant = 3 + right_quadrant = 2 + else: + left_quadrant = 2 + right_quadrant = 3 + + if m_curving: + if m_curving == "^": + quadrant[left_quadrant] += 1 + elif m_curving == "_": + quadrant[right_quadrant] += 1 + else: + # This morphism counts in both upper and lower + # quadrants. + quadrant[left_quadrant] += 1 + quadrant[right_quadrant] += 1 + + # Pick the freest quadrant to curve our morphism into. + freest_quadrant = 0 + for i in range(4): + if quadrant[i] < quadrant[freest_quadrant]: + freest_quadrant = i + + # Now set up proper looping. + (looping_start, looping_end) = [("r", "u"), ("u", "l"), ("l", "d"), + ("d", "r")][freest_quadrant] + + return (curving, label_pos, looping_start, looping_end) + + @staticmethod + def _process_horizontal_morphism(i, j, target_j, grid, morphisms_str_info, + object_coords): + """ + Produces the information required for constructing the string + representation of a horizontal morphism. This function is + invoked from ``_process_morphism``. + + See Also + ======== + + _process_morphism + """ + # The arrow is horizontal. Check if it goes from left to + # right (``backwards == False``) or from right to left + # (``backwards == True``). + backwards = False + start = j + end = target_j + if end < start: + (start, end) = (end, start) + backwards = True + + # Let's see which objects are there between ``start`` and + # ``end``, and then count how many morphisms stick out + # upwards, and how many stick out downwards. + # + # For example, consider the situation: + # + # B1 C1 + # | | + # A--B--C--D + # | + # B2 + # + # Between the objects `A` and `D` there are two objects: + # `B` and `C`. Further, there are two morphisms which + # stick out upward (the ones between `B1` and `B` and + # between `C` and `C1`) and one morphism which sticks out + # downward (the one between `B and `B2`). + # + # We need this information to decide how to curve the + # arrow between `A` and `D`. First of all, since there + # are two objects between `A` and `D``, we must curve the + # arrow. Then, we will have it curve downward, because + # there is more space (less morphisms stick out downward + # than upward). + up = [] + down = [] + straight_horizontal = [] + for k in range(start + 1, end): + obj = grid[i, k] + if not obj: + continue + + for m in morphisms_str_info: + if m.domain == obj: + (end_i, end_j) = object_coords[m.codomain] + elif m.codomain == obj: + (end_i, end_j) = object_coords[m.domain] + else: + continue + + if end_i > i: + down.append(m) + elif end_i < i: + up.append(m) + elif not morphisms_str_info[m].curving: + # This is a straight horizontal morphism, + # because it has no curving. + straight_horizontal.append(m) + + if len(up) < len(down): + # More morphisms stick out downward than upward, let's + # curve the morphism up. + if backwards: + curving = "_" + label_pos = "_" + else: + curving = "^" + label_pos = "^" + + # Assure that the straight horizontal morphisms have + # their labels on the lower side of the arrow. + for m in straight_horizontal: + (i1, j1) = object_coords[m.domain] + (i2, j2) = object_coords[m.codomain] + + m_str_info = morphisms_str_info[m] + if j1 < j2: + m_str_info.label_position = "_" + else: + m_str_info.label_position = "^" + + # Don't allow any further modifications of the + # position of this label. + m_str_info.forced_label_position = True + else: + # More morphisms stick out downward than upward, let's + # curve the morphism up. + if backwards: + curving = "^" + label_pos = "^" + else: + curving = "_" + label_pos = "_" + + # Assure that the straight horizontal morphisms have + # their labels on the upper side of the arrow. + for m in straight_horizontal: + (i1, j1) = object_coords[m.domain] + (i2, j2) = object_coords[m.codomain] + + m_str_info = morphisms_str_info[m] + if j1 < j2: + m_str_info.label_position = "^" + else: + m_str_info.label_position = "_" + + # Don't allow any further modifications of the + # position of this label. + m_str_info.forced_label_position = True + + return (curving, label_pos) + + @staticmethod + def _process_vertical_morphism(i, j, target_i, grid, morphisms_str_info, + object_coords): + """ + Produces the information required for constructing the string + representation of a vertical morphism. This function is + invoked from ``_process_morphism``. + + See Also + ======== + + _process_morphism + """ + # This arrow is vertical. Check if it goes from top to + # bottom (``backwards == False``) or from bottom to top + # (``backwards == True``). + backwards = False + start = i + end = target_i + if end < start: + (start, end) = (end, start) + backwards = True + + # Let's see which objects are there between ``start`` and + # ``end``, and then count how many morphisms stick out to + # the left, and how many stick out to the right. + # + # See the corresponding comment in the previous branch of + # this if-statement for more details. + left = [] + right = [] + straight_vertical = [] + for k in range(start + 1, end): + obj = grid[k, j] + if not obj: + continue + + for m in morphisms_str_info: + if m.domain == obj: + (end_i, end_j) = object_coords[m.codomain] + elif m.codomain == obj: + (end_i, end_j) = object_coords[m.domain] + else: + continue + + if end_j > j: + right.append(m) + elif end_j < j: + left.append(m) + elif not morphisms_str_info[m].curving: + # This is a straight vertical morphism, + # because it has no curving. + straight_vertical.append(m) + + if len(left) < len(right): + # More morphisms stick out to the left than to the + # right, let's curve the morphism to the right. + if backwards: + curving = "^" + label_pos = "^" + else: + curving = "_" + label_pos = "_" + + # Assure that the straight vertical morphisms have + # their labels on the left side of the arrow. + for m in straight_vertical: + (i1, j1) = object_coords[m.domain] + (i2, j2) = object_coords[m.codomain] + + m_str_info = morphisms_str_info[m] + if i1 < i2: + m_str_info.label_position = "^" + else: + m_str_info.label_position = "_" + + # Don't allow any further modifications of the + # position of this label. + m_str_info.forced_label_position = True + else: + # More morphisms stick out to the right than to the + # left, let's curve the morphism to the left. + if backwards: + curving = "_" + label_pos = "_" + else: + curving = "^" + label_pos = "^" + + # Assure that the straight vertical morphisms have + # their labels on the right side of the arrow. + for m in straight_vertical: + (i1, j1) = object_coords[m.domain] + (i2, j2) = object_coords[m.codomain] + + m_str_info = morphisms_str_info[m] + if i1 < i2: + m_str_info.label_position = "_" + else: + m_str_info.label_position = "^" + + # Don't allow any further modifications of the + # position of this label. + m_str_info.forced_label_position = True + + return (curving, label_pos) + + def _process_morphism(self, diagram, grid, morphism, object_coords, + morphisms, morphisms_str_info): + """ + Given the required information, produces the string + representation of ``morphism``. + """ + def repeat_string_cond(times, str_gt, str_lt): + """ + If ``times > 0``, repeats ``str_gt`` ``times`` times. + Otherwise, repeats ``str_lt`` ``-times`` times. + """ + if times > 0: + return str_gt * times + else: + return str_lt * (-times) + + def count_morphisms_undirected(A, B): + """ + Counts how many processed morphisms there are between the + two supplied objects. + """ + return len([m for m in morphisms_str_info + if {m.domain, m.codomain} == {A, B}]) + + def count_morphisms_filtered(dom, cod, curving): + """ + Counts the processed morphisms which go out of ``dom`` + into ``cod`` with curving ``curving``. + """ + return len([m for m, m_str_info in morphisms_str_info.items() + if (m.domain, m.codomain) == (dom, cod) and + (m_str_info.curving == curving)]) + + (i, j) = object_coords[morphism.domain] + (target_i, target_j) = object_coords[morphism.codomain] + + # We now need to determine the direction of + # the arrow. + delta_i = target_i - i + delta_j = target_j - j + vertical_direction = repeat_string_cond(delta_i, + "d", "u") + horizontal_direction = repeat_string_cond(delta_j, + "r", "l") + + curving = "" + label_pos = "^" + looping_start = "" + looping_end = "" + + if (delta_i == 0) and (delta_j == 0): + # This is a loop morphism. + (curving, label_pos, looping_start, + looping_end) = XypicDiagramDrawer._process_loop_morphism( + i, j, grid, morphisms_str_info, object_coords) + elif (delta_i == 0) and (abs(j - target_j) > 1): + # This is a horizontal morphism. + (curving, label_pos) = XypicDiagramDrawer._process_horizontal_morphism( + i, j, target_j, grid, morphisms_str_info, object_coords) + elif (delta_j == 0) and (abs(i - target_i) > 1): + # This is a vertical morphism. + (curving, label_pos) = XypicDiagramDrawer._process_vertical_morphism( + i, j, target_i, grid, morphisms_str_info, object_coords) + + count = count_morphisms_undirected(morphism.domain, morphism.codomain) + curving_amount = "" + if curving: + # This morphisms should be curved anyway. + curving_amount = self.default_curving_amount + count * \ + self.default_curving_step + elif count: + # There are no objects between the domain and codomain of + # the current morphism, but this is not there already are + # some morphisms with the same domain and codomain, so we + # have to curve this one. + curving = "^" + filtered_morphisms = count_morphisms_filtered( + morphism.domain, morphism.codomain, curving) + curving_amount = self.default_curving_amount + \ + filtered_morphisms * \ + self.default_curving_step + + # Let's now get the name of the morphism. + morphism_name = "" + if isinstance(morphism, IdentityMorphism): + morphism_name = "id_{%s}" + latex(grid[i, j]) + elif isinstance(morphism, CompositeMorphism): + component_names = [latex(Symbol(component.name)) for + component in morphism.components] + component_names.reverse() + morphism_name = "\\circ ".join(component_names) + elif isinstance(morphism, NamedMorphism): + morphism_name = latex(Symbol(morphism.name)) + + return ArrowStringDescription( + self.unit, curving, curving_amount, looping_start, + looping_end, horizontal_direction, vertical_direction, + label_pos, morphism_name) + + @staticmethod + def _check_free_space_horizontal(dom_i, dom_j, cod_j, grid): + """ + For a horizontal morphism, checks whether there is free space + (i.e., space not occupied by any objects) above the morphism + or below it. + """ + if dom_j < cod_j: + (start, end) = (dom_j, cod_j) + backwards = False + else: + (start, end) = (cod_j, dom_j) + backwards = True + + # Check for free space above. + if dom_i == 0: + free_up = True + else: + free_up = all(grid[dom_i - 1, j] for j in + range(start, end + 1)) + + # Check for free space below. + if dom_i == grid.height - 1: + free_down = True + else: + free_down = not any(grid[dom_i + 1, j] for j in + range(start, end + 1)) + + return (free_up, free_down, backwards) + + @staticmethod + def _check_free_space_vertical(dom_i, cod_i, dom_j, grid): + """ + For a vertical morphism, checks whether there is free space + (i.e., space not occupied by any objects) to the left of the + morphism or to the right of it. + """ + if dom_i < cod_i: + (start, end) = (dom_i, cod_i) + backwards = False + else: + (start, end) = (cod_i, dom_i) + backwards = True + + # Check if there's space to the left. + if dom_j == 0: + free_left = True + else: + free_left = not any(grid[i, dom_j - 1] for i in + range(start, end + 1)) + + if dom_j == grid.width - 1: + free_right = True + else: + free_right = not any(grid[i, dom_j + 1] for i in + range(start, end + 1)) + + return (free_left, free_right, backwards) + + @staticmethod + def _check_free_space_diagonal(dom_i, cod_i, dom_j, cod_j, grid): + """ + For a diagonal morphism, checks whether there is free space + (i.e., space not occupied by any objects) above the morphism + or below it. + """ + def abs_xrange(start, end): + if start < end: + return range(start, end + 1) + else: + return range(end, start + 1) + + if dom_i < cod_i and dom_j < cod_j: + # This morphism goes from top-left to + # bottom-right. + (start_i, start_j) = (dom_i, dom_j) + (end_i, end_j) = (cod_i, cod_j) + backwards = False + elif dom_i > cod_i and dom_j > cod_j: + # This morphism goes from bottom-right to + # top-left. + (start_i, start_j) = (cod_i, cod_j) + (end_i, end_j) = (dom_i, dom_j) + backwards = True + if dom_i < cod_i and dom_j > cod_j: + # This morphism goes from top-right to + # bottom-left. + (start_i, start_j) = (dom_i, dom_j) + (end_i, end_j) = (cod_i, cod_j) + backwards = True + elif dom_i > cod_i and dom_j < cod_j: + # This morphism goes from bottom-left to + # top-right. + (start_i, start_j) = (cod_i, cod_j) + (end_i, end_j) = (dom_i, dom_j) + backwards = False + + # This is an attempt at a fast and furious strategy to + # decide where there is free space on the two sides of + # a diagonal morphism. For a diagonal morphism + # starting at ``(start_i, start_j)`` and ending at + # ``(end_i, end_j)`` the rectangle defined by these + # two points is considered. The slope of the diagonal + # ``alpha`` is then computed. Then, for every cell + # ``(i, j)`` within the rectangle, the slope + # ``alpha1`` of the line through ``(start_i, + # start_j)`` and ``(i, j)`` is considered. If + # ``alpha1`` is between 0 and ``alpha``, the point + # ``(i, j)`` is above the diagonal, if ``alpha1`` is + # between ``alpha`` and infinity, the point is below + # the diagonal. Also note that, with some beforehand + # precautions, this trick works for both the main and + # the secondary diagonals of the rectangle. + + # I have considered the possibility to only follow the + # shorter diagonals immediately above and below the + # main (or secondary) diagonal. This, however, + # wouldn't have resulted in much performance gain or + # better detection of outer edges, because of + # relatively small sizes of diagram grids, while the + # code would have become harder to understand. + + alpha = float(end_i - start_i)/(end_j - start_j) + free_up = True + free_down = True + for i in abs_xrange(start_i, end_i): + if not free_up and not free_down: + break + + for j in abs_xrange(start_j, end_j): + if not free_up and not free_down: + break + + if (i, j) == (start_i, start_j): + continue + + if j == start_j: + alpha1 = "inf" + else: + alpha1 = float(i - start_i)/(j - start_j) + + if grid[i, j]: + if (alpha1 == "inf") or (abs(alpha1) > abs(alpha)): + free_down = False + elif abs(alpha1) < abs(alpha): + free_up = False + + return (free_up, free_down, backwards) + + def _push_labels_out(self, morphisms_str_info, grid, object_coords): + """ + For all straight morphisms which form the visual boundary of + the laid out diagram, puts their labels on their outer sides. + """ + def set_label_position(free1, free2, pos1, pos2, backwards, m_str_info): + """ + Given the information about room available to one side and + to the other side of a morphism (``free1`` and ``free2``), + sets the position of the morphism label in such a way that + it is on the freer side. This latter operations involves + choice between ``pos1`` and ``pos2``, taking ``backwards`` + in consideration. + + Thus this function will do nothing if either both ``free1 + == True`` and ``free2 == True`` or both ``free1 == False`` + and ``free2 == False``. In either case, choosing one side + over the other presents no advantage. + """ + if backwards: + (pos1, pos2) = (pos2, pos1) + + if free1 and not free2: + m_str_info.label_position = pos1 + elif free2 and not free1: + m_str_info.label_position = pos2 + + for m, m_str_info in morphisms_str_info.items(): + if m_str_info.curving or m_str_info.forced_label_position: + # This is either a curved morphism, and curved + # morphisms have other magic, or the position of this + # label has already been fixed. + continue + + if m.domain == m.codomain: + # This is a loop morphism, their labels, again have a + # different magic. + continue + + (dom_i, dom_j) = object_coords[m.domain] + (cod_i, cod_j) = object_coords[m.codomain] + + if dom_i == cod_i: + # Horizontal morphism. + (free_up, free_down, + backwards) = XypicDiagramDrawer._check_free_space_horizontal( + dom_i, dom_j, cod_j, grid) + + set_label_position(free_up, free_down, "^", "_", + backwards, m_str_info) + elif dom_j == cod_j: + # Vertical morphism. + (free_left, free_right, + backwards) = XypicDiagramDrawer._check_free_space_vertical( + dom_i, cod_i, dom_j, grid) + + set_label_position(free_left, free_right, "_", "^", + backwards, m_str_info) + else: + # A diagonal morphism. + (free_up, free_down, + backwards) = XypicDiagramDrawer._check_free_space_diagonal( + dom_i, cod_i, dom_j, cod_j, grid) + + set_label_position(free_up, free_down, "^", "_", + backwards, m_str_info) + + @staticmethod + def _morphism_sort_key(morphism, object_coords): + """ + Provides a morphism sorting key such that horizontal or + vertical morphisms between neighbouring objects come + first, then horizontal or vertical morphisms between more + far away objects, and finally, all other morphisms. + """ + (i, j) = object_coords[morphism.domain] + (target_i, target_j) = object_coords[morphism.codomain] + + if morphism.domain == morphism.codomain: + # Loop morphisms should get after diagonal morphisms + # so that the proper direction in which to curve the + # loop can be determined. + return (3, 0, default_sort_key(morphism)) + + if target_i == i: + return (1, abs(target_j - j), default_sort_key(morphism)) + + if target_j == j: + return (1, abs(target_i - i), default_sort_key(morphism)) + + # Diagonal morphism. + return (2, 0, default_sort_key(morphism)) + + @staticmethod + def _build_xypic_string(diagram, grid, morphisms, + morphisms_str_info, diagram_format): + """ + Given a collection of :class:`ArrowStringDescription` + describing the morphisms of a diagram and the object layout + information of a diagram, produces the final Xy-pic picture. + """ + # Build the mapping between objects and morphisms which have + # them as domains. + object_morphisms = {} + for obj in diagram.objects: + object_morphisms[obj] = [] + for morphism in morphisms: + object_morphisms[morphism.domain].append(morphism) + + result = "\\xymatrix%s{\n" % diagram_format + + for i in range(grid.height): + for j in range(grid.width): + obj = grid[i, j] + if obj: + result += latex(obj) + " " + + morphisms_to_draw = object_morphisms[obj] + for morphism in morphisms_to_draw: + result += str(morphisms_str_info[morphism]) + " " + + # Don't put the & after the last column. + if j < grid.width - 1: + result += "& " + + # Don't put the line break after the last row. + if i < grid.height - 1: + result += "\\\\" + result += "\n" + + result += "}\n" + + return result + + def draw(self, diagram, grid, masked=None, diagram_format=""): + r""" + Returns the Xy-pic representation of ``diagram`` laid out in + ``grid``. + + Consider the following simple triangle diagram. + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy.categories import DiagramGrid, XypicDiagramDrawer + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g], {g * f: "unique"}) + + To draw this diagram, its objects need to be laid out with a + :class:`DiagramGrid`:: + + >>> grid = DiagramGrid(diagram) + + Finally, the drawing: + + >>> drawer = XypicDiagramDrawer() + >>> print(drawer.draw(diagram, grid)) + \xymatrix{ + A \ar[d]_{g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\ + C & + } + + The argument ``masked`` can be used to skip morphisms in the + presentation of the diagram: + + >>> print(drawer.draw(diagram, grid, masked=[g * f])) + \xymatrix{ + A \ar[r]^{f} & B \ar[ld]^{g} \\ + C & + } + + Finally, the ``diagram_format`` argument can be used to + specify the format string of the diagram. For example, to + increase the spacing by 1 cm, proceeding as follows: + + >>> print(drawer.draw(diagram, grid, diagram_format="@+1cm")) + \xymatrix@+1cm{ + A \ar[d]_{g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\ + C & + } + + """ + # This method works in several steps. It starts by removing + # the masked morphisms, if necessary, and then maps objects to + # their positions in the grid (coordinate tuples). Remember + # that objects are unique in ``Diagram`` and in the layout + # produced by ``DiagramGrid``, so every object is mapped to a + # single coordinate pair. + # + # The next step is the central step and is concerned with + # analysing the morphisms of the diagram and deciding how to + # draw them. For example, how to curve the arrows is decided + # at this step. The bulk of the analysis is implemented in + # ``_process_morphism``, to the result of which the + # appropriate formatters are applied. + # + # The result of the previous step is a list of + # ``ArrowStringDescription``. After the analysis and + # application of formatters, some extra logic tries to assure + # better positioning of morphism labels (for example, an + # attempt is made to avoid the situations when arrows cross + # labels). This functionality constitutes the next step and + # is implemented in ``_push_labels_out``. Note that label + # positions which have been set via a formatter are not + # affected in this step. + # + # Finally, at the closing step, the array of + # ``ArrowStringDescription`` and the layout information + # incorporated in ``DiagramGrid`` are combined to produce the + # resulting Xy-pic picture. This part of code lies in + # ``_build_xypic_string``. + + if not masked: + morphisms_props = grid.morphisms + else: + morphisms_props = {} + for m, props in grid.morphisms.items(): + if m in masked: + continue + morphisms_props[m] = props + + # Build the mapping between objects and their position in the + # grid. + object_coords = {} + for i in range(grid.height): + for j in range(grid.width): + if grid[i, j]: + object_coords[grid[i, j]] = (i, j) + + morphisms = sorted(morphisms_props, + key=lambda m: XypicDiagramDrawer._morphism_sort_key( + m, object_coords)) + + # Build the tuples defining the string representations of + # morphisms. + morphisms_str_info = {} + for morphism in morphisms: + string_description = self._process_morphism( + diagram, grid, morphism, object_coords, morphisms, + morphisms_str_info) + + if self.default_arrow_formatter: + self.default_arrow_formatter(string_description) + + for prop in morphisms_props[morphism]: + # prop is a Symbol. TODO: Find out why. + if prop.name in self.arrow_formatters: + formatter = self.arrow_formatters[prop.name] + formatter(string_description) + + morphisms_str_info[morphism] = string_description + + # Reposition the labels a bit. + self._push_labels_out(morphisms_str_info, grid, object_coords) + + return XypicDiagramDrawer._build_xypic_string( + diagram, grid, morphisms, morphisms_str_info, diagram_format) + + +def xypic_draw_diagram(diagram, masked=None, diagram_format="", + groups=None, **hints): + r""" + Provides a shortcut combining :class:`DiagramGrid` and + :class:`XypicDiagramDrawer`. Returns an Xy-pic presentation of + ``diagram``. The argument ``masked`` is a list of morphisms which + will be not be drawn. The argument ``diagram_format`` is the + format string inserted after "\xymatrix". ``groups`` should be a + set of logical groups. The ``hints`` will be passed directly to + the constructor of :class:`DiagramGrid`. + + For more information about the arguments, see the docstrings of + :class:`DiagramGrid` and ``XypicDiagramDrawer.draw``. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy.categories import xypic_draw_diagram + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> diagram = Diagram([f, g], {g * f: "unique"}) + >>> print(xypic_draw_diagram(diagram)) + \xymatrix{ + A \ar[d]_{g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\ + C & + } + + See Also + ======== + + XypicDiagramDrawer, DiagramGrid + """ + grid = DiagramGrid(diagram, groups, **hints) + drawer = XypicDiagramDrawer() + return drawer.draw(diagram, grid, masked, diagram_format) + + +@doctest_depends_on(exe=('latex', 'dvipng'), modules=('pyglet',)) +def preview_diagram(diagram, masked=None, diagram_format="", groups=None, + output='png', viewer=None, euler=True, **hints): + """ + Combines the functionality of ``xypic_draw_diagram`` and + ``sympy.printing.preview``. The arguments ``masked``, + ``diagram_format``, ``groups``, and ``hints`` are passed to + ``xypic_draw_diagram``, while ``output``, ``viewer, and ``euler`` + are passed to ``preview``. + + Examples + ======== + + >>> from sympy.categories import Object, NamedMorphism, Diagram + >>> from sympy.categories import preview_diagram + >>> A = Object("A") + >>> B = Object("B") + >>> C = Object("C") + >>> f = NamedMorphism(A, B, "f") + >>> g = NamedMorphism(B, C, "g") + >>> d = Diagram([f, g], {g * f: "unique"}) + >>> preview_diagram(d) + + See Also + ======== + + XypicDiagramDrawer + """ + from sympy.printing import preview + latex_output = xypic_draw_diagram(diagram, masked, diagram_format, + groups, **hints) + preview(latex_output, output, viewer, euler, ("xypic",)) diff --git a/.venv/lib/python3.13/site-packages/sympy/categories/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/categories/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/categories/tests/test_baseclasses.py b/.venv/lib/python3.13/site-packages/sympy/categories/tests/test_baseclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..cfac32229768fb5903b23b11ffb236912c0b931e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/categories/tests/test_baseclasses.py @@ -0,0 +1,209 @@ +from sympy.categories import (Object, Morphism, IdentityMorphism, + NamedMorphism, CompositeMorphism, + Diagram, Category) +from sympy.categories.baseclasses import Class +from sympy.testing.pytest import raises +from sympy.core.containers import (Dict, Tuple) +from sympy.sets import EmptySet +from sympy.sets.sets import FiniteSet + + +def test_morphisms(): + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + + # Test the base morphism. + f = NamedMorphism(A, B, "f") + assert f.domain == A + assert f.codomain == B + assert f == NamedMorphism(A, B, "f") + + # Test identities. + id_A = IdentityMorphism(A) + id_B = IdentityMorphism(B) + assert id_A.domain == A + assert id_A.codomain == A + assert id_A == IdentityMorphism(A) + assert id_A != id_B + + # Test named morphisms. + g = NamedMorphism(B, C, "g") + assert g.name == "g" + assert g != f + assert g == NamedMorphism(B, C, "g") + assert g != NamedMorphism(B, C, "f") + + # Test composite morphisms. + assert f == CompositeMorphism(f) + + k = g.compose(f) + assert k.domain == A + assert k.codomain == C + assert k.components == Tuple(f, g) + assert g * f == k + assert CompositeMorphism(f, g) == k + + assert CompositeMorphism(g * f) == g * f + + # Test the associativity of composition. + h = NamedMorphism(C, D, "h") + + p = h * g + u = h * g * f + + assert h * k == u + assert p * f == u + assert CompositeMorphism(f, g, h) == u + + # Test flattening. + u2 = u.flatten("u") + assert isinstance(u2, NamedMorphism) + assert u2.name == "u" + assert u2.domain == A + assert u2.codomain == D + + # Test identities. + assert f * id_A == f + assert id_B * f == f + assert id_A * id_A == id_A + assert CompositeMorphism(id_A) == id_A + + # Test bad compositions. + raises(ValueError, lambda: f * g) + + raises(TypeError, lambda: f.compose(None)) + raises(TypeError, lambda: id_A.compose(None)) + raises(TypeError, lambda: f * None) + raises(TypeError, lambda: id_A * None) + + raises(TypeError, lambda: CompositeMorphism(f, None, 1)) + + raises(ValueError, lambda: NamedMorphism(A, B, "")) + raises(NotImplementedError, lambda: Morphism(A, B)) + + +def test_diagram(): + A = Object("A") + B = Object("B") + C = Object("C") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + id_A = IdentityMorphism(A) + id_B = IdentityMorphism(B) + + empty = EmptySet + + # Test the addition of identities. + d1 = Diagram([f]) + + assert d1.objects == FiniteSet(A, B) + assert d1.hom(A, B) == (FiniteSet(f), empty) + assert d1.hom(A, A) == (FiniteSet(id_A), empty) + assert d1.hom(B, B) == (FiniteSet(id_B), empty) + + assert d1 == Diagram([id_A, f]) + assert d1 == Diagram([f, f]) + + # Test the addition of composites. + d2 = Diagram([f, g]) + homAC = d2.hom(A, C)[0] + + assert d2.objects == FiniteSet(A, B, C) + assert g * f in d2.premises.keys() + assert homAC == FiniteSet(g * f) + + # Test equality, inequality and hash. + d11 = Diagram([f]) + + assert d1 == d11 + assert d1 != d2 + assert hash(d1) == hash(d11) + + d11 = Diagram({f: "unique"}) + assert d1 != d11 + + # Make sure that (re-)adding composites (with new properties) + # works as expected. + d = Diagram([f, g], {g * f: "unique"}) + assert d.conclusions == Dict({g * f: FiniteSet("unique")}) + + # Check the hom-sets when there are premises and conclusions. + assert d.hom(A, C) == (FiniteSet(g * f), FiniteSet(g * f)) + d = Diagram([f, g], [g * f]) + assert d.hom(A, C) == (FiniteSet(g * f), FiniteSet(g * f)) + + # Check how the properties of composite morphisms are computed. + d = Diagram({f: ["unique", "isomorphism"], g: "unique"}) + assert d.premises[g * f] == FiniteSet("unique") + + # Check that conclusion morphisms with new objects are not allowed. + d = Diagram([f], [g]) + assert d.conclusions == Dict({}) + + # Test an empty diagram. + d = Diagram() + assert d.premises == Dict({}) + assert d.conclusions == Dict({}) + assert d.objects == empty + + # Check a SymPy Dict object. + d = Diagram(Dict({f: FiniteSet("unique", "isomorphism"), g: "unique"})) + assert d.premises[g * f] == FiniteSet("unique") + + # Check the addition of components of composite morphisms. + d = Diagram([g * f]) + assert f in d.premises + assert g in d.premises + + # Check subdiagrams. + d = Diagram([f, g], {g * f: "unique"}) + + d1 = Diagram([f]) + assert d.is_subdiagram(d1) + assert not d1.is_subdiagram(d) + + d = Diagram([NamedMorphism(B, A, "f'")]) + assert not d.is_subdiagram(d1) + assert not d1.is_subdiagram(d) + + d1 = Diagram([f, g], {g * f: ["unique", "something"]}) + assert not d.is_subdiagram(d1) + assert not d1.is_subdiagram(d) + + d = Diagram({f: "blooh"}) + d1 = Diagram({f: "bleeh"}) + assert not d.is_subdiagram(d1) + assert not d1.is_subdiagram(d) + + d = Diagram([f, g], {f: "unique", g * f: "veryunique"}) + d1 = d.subdiagram_from_objects(FiniteSet(A, B)) + assert d1 == Diagram([f], {f: "unique"}) + raises(ValueError, lambda: d.subdiagram_from_objects(FiniteSet(A, + Object("D")))) + + raises(ValueError, lambda: Diagram({IdentityMorphism(A): "unique"})) + + +def test_category(): + A = Object("A") + B = Object("B") + C = Object("C") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + + d1 = Diagram([f, g]) + d2 = Diagram([f]) + + objects = d1.objects | d2.objects + + K = Category("K", objects, commutative_diagrams=[d1, d2]) + + assert K.name == "K" + assert K.objects == Class(objects) + assert K.commutative_diagrams == FiniteSet(d1, d2) + + raises(ValueError, lambda: Category("")) diff --git a/.venv/lib/python3.13/site-packages/sympy/categories/tests/test_drawing.py b/.venv/lib/python3.13/site-packages/sympy/categories/tests/test_drawing.py new file mode 100644 index 0000000000000000000000000000000000000000..63a13266cd6b58f6a85aad4af0813b395acbb5e1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/categories/tests/test_drawing.py @@ -0,0 +1,919 @@ +from sympy.categories.diagram_drawing import _GrowableGrid, ArrowStringDescription +from sympy.categories import (DiagramGrid, Object, NamedMorphism, + Diagram, XypicDiagramDrawer, xypic_draw_diagram) +from sympy.sets.sets import FiniteSet + + +def test_GrowableGrid(): + grid = _GrowableGrid(1, 2) + + # Check dimensions. + assert grid.width == 1 + assert grid.height == 2 + + # Check initialization of elements. + assert grid[0, 0] is None + assert grid[1, 0] is None + + # Check assignment to elements. + grid[0, 0] = 1 + grid[1, 0] = "two" + + assert grid[0, 0] == 1 + assert grid[1, 0] == "two" + + # Check appending a row. + grid.append_row() + + assert grid.width == 1 + assert grid.height == 3 + + assert grid[0, 0] == 1 + assert grid[1, 0] == "two" + assert grid[2, 0] is None + + # Check appending a column. + grid.append_column() + assert grid.width == 2 + assert grid.height == 3 + + assert grid[0, 0] == 1 + assert grid[1, 0] == "two" + assert grid[2, 0] is None + + assert grid[0, 1] is None + assert grid[1, 1] is None + assert grid[2, 1] is None + + grid = _GrowableGrid(1, 2) + grid[0, 0] = 1 + grid[1, 0] = "two" + + # Check prepending a row. + grid.prepend_row() + assert grid.width == 1 + assert grid.height == 3 + + assert grid[0, 0] is None + assert grid[1, 0] == 1 + assert grid[2, 0] == "two" + + # Check prepending a column. + grid.prepend_column() + assert grid.width == 2 + assert grid.height == 3 + + assert grid[0, 0] is None + assert grid[1, 0] is None + assert grid[2, 0] is None + + assert grid[0, 1] is None + assert grid[1, 1] == 1 + assert grid[2, 1] == "two" + + +def test_DiagramGrid(): + # Set up some objects and morphisms. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(D, A, "h") + k = NamedMorphism(D, B, "k") + + # A one-morphism diagram. + d = Diagram([f]) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 1 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid.morphisms == {f: FiniteSet()} + + # A triangle. + d = Diagram([f, g], {g * f: "unique"}) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[1, 0] == C + assert grid[1, 1] is None + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), + g * f: FiniteSet("unique")} + + # A triangle with a "loop" morphism. + l_A = NamedMorphism(A, A, "l_A") + d = Diagram([f, g, l_A]) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), l_A: FiniteSet()} + + # A simple diagram. + d = Diagram([f, g, h, k]) + grid = DiagramGrid(d) + + assert grid.width == 3 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == D + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] is None + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + k: FiniteSet()} + + assert str(grid) == '[[Object("A"), Object("B"), Object("D")], ' \ + '[None, Object("C"), None]]' + + # A chain of morphisms. + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + k = NamedMorphism(D, E, "k") + d = Diagram([f, g, h, k]) + grid = DiagramGrid(d) + + assert grid.width == 3 + assert grid.height == 3 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] is None + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] == D + assert grid[2, 0] is None + assert grid[2, 1] is None + assert grid[2, 2] == E + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + k: FiniteSet()} + + # A square. + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, D, "g") + h = NamedMorphism(A, C, "h") + k = NamedMorphism(C, D, "k") + d = Diagram([f, g, h, k]) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[1, 0] == C + assert grid[1, 1] == D + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + k: FiniteSet()} + + # A strange diagram which resulted from a typo when creating a + # test for five lemma, but which allowed to stop one extra problem + # in the algorithm. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + A_ = Object("A'") + B_ = Object("B'") + C_ = Object("C'") + D_ = Object("D'") + E_ = Object("E'") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + + # These 4 morphisms should be between primed objects. + j = NamedMorphism(A, B, "j") + k = NamedMorphism(B, C, "k") + l = NamedMorphism(C, D, "l") + m = NamedMorphism(D, E, "m") + + o = NamedMorphism(A, A_, "o") + p = NamedMorphism(B, B_, "p") + q = NamedMorphism(C, C_, "q") + r = NamedMorphism(D, D_, "r") + s = NamedMorphism(E, E_, "s") + + d = Diagram([f, g, h, i, j, k, l, m, o, p, q, r, s]) + grid = DiagramGrid(d) + + assert grid.width == 3 + assert grid.height == 4 + assert grid[0, 0] is None + assert grid[0, 1] == A + assert grid[0, 2] == A_ + assert grid[1, 0] == C + assert grid[1, 1] == B + assert grid[1, 2] == B_ + assert grid[2, 0] == C_ + assert grid[2, 1] == D + assert grid[2, 2] == D_ + assert grid[3, 0] is None + assert grid[3, 1] == E + assert grid[3, 2] == E_ + + morphisms = {} + for m in [f, g, h, i, j, k, l, m, o, p, q, r, s]: + morphisms[m] = FiniteSet() + assert grid.morphisms == morphisms + + # A cube. + A1 = Object("A1") + A2 = Object("A2") + A3 = Object("A3") + A4 = Object("A4") + A5 = Object("A5") + A6 = Object("A6") + A7 = Object("A7") + A8 = Object("A8") + + # The top face of the cube. + f1 = NamedMorphism(A1, A2, "f1") + f2 = NamedMorphism(A1, A3, "f2") + f3 = NamedMorphism(A2, A4, "f3") + f4 = NamedMorphism(A3, A4, "f3") + + # The bottom face of the cube. + f5 = NamedMorphism(A5, A6, "f5") + f6 = NamedMorphism(A5, A7, "f6") + f7 = NamedMorphism(A6, A8, "f7") + f8 = NamedMorphism(A7, A8, "f8") + + # The remaining morphisms. + f9 = NamedMorphism(A1, A5, "f9") + f10 = NamedMorphism(A2, A6, "f10") + f11 = NamedMorphism(A3, A7, "f11") + f12 = NamedMorphism(A4, A8, "f11") + + d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12]) + grid = DiagramGrid(d) + + assert grid.width == 4 + assert grid.height == 3 + assert grid[0, 0] is None + assert grid[0, 1] == A5 + assert grid[0, 2] == A6 + assert grid[0, 3] is None + assert grid[1, 0] is None + assert grid[1, 1] == A1 + assert grid[1, 2] == A2 + assert grid[1, 3] is None + assert grid[2, 0] == A7 + assert grid[2, 1] == A3 + assert grid[2, 2] == A4 + assert grid[2, 3] == A8 + + morphisms = {} + for m in [f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12]: + morphisms[m] = FiniteSet() + assert grid.morphisms == morphisms + + # A line diagram. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + d = Diagram([f, g, h, i]) + grid = DiagramGrid(d, layout="sequential") + + assert grid.width == 5 + assert grid.height == 1 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == C + assert grid[0, 3] == D + assert grid[0, 4] == E + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + i: FiniteSet()} + + # Test the transposed version. + grid = DiagramGrid(d, layout="sequential", transpose=True) + + assert grid.width == 1 + assert grid.height == 5 + assert grid[0, 0] == A + assert grid[1, 0] == B + assert grid[2, 0] == C + assert grid[3, 0] == D + assert grid[4, 0] == E + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + i: FiniteSet()} + + # A pullback. + m1 = NamedMorphism(A, B, "m1") + m2 = NamedMorphism(A, C, "m2") + s1 = NamedMorphism(B, D, "s1") + s2 = NamedMorphism(C, D, "s2") + f1 = NamedMorphism(E, B, "f1") + f2 = NamedMorphism(E, C, "f2") + g = NamedMorphism(E, A, "g") + + d = Diagram([m1, m2, s1, s2, f1, f2], {g: "unique"}) + grid = DiagramGrid(d) + + assert grid.width == 3 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == E + assert grid[1, 0] == C + assert grid[1, 1] == D + assert grid[1, 2] is None + + morphisms = {g: FiniteSet("unique")} + for m in [m1, m2, s1, s2, f1, f2]: + morphisms[m] = FiniteSet() + assert grid.morphisms == morphisms + + # Test the pullback with sequential layout, just for stress + # testing. + grid = DiagramGrid(d, layout="sequential") + + assert grid.width == 5 + assert grid.height == 1 + assert grid[0, 0] == D + assert grid[0, 1] == B + assert grid[0, 2] == A + assert grid[0, 3] == C + assert grid[0, 4] == E + assert grid.morphisms == morphisms + + # Test a pullback with object grouping. + grid = DiagramGrid(d, groups=FiniteSet(E, FiniteSet(A, B, C, D))) + + assert grid.width == 3 + assert grid.height == 2 + assert grid[0, 0] == E + assert grid[0, 1] == A + assert grid[0, 2] == B + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] == D + assert grid.morphisms == morphisms + + # Five lemma, actually. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + A_ = Object("A'") + B_ = Object("B'") + C_ = Object("C'") + D_ = Object("D'") + E_ = Object("E'") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + + j = NamedMorphism(A_, B_, "j") + k = NamedMorphism(B_, C_, "k") + l = NamedMorphism(C_, D_, "l") + m = NamedMorphism(D_, E_, "m") + + o = NamedMorphism(A, A_, "o") + p = NamedMorphism(B, B_, "p") + q = NamedMorphism(C, C_, "q") + r = NamedMorphism(D, D_, "r") + s = NamedMorphism(E, E_, "s") + + d = Diagram([f, g, h, i, j, k, l, m, o, p, q, r, s]) + grid = DiagramGrid(d) + + assert grid.width == 5 + assert grid.height == 3 + assert grid[0, 0] is None + assert grid[0, 1] == A + assert grid[0, 2] == A_ + assert grid[0, 3] is None + assert grid[0, 4] is None + assert grid[1, 0] == C + assert grid[1, 1] == B + assert grid[1, 2] == B_ + assert grid[1, 3] == C_ + assert grid[1, 4] is None + assert grid[2, 0] == D + assert grid[2, 1] == E + assert grid[2, 2] is None + assert grid[2, 3] == D_ + assert grid[2, 4] == E_ + + morphisms = {} + for m in [f, g, h, i, j, k, l, m, o, p, q, r, s]: + morphisms[m] = FiniteSet() + assert grid.morphisms == morphisms + + # Test the five lemma with object grouping. + grid = DiagramGrid(d, FiniteSet( + FiniteSet(A, B, C, D, E), FiniteSet(A_, B_, C_, D_, E_))) + + assert grid.width == 6 + assert grid.height == 3 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] is None + assert grid[0, 3] == A_ + assert grid[0, 4] == B_ + assert grid[0, 5] is None + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] == D + assert grid[1, 3] is None + assert grid[1, 4] == C_ + assert grid[1, 5] == D_ + assert grid[2, 0] is None + assert grid[2, 1] is None + assert grid[2, 2] == E + assert grid[2, 3] is None + assert grid[2, 4] is None + assert grid[2, 5] == E_ + assert grid.morphisms == morphisms + + # Test the five lemma with object grouping, but mixing containers + # to represent groups. + grid = DiagramGrid(d, [(A, B, C, D, E), {A_, B_, C_, D_, E_}]) + + assert grid.width == 6 + assert grid.height == 3 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] is None + assert grid[0, 3] == A_ + assert grid[0, 4] == B_ + assert grid[0, 5] is None + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] == D + assert grid[1, 3] is None + assert grid[1, 4] == C_ + assert grid[1, 5] == D_ + assert grid[2, 0] is None + assert grid[2, 1] is None + assert grid[2, 2] == E + assert grid[2, 3] is None + assert grid[2, 4] is None + assert grid[2, 5] == E_ + assert grid.morphisms == morphisms + + # Test the five lemma with object grouping and hints. + grid = DiagramGrid(d, { + FiniteSet(A, B, C, D, E): {"layout": "sequential", + "transpose": True}, + FiniteSet(A_, B_, C_, D_, E_): {"layout": "sequential", + "transpose": True}}, + transpose=True) + + assert grid.width == 5 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == C + assert grid[0, 3] == D + assert grid[0, 4] == E + assert grid[1, 0] == A_ + assert grid[1, 1] == B_ + assert grid[1, 2] == C_ + assert grid[1, 3] == D_ + assert grid[1, 4] == E_ + assert grid.morphisms == morphisms + + # A two-triangle disconnected diagram. + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + f_ = NamedMorphism(A_, B_, "f") + g_ = NamedMorphism(B_, C_, "g") + d = Diagram([f, g, f_, g_], {g * f: "unique", g_ * f_: "unique"}) + grid = DiagramGrid(d) + + assert grid.width == 4 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == A_ + assert grid[0, 3] == B_ + assert grid[1, 0] == C + assert grid[1, 1] is None + assert grid[1, 2] == C_ + assert grid[1, 3] is None + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), f_: FiniteSet(), + g_: FiniteSet(), g * f: FiniteSet("unique"), + g_ * f_: FiniteSet("unique")} + + # A two-morphism disconnected diagram. + f = NamedMorphism(A, B, "f") + g = NamedMorphism(C, D, "g") + d = Diagram([f, g]) + grid = DiagramGrid(d) + + assert grid.width == 4 + assert grid.height == 1 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == C + assert grid[0, 3] == D + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet()} + + # Test a one-object diagram. + f = NamedMorphism(A, A, "f") + d = Diagram([f]) + grid = DiagramGrid(d) + + assert grid.width == 1 + assert grid.height == 1 + assert grid[0, 0] == A + + # Test a two-object disconnected diagram. + g = NamedMorphism(B, B, "g") + d = Diagram([f, g]) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 1 + assert grid[0, 0] == A + assert grid[0, 1] == B + + +def test_DiagramGrid_pseudopod(): + # Test a diagram in which even growing a pseudopod does not + # eventually help. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + F = Object("F") + A_ = Object("A'") + B_ = Object("B'") + C_ = Object("C'") + D_ = Object("D'") + E_ = Object("E'") + + f1 = NamedMorphism(A, B, "f1") + f2 = NamedMorphism(A, C, "f2") + f3 = NamedMorphism(A, D, "f3") + f4 = NamedMorphism(A, E, "f4") + f5 = NamedMorphism(A, A_, "f5") + f6 = NamedMorphism(A, B_, "f6") + f7 = NamedMorphism(A, C_, "f7") + f8 = NamedMorphism(A, D_, "f8") + f9 = NamedMorphism(A, E_, "f9") + f10 = NamedMorphism(A, F, "f10") + d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10]) + grid = DiagramGrid(d) + + assert grid.width == 5 + assert grid.height == 3 + assert grid[0, 0] == E + assert grid[0, 1] == C + assert grid[0, 2] == C_ + assert grid[0, 3] == E_ + assert grid[0, 4] == F + assert grid[1, 0] == D + assert grid[1, 1] == A + assert grid[1, 2] == A_ + assert grid[1, 3] is None + assert grid[1, 4] is None + assert grid[2, 0] == D_ + assert grid[2, 1] == B + assert grid[2, 2] == B_ + assert grid[2, 3] is None + assert grid[2, 4] is None + + morphisms = {} + for f in [f1, f2, f3, f4, f5, f6, f7, f8, f9, f10]: + morphisms[f] = FiniteSet() + assert grid.morphisms == morphisms + + +def test_ArrowStringDescription(): + astr = ArrowStringDescription("cm", "", None, "", "", "d", "r", "_", "f") + assert str(astr) == "\\ar[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "", "", "d", "r", "_", "f") + assert str(astr) == "\\ar[dr]_{f}" + + astr = ArrowStringDescription("cm", "^", 12, "", "", "d", "r", "_", "f") + assert str(astr) == "\\ar@/^12cm/[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "r", "", "d", "r", "_", "f") + assert str(astr) == "\\ar[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f") + assert str(astr) == "\\ar@(r,u)[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f") + assert str(astr) == "\\ar@(r,u)[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f") + astr.arrow_style = "{-->}" + assert str(astr) == "\\ar@(r,u)@{-->}[dr]_{f}" + + astr = ArrowStringDescription("cm", "_", 12, "", "", "d", "r", "_", "f") + astr.arrow_style = "{-->}" + assert str(astr) == "\\ar@/_12cm/@{-->}[dr]_{f}" + + +def test_XypicDiagramDrawer_line(): + # A linear diagram. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + d = Diagram([f, g, h, i]) + grid = DiagramGrid(d, layout="sequential") + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]^{f} & B \\ar[r]^{g} & C \\ar[r]^{h} & D \\ar[r]^{i} & E \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, layout="sequential", transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} \\\\\n" \ + "B \\ar[d]^{g} \\\\\n" \ + "C \\ar[d]^{h} \\\\\n" \ + "D \\ar[d]^{i} \\\\\n" \ + "E \n" \ + "}\n" + + +def test_XypicDiagramDrawer_triangle(): + # A triangle diagram. + A = Object("A") + B = Object("B") + C = Object("C") + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + + d = Diagram([f, g], {g * f: "unique"}) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]_{g\\circ f} \\ar[r]^{f} & B \\ar[ld]^{g} \\\\\n" \ + "C & \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]^{g\\circ f} \\ar[d]_{f} & C \\\\\n" \ + "B \\ar[ru]_{g} & \n" \ + "}\n" + + # The same diagram, with a masked morphism. + assert drawer.draw(d, grid, masked=[g]) == "\\xymatrix{\n" \ + "A \\ar[r]^{g\\circ f} \\ar[d]_{f} & C \\\\\n" \ + "B & \n" \ + "}\n" + + # The same diagram with a formatter for "unique". + def formatter(astr): + astr.label = "\\exists !" + astr.label + astr.arrow_style = "{-->}" + + drawer.arrow_formatters["unique"] = formatter + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar@{-->}[r]^{\\exists !g\\circ f} \\ar[d]_{f} & C \\\\\n" \ + "B \\ar[ru]_{g} & \n" \ + "}\n" + + # The same diagram with a default formatter. + def default_formatter(astr): + astr.label_displacement = "(0.45)" + + drawer.default_arrow_formatter = default_formatter + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar@{-->}[r]^(0.45){\\exists !g\\circ f} \\ar[d]_(0.45){f} & C \\\\\n" \ + "B \\ar[ru]_(0.45){g} & \n" \ + "}\n" + + # A triangle diagram with a lot of morphisms between the same + # objects. + f1 = NamedMorphism(B, A, "f1") + f2 = NamedMorphism(A, B, "f2") + g1 = NamedMorphism(C, B, "g1") + g2 = NamedMorphism(B, C, "g2") + d = Diagram([f, f1, f2, g, g1, g2], {f1 * g1: "unique", g2 * f2: "unique"}) + + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid, masked=[f1*g1*g2*f2, g2*f2*f1*g1]) == \ + "\\xymatrix{\n" \ + "A \\ar[r]^{g_{2}\\circ f_{2}} \\ar[d]_{f} \\ar@/^3mm/[d]^{f_{2}} " \ + "& C \\ar@/^3mm/[l]^{f_{1}\\circ g_{1}} \\ar@/^3mm/[ld]^{g_{1}} \\\\\n" \ + "B \\ar@/^3mm/[u]^{f_{1}} \\ar[ru]_{g} \\ar@/^3mm/[ru]^{g_{2}} & \n" \ + "}\n" + + +def test_XypicDiagramDrawer_cube(): + # A cube diagram. + A1 = Object("A1") + A2 = Object("A2") + A3 = Object("A3") + A4 = Object("A4") + A5 = Object("A5") + A6 = Object("A6") + A7 = Object("A7") + A8 = Object("A8") + + # The top face of the cube. + f1 = NamedMorphism(A1, A2, "f1") + f2 = NamedMorphism(A1, A3, "f2") + f3 = NamedMorphism(A2, A4, "f3") + f4 = NamedMorphism(A3, A4, "f3") + + # The bottom face of the cube. + f5 = NamedMorphism(A5, A6, "f5") + f6 = NamedMorphism(A5, A7, "f6") + f7 = NamedMorphism(A6, A8, "f7") + f8 = NamedMorphism(A7, A8, "f8") + + # The remaining morphisms. + f9 = NamedMorphism(A1, A5, "f9") + f10 = NamedMorphism(A2, A6, "f10") + f11 = NamedMorphism(A3, A7, "f11") + f12 = NamedMorphism(A4, A8, "f11") + + d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "& A_{5} \\ar[r]^{f_{5}} \\ar[ldd]_{f_{6}} & A_{6} \\ar[rdd]^{f_{7}} " \ + "& \\\\\n" \ + "& A_{1} \\ar[r]^{f_{1}} \\ar[d]^{f_{2}} \\ar[u]^{f_{9}} & A_{2} " \ + "\\ar[d]^{f_{3}} \\ar[u]_{f_{10}} & \\\\\n" \ + "A_{7} \\ar@/_3mm/[rrr]_{f_{8}} & A_{3} \\ar[r]^{f_{3}} \\ar[l]_{f_{11}} " \ + "& A_{4} \\ar[r]^{f_{11}} & A_{8} \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "& & A_{7} \\ar@/^3mm/[ddd]^{f_{8}} \\\\\n" \ + "A_{5} \\ar[d]_{f_{5}} \\ar[rru]^{f_{6}} & A_{1} \\ar[d]^{f_{1}} " \ + "\\ar[r]^{f_{2}} \\ar[l]^{f_{9}} & A_{3} \\ar[d]_{f_{3}} " \ + "\\ar[u]^{f_{11}} \\\\\n" \ + "A_{6} \\ar[rrd]_{f_{7}} & A_{2} \\ar[r]^{f_{3}} \\ar[l]^{f_{10}} " \ + "& A_{4} \\ar[d]_{f_{11}} \\\\\n" \ + "& & A_{8} \n" \ + "}\n" + + +def test_XypicDiagramDrawer_curved_and_loops(): + # A simple diagram, with a curved arrow. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(D, A, "h") + k = NamedMorphism(D, B, "k") + d = Diagram([f, g, h, k]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]_{f} & B \\ar[d]^{g} & D \\ar[l]^{k} \\ar@/_3mm/[ll]_{h} \\\\\n" \ + "& C & \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} & \\\\\n" \ + "B \\ar[r]^{g} & C \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^3mm/[uu]^{h} & \n" \ + "}\n" + + # The same diagram, larger and rotated. + assert drawer.draw(d, grid, diagram_format="@+1cm@dr") == \ + "\\xymatrix@+1cm@dr{\n" \ + "A \\ar[d]^{f} & \\\\\n" \ + "B \\ar[r]^{g} & C \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^3mm/[uu]^{h} & \n" \ + "}\n" + + # A simple diagram with three curved arrows. + h1 = NamedMorphism(D, A, "h1") + h2 = NamedMorphism(A, D, "h2") + k = NamedMorphism(D, B, "k") + d = Diagram([f, g, h, k, h1, h2]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} & B \\ar[d]^{g} & D \\ar[l]^{k} " \ + "\\ar@/_7mm/[ll]_{h} \\ar@/_11mm/[ll]_{h_{1}} \\\\\n" \ + "& C & \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} & \\\\\n" \ + "B \\ar[r]^{g} & C \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} & \n" \ + "}\n" + + # The same diagram, with "loop" morphisms. + l_A = NamedMorphism(A, A, "l_A") + l_D = NamedMorphism(D, D, "l_D") + l_C = NamedMorphism(C, C, "l_C") + d = Diagram([f, g, h, k, h1, h2, l_A, l_D, l_C]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} \\ar@(u,l)[]^{l_{A}} " \ + "& B \\ar[d]^{g} & D \\ar[l]^{k} \\ar@/_7mm/[ll]_{h} " \ + "\\ar@/_11mm/[ll]_{h_{1}} \\ar@(r,u)[]^{l_{D}} \\\\\n" \ + "& C \\ar@(l,d)[]^{l_{C}} & \n" \ + "}\n" + + # The same diagram with "loop" morphisms, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} \\ar@(r,u)[]^{l_{A}} & \\\\\n" \ + "B \\ar[r]^{g} & C \\ar@(r,u)[]^{l_{C}} \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} " \ + "\\ar@(l,d)[]^{l_{D}} & \n" \ + "}\n" + + # The same diagram with two "loop" morphisms per object. + l_A_ = NamedMorphism(A, A, "n_A") + l_D_ = NamedMorphism(D, D, "n_D") + l_C_ = NamedMorphism(C, C, "n_C") + d = Diagram([f, g, h, k, h1, h2, l_A, l_D, l_C, l_A_, l_D_, l_C_]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} \\ar@(u,l)[]^{l_{A}} " \ + "\\ar@/^3mm/@(l,d)[]^{n_{A}} & B \\ar[d]^{g} & D \\ar[l]^{k} " \ + "\\ar@/_7mm/[ll]_{h} \\ar@/_11mm/[ll]_{h_{1}} \\ar@(r,u)[]^{l_{D}} " \ + "\\ar@/^3mm/@(d,r)[]^{n_{D}} \\\\\n" \ + "& C \\ar@(l,d)[]^{l_{C}} \\ar@/^3mm/@(d,r)[]^{n_{C}} & \n" \ + "}\n" + + # The same diagram with two "loop" morphisms per object, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} \\ar@(r,u)[]^{l_{A}} " \ + "\\ar@/^3mm/@(u,l)[]^{n_{A}} & \\\\\n" \ + "B \\ar[r]^{g} & C \\ar@(r,u)[]^{l_{C}} \\ar@/^3mm/@(d,r)[]^{n_{C}} \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} " \ + "\\ar@(l,d)[]^{l_{D}} \\ar@/^3mm/@(d,r)[]^{n_{D}} & \n" \ + "}\n" + + +def test_xypic_draw_diagram(): + # A linear diagram. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + d = Diagram([f, g, h, i]) + + grid = DiagramGrid(d, layout="sequential") + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == xypic_draw_diagram(d, layout="sequential") diff --git a/.venv/lib/python3.13/site-packages/sympy/diffgeom/__init__.py b/.venv/lib/python3.13/site-packages/sympy/diffgeom/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8846a99510601c9675103e21ef5a0a1e839fdd11 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/diffgeom/__init__.py @@ -0,0 +1,19 @@ +from .diffgeom import ( + BaseCovarDerivativeOp, BaseScalarField, BaseVectorField, Commutator, + contravariant_order, CoordSystem, CoordinateSymbol, + CovarDerivativeOp, covariant_order, Differential, intcurve_diffequ, + intcurve_series, LieDerivative, Manifold, metric_to_Christoffel_1st, + metric_to_Christoffel_2nd, metric_to_Ricci_components, + metric_to_Riemann_components, Patch, Point, TensorProduct, twoform_to_matrix, + vectors_in_basis, WedgeProduct, +) + +__all__ = [ + 'BaseCovarDerivativeOp', 'BaseScalarField', 'BaseVectorField', 'Commutator', + 'contravariant_order', 'CoordSystem', 'CoordinateSymbol', + 'CovarDerivativeOp', 'covariant_order', 'Differential', 'intcurve_diffequ', + 'intcurve_series', 'LieDerivative', 'Manifold', 'metric_to_Christoffel_1st', + 'metric_to_Christoffel_2nd', 'metric_to_Ricci_components', + 'metric_to_Riemann_components', 'Patch', 'Point', 'TensorProduct', + 'twoform_to_matrix', 'vectors_in_basis', 'WedgeProduct', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/diffgeom/diffgeom.py b/.venv/lib/python3.13/site-packages/sympy/diffgeom/diffgeom.py new file mode 100644 index 0000000000000000000000000000000000000000..a95f83122d6de0b7015b9a3ad0573cbfd97a7ef3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/diffgeom/diffgeom.py @@ -0,0 +1,2270 @@ +from __future__ import annotations +from typing import Any + +from functools import reduce +from itertools import permutations + +from sympy.combinatorics import Permutation +from sympy.core import ( + Basic, Expr, Function, diff, + Pow, Mul, Add, Lambda, S, Tuple, Dict +) +from sympy.core.cache import cacheit + +from sympy.core.symbol import Symbol, Dummy +from sympy.core.symbol import Str +from sympy.core.sympify import _sympify +from sympy.functions import factorial +from sympy.matrices import ImmutableDenseMatrix as Matrix +from sympy.solvers import solve + +from sympy.utilities.exceptions import (sympy_deprecation_warning, + SymPyDeprecationWarning, + ignore_warnings) + + +# TODO you are a bit excessive in the use of Dummies +# TODO dummy point, literal field +# TODO too often one needs to call doit or simplify on the output, check the +# tests and find out why +from sympy.tensor.array import ImmutableDenseNDimArray + + +class Manifold(Basic): + """ + A mathematical manifold. + + Explanation + =========== + + A manifold is a topological space that locally resembles + Euclidean space near each point [1]. + This class does not provide any means to study the topological + characteristics of the manifold that it represents, though. + + Parameters + ========== + + name : str + The name of the manifold. + + dim : int + The dimension of the manifold. + + Examples + ======== + + >>> from sympy.diffgeom import Manifold + >>> m = Manifold('M', 2) + >>> m + M + >>> m.dim + 2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Manifold + """ + + def __new__(cls, name, dim, **kwargs): + if not isinstance(name, Str): + name = Str(name) + dim = _sympify(dim) + obj = super().__new__(cls, name, dim) + + obj.patches = _deprecated_list( + """ + Manifold.patches is deprecated. The Manifold object is now + immutable. Instead use a separate list to keep track of the + patches. + """, []) + return obj + + @property + def name(self): + return self.args[0] + + @property + def dim(self): + return self.args[1] + + +class Patch(Basic): + """ + A patch on a manifold. + + Explanation + =========== + + Coordinate patch, or patch in short, is a simply-connected open set around + a point in the manifold [1]. On a manifold one can have many patches that + do not always include the whole manifold. On these patches coordinate + charts can be defined that permit the parameterization of any point on the + patch in terms of a tuple of real numbers (the coordinates). + + This class does not provide any means to study the topological + characteristics of the patch that it represents. + + Parameters + ========== + + name : str + The name of the patch. + + manifold : Manifold + The manifold on which the patch is defined. + + Examples + ======== + + >>> from sympy.diffgeom import Manifold, Patch + >>> m = Manifold('M', 2) + >>> p = Patch('P', m) + >>> p + P + >>> p.dim + 2 + + References + ========== + + .. [1] G. Sussman, J. Wisdom, W. Farr, Functional Differential Geometry + (2013) + + """ + def __new__(cls, name, manifold, **kwargs): + if not isinstance(name, Str): + name = Str(name) + obj = super().__new__(cls, name, manifold) + + obj.manifold.patches.append(obj) # deprecated + obj.coord_systems = _deprecated_list( + """ + Patch.coord_systms is deprecated. The Patch class is now + immutable. Instead use a separate list to keep track of coordinate + systems. + """, []) + return obj + + @property + def name(self): + return self.args[0] + + @property + def manifold(self): + return self.args[1] + + @property + def dim(self): + return self.manifold.dim + + +class CoordSystem(Basic): + """ + A coordinate system defined on the patch. + + Explanation + =========== + + Coordinate system is a system that uses one or more coordinates to uniquely + determine the position of the points or other geometric elements on a + manifold [1]. + + By passing ``Symbols`` to *symbols* parameter, user can define the name and + assumptions of coordinate symbols of the coordinate system. If not passed, + these symbols are generated automatically and are assumed to be real valued. + + By passing *relations* parameter, user can define the transform relations of + coordinate systems. Inverse transformation and indirect transformation can + be found automatically. If this parameter is not passed, coordinate + transformation cannot be done. + + Parameters + ========== + + name : str + The name of the coordinate system. + + patch : Patch + The patch where the coordinate system is defined. + + symbols : list of Symbols, optional + Defines the names and assumptions of coordinate symbols. + + relations : dict, optional + Key is a tuple of two strings, who are the names of the systems where + the coordinates transform from and transform to. + Value is a tuple of the symbols before transformation and a tuple of + the expressions after transformation. + + Examples + ======== + + We define two-dimensional Cartesian coordinate system and polar coordinate + system. + + >>> from sympy import symbols, pi, sqrt, atan2, cos, sin + >>> from sympy.diffgeom import Manifold, Patch, CoordSystem + >>> m = Manifold('M', 2) + >>> p = Patch('P', m) + >>> x, y = symbols('x y', real=True) + >>> r, theta = symbols('r theta', nonnegative=True) + >>> relation_dict = { + ... ('Car2D', 'Pol'): [(x, y), (sqrt(x**2 + y**2), atan2(y, x))], + ... ('Pol', 'Car2D'): [(r, theta), (r*cos(theta), r*sin(theta))] + ... } + >>> Car2D = CoordSystem('Car2D', p, (x, y), relation_dict) + >>> Pol = CoordSystem('Pol', p, (r, theta), relation_dict) + + ``symbols`` property returns ``CoordinateSymbol`` instances. These symbols + are not same with the symbols used to construct the coordinate system. + + >>> Car2D + Car2D + >>> Car2D.dim + 2 + >>> Car2D.symbols + (x, y) + >>> _[0].func + + + ``transformation()`` method returns the transformation function from + one coordinate system to another. ``transform()`` method returns the + transformed coordinates. + + >>> Car2D.transformation(Pol) + Lambda((x, y), Matrix([ + [sqrt(x**2 + y**2)], + [ atan2(y, x)]])) + >>> Car2D.transform(Pol) + Matrix([ + [sqrt(x**2 + y**2)], + [ atan2(y, x)]]) + >>> Car2D.transform(Pol, [1, 2]) + Matrix([ + [sqrt(5)], + [atan(2)]]) + + ``jacobian()`` method returns the Jacobian matrix of coordinate + transformation between two systems. ``jacobian_determinant()`` method + returns the Jacobian determinant of coordinate transformation between two + systems. + + >>> Pol.jacobian(Car2D) + Matrix([ + [cos(theta), -r*sin(theta)], + [sin(theta), r*cos(theta)]]) + >>> Pol.jacobian(Car2D, [1, pi/2]) + Matrix([ + [0, -1], + [1, 0]]) + >>> Car2D.jacobian_determinant(Pol) + 1/sqrt(x**2 + y**2) + >>> Car2D.jacobian_determinant(Pol, [1,0]) + 1 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Coordinate_system + + """ + def __new__(cls, name, patch, symbols=None, relations={}, **kwargs): + if not isinstance(name, Str): + name = Str(name) + + # canonicallize the symbols + if symbols is None: + names = kwargs.get('names', None) + if names is None: + symbols = Tuple( + *[Symbol('%s_%s' % (name.name, i), real=True) + for i in range(patch.dim)] + ) + else: + sympy_deprecation_warning( + f""" +The 'names' argument to CoordSystem is deprecated. Use 'symbols' instead. That +is, replace + + CoordSystem(..., names={names}) + +with + + CoordSystem(..., symbols=[{', '.join(["Symbol(" + repr(n) + ", real=True)" for n in names])}]) + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-diffgeom-mutable", + ) + symbols = Tuple( + *[Symbol(n, real=True) for n in names] + ) + else: + syms = [] + for s in symbols: + if isinstance(s, Symbol): + syms.append(Symbol(s.name, **s._assumptions.generator)) + elif isinstance(s, str): + sympy_deprecation_warning( + f""" + +Passing a string as the coordinate symbol name to CoordSystem is deprecated. +Pass a Symbol with the appropriate name and assumptions instead. + +That is, replace {s} with Symbol({s!r}, real=True). + """, + + deprecated_since_version="1.7", + active_deprecations_target="deprecated-diffgeom-mutable", + ) + syms.append(Symbol(s, real=True)) + symbols = Tuple(*syms) + + # canonicallize the relations + rel_temp = {} + for k,v in relations.items(): + s1, s2 = k + if not isinstance(s1, Str): + s1 = Str(s1) + if not isinstance(s2, Str): + s2 = Str(s2) + key = Tuple(s1, s2) + + # Old version used Lambda as a value. + if isinstance(v, Lambda): + v = (tuple(v.signature), tuple(v.expr)) + else: + v = (tuple(v[0]), tuple(v[1])) + rel_temp[key] = v + relations = Dict(rel_temp) + + # construct the object + obj = super().__new__(cls, name, patch, symbols, relations) + + # Add deprecated attributes + obj.transforms = _deprecated_dict( + """ + CoordSystem.transforms is deprecated. The CoordSystem class is now + immutable. Use the 'relations' keyword argument to the + CoordSystems() constructor to specify relations. + """, {}) + obj._names = [str(n) for n in symbols] + obj.patch.coord_systems.append(obj) # deprecated + obj._dummies = [Dummy(str(n)) for n in symbols] # deprecated + obj._dummy = Dummy() + + return obj + + @property + def name(self): + return self.args[0] + + @property + def patch(self): + return self.args[1] + + @property + def manifold(self): + return self.patch.manifold + + @property + def symbols(self): + return tuple(CoordinateSymbol(self, i, **s._assumptions.generator) + for i,s in enumerate(self.args[2])) + + @property + def relations(self): + return self.args[3] + + @property + def dim(self): + return self.patch.dim + + ########################################################################## + # Finding transformation relation + ########################################################################## + + def transformation(self, sys): + """ + Return coordinate transformation function from *self* to *sys*. + + Parameters + ========== + + sys : CoordSystem + + Returns + ======= + + sympy.Lambda + + Examples + ======== + + >>> from sympy.diffgeom.rn import R2_r, R2_p + >>> R2_r.transformation(R2_p) + Lambda((x, y), Matrix([ + [sqrt(x**2 + y**2)], + [ atan2(y, x)]])) + + """ + signature = self.args[2] + + key = Tuple(self.name, sys.name) + if self == sys: + expr = Matrix(self.symbols) + elif key in self.relations: + expr = Matrix(self.relations[key][1]) + elif key[::-1] in self.relations: + expr = Matrix(self._inverse_transformation(sys, self)) + else: + expr = Matrix(self._indirect_transformation(self, sys)) + return Lambda(signature, expr) + + @staticmethod + def _solve_inverse(sym1, sym2, exprs, sys1_name, sys2_name): + ret = solve( + [t[0] - t[1] for t in zip(sym2, exprs)], + list(sym1), dict=True) + + if len(ret) == 0: + temp = "Cannot solve inverse relation from {} to {}." + raise NotImplementedError(temp.format(sys1_name, sys2_name)) + elif len(ret) > 1: + temp = "Obtained multiple inverse relation from {} to {}." + raise ValueError(temp.format(sys1_name, sys2_name)) + + return ret[0] + + @classmethod + def _inverse_transformation(cls, sys1, sys2): + # Find the transformation relation from sys2 to sys1 + forward = sys1.transform(sys2) + inv_results = cls._solve_inverse(sys1.symbols, sys2.symbols, forward, + sys1.name, sys2.name) + signature = tuple(sys1.symbols) + return [inv_results[s] for s in signature] + + @classmethod + @cacheit + def _indirect_transformation(cls, sys1, sys2): + # Find the transformation relation between two indirectly connected + # coordinate systems + rel = sys1.relations + path = cls._dijkstra(sys1, sys2) + + transforms = [] + for s1, s2 in zip(path, path[1:]): + if (s1, s2) in rel: + transforms.append(rel[(s1, s2)]) + else: + sym2, inv_exprs = rel[(s2, s1)] + sym1 = tuple(Dummy() for i in sym2) + ret = cls._solve_inverse(sym2, sym1, inv_exprs, s2, s1) + ret = tuple(ret[s] for s in sym2) + transforms.append((sym1, ret)) + syms = sys1.args[2] + exprs = syms + for newsyms, newexprs in transforms: + exprs = tuple(e.subs(zip(newsyms, exprs)) for e in newexprs) + return exprs + + @staticmethod + def _dijkstra(sys1, sys2): + # Use Dijkstra algorithm to find the shortest path between two indirectly-connected + # coordinate systems + # return value is the list of the names of the systems. + relations = sys1.relations + graph = {} + for s1, s2 in relations.keys(): + if s1 not in graph: + graph[s1] = {s2} + else: + graph[s1].add(s2) + if s2 not in graph: + graph[s2] = {s1} + else: + graph[s2].add(s1) + + path_dict = {sys:[0, [], 0] for sys in graph} # minimum distance, path, times of visited + + def visit(sys): + path_dict[sys][2] = 1 + for newsys in graph[sys]: + distance = path_dict[sys][0] + 1 + if path_dict[newsys][0] >= distance or not path_dict[newsys][1]: + path_dict[newsys][0] = distance + path_dict[newsys][1] = list(path_dict[sys][1]) + path_dict[newsys][1].append(sys) + + visit(sys1.name) + + while True: + min_distance = max(path_dict.values(), key=lambda x:x[0])[0] + newsys = None + for sys, lst in path_dict.items(): + if 0 < lst[0] <= min_distance and not lst[2]: + min_distance = lst[0] + newsys = sys + if newsys is None: + break + visit(newsys) + + result = path_dict[sys2.name][1] + result.append(sys2.name) + + if result == [sys2.name]: + raise KeyError("Two coordinate systems are not connected.") + return result + + def connect_to(self, to_sys, from_coords, to_exprs, inverse=True, fill_in_gaps=False): + sympy_deprecation_warning( + """ + The CoordSystem.connect_to() method is deprecated. Instead, + generate a new instance of CoordSystem with the 'relations' + keyword argument (CoordSystem classes are now immutable). + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-diffgeom-mutable", + ) + + from_coords, to_exprs = dummyfy(from_coords, to_exprs) + self.transforms[to_sys] = Matrix(from_coords), Matrix(to_exprs) + + if inverse: + to_sys.transforms[self] = self._inv_transf(from_coords, to_exprs) + + if fill_in_gaps: + self._fill_gaps_in_transformations() + + @staticmethod + def _inv_transf(from_coords, to_exprs): + # Will be removed when connect_to is removed + inv_from = [i.as_dummy() for i in from_coords] + inv_to = solve( + [t[0] - t[1] for t in zip(inv_from, to_exprs)], + list(from_coords), dict=True)[0] + inv_to = [inv_to[fc] for fc in from_coords] + return Matrix(inv_from), Matrix(inv_to) + + @staticmethod + def _fill_gaps_in_transformations(): + # Will be removed when connect_to is removed + raise NotImplementedError + + ########################################################################## + # Coordinate transformations + ########################################################################## + + def transform(self, sys, coordinates=None): + """ + Return the result of coordinate transformation from *self* to *sys*. + If coordinates are not given, coordinate symbols of *self* are used. + + Parameters + ========== + + sys : CoordSystem + + coordinates : Any iterable, optional. + + Returns + ======= + + sympy.ImmutableDenseMatrix containing CoordinateSymbol + + Examples + ======== + + >>> from sympy.diffgeom.rn import R2_r, R2_p + >>> R2_r.transform(R2_p) + Matrix([ + [sqrt(x**2 + y**2)], + [ atan2(y, x)]]) + >>> R2_r.transform(R2_p, [0, 1]) + Matrix([ + [ 1], + [pi/2]]) + + """ + if coordinates is None: + coordinates = self.symbols + if self != sys: + transf = self.transformation(sys) + coordinates = transf(*coordinates) + else: + coordinates = Matrix(coordinates) + return coordinates + + def coord_tuple_transform_to(self, to_sys, coords): + """Transform ``coords`` to coord system ``to_sys``.""" + sympy_deprecation_warning( + """ + The CoordSystem.coord_tuple_transform_to() method is deprecated. + Use the CoordSystem.transform() method instead. + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-diffgeom-mutable", + ) + + coords = Matrix(coords) + if self != to_sys: + with ignore_warnings(SymPyDeprecationWarning): + transf = self.transforms[to_sys] + coords = transf[1].subs(list(zip(transf[0], coords))) + return coords + + def jacobian(self, sys, coordinates=None): + """ + Return the jacobian matrix of a transformation on given coordinates. + If coordinates are not given, coordinate symbols of *self* are used. + + Parameters + ========== + + sys : CoordSystem + + coordinates : Any iterable, optional. + + Returns + ======= + + sympy.ImmutableDenseMatrix + + Examples + ======== + + >>> from sympy.diffgeom.rn import R2_r, R2_p + >>> R2_p.jacobian(R2_r) + Matrix([ + [cos(theta), -rho*sin(theta)], + [sin(theta), rho*cos(theta)]]) + >>> R2_p.jacobian(R2_r, [1, 0]) + Matrix([ + [1, 0], + [0, 1]]) + + """ + result = self.transform(sys).jacobian(self.symbols) + if coordinates is not None: + result = result.subs(list(zip(self.symbols, coordinates))) + return result + jacobian_matrix = jacobian + + def jacobian_determinant(self, sys, coordinates=None): + """ + Return the jacobian determinant of a transformation on given + coordinates. If coordinates are not given, coordinate symbols of *self* + are used. + + Parameters + ========== + + sys : CoordSystem + + coordinates : Any iterable, optional. + + Returns + ======= + + sympy.Expr + + Examples + ======== + + >>> from sympy.diffgeom.rn import R2_r, R2_p + >>> R2_r.jacobian_determinant(R2_p) + 1/sqrt(x**2 + y**2) + >>> R2_r.jacobian_determinant(R2_p, [1, 0]) + 1 + + """ + return self.jacobian(sys, coordinates).det() + + + ########################################################################## + # Points + ########################################################################## + + def point(self, coords): + """Create a ``Point`` with coordinates given in this coord system.""" + return Point(self, coords) + + def point_to_coords(self, point): + """Calculate the coordinates of a point in this coord system.""" + return point.coords(self) + + ########################################################################## + # Base fields. + ########################################################################## + + def base_scalar(self, coord_index): + """Return ``BaseScalarField`` that takes a point and returns one of the coordinates.""" + return BaseScalarField(self, coord_index) + coord_function = base_scalar + + def base_scalars(self): + """Returns a list of all coordinate functions. + For more details see the ``base_scalar`` method of this class.""" + return [self.base_scalar(i) for i in range(self.dim)] + coord_functions = base_scalars + + def base_vector(self, coord_index): + """Return a basis vector field. + The basis vector field for this coordinate system. It is also an + operator on scalar fields.""" + return BaseVectorField(self, coord_index) + + def base_vectors(self): + """Returns a list of all base vectors. + For more details see the ``base_vector`` method of this class.""" + return [self.base_vector(i) for i in range(self.dim)] + + def base_oneform(self, coord_index): + """Return a basis 1-form field. + The basis one-form field for this coordinate system. It is also an + operator on vector fields.""" + return Differential(self.coord_function(coord_index)) + + def base_oneforms(self): + """Returns a list of all base oneforms. + For more details see the ``base_oneform`` method of this class.""" + return [self.base_oneform(i) for i in range(self.dim)] + + +class CoordinateSymbol(Symbol): + """A symbol which denotes an abstract value of i-th coordinate of + the coordinate system with given context. + + Explanation + =========== + + Each coordinates in coordinate system are represented by unique symbol, + such as x, y, z in Cartesian coordinate system. + + You may not construct this class directly. Instead, use `symbols` method + of CoordSystem. + + Parameters + ========== + + coord_sys : CoordSystem + + index : integer + + Examples + ======== + + >>> from sympy import symbols, Lambda, Matrix, sqrt, atan2, cos, sin + >>> from sympy.diffgeom import Manifold, Patch, CoordSystem + >>> m = Manifold('M', 2) + >>> p = Patch('P', m) + >>> x, y = symbols('x y', real=True) + >>> r, theta = symbols('r theta', nonnegative=True) + >>> relation_dict = { + ... ('Car2D', 'Pol'): Lambda((x, y), Matrix([sqrt(x**2 + y**2), atan2(y, x)])), + ... ('Pol', 'Car2D'): Lambda((r, theta), Matrix([r*cos(theta), r*sin(theta)])) + ... } + >>> Car2D = CoordSystem('Car2D', p, [x, y], relation_dict) + >>> Pol = CoordSystem('Pol', p, [r, theta], relation_dict) + >>> x, y = Car2D.symbols + + ``CoordinateSymbol`` contains its coordinate symbol and index. + + >>> x.name + 'x' + >>> x.coord_sys == Car2D + True + >>> x.index + 0 + >>> x.is_real + True + + You can transform ``CoordinateSymbol`` into other coordinate system using + ``rewrite()`` method. + + >>> x.rewrite(Pol) + r*cos(theta) + >>> sqrt(x**2 + y**2).rewrite(Pol).simplify() + r + + """ + def __new__(cls, coord_sys, index, **assumptions): + name = coord_sys.args[2][index].name + obj = super().__new__(cls, name, **assumptions) + obj.coord_sys = coord_sys + obj.index = index + return obj + + def __getnewargs__(self): + return (self.coord_sys, self.index) + + def _hashable_content(self): + return ( + self.coord_sys, self.index + ) + tuple(sorted(self.assumptions0.items())) + + def _eval_rewrite(self, rule, args, **hints): + if isinstance(rule, CoordSystem): + return rule.transform(self.coord_sys)[self.index] + return super()._eval_rewrite(rule, args, **hints) + + +class Point(Basic): + """Point defined in a coordinate system. + + Explanation + =========== + + Mathematically, point is defined in the manifold and does not have any coordinates + by itself. Coordinate system is what imbues the coordinates to the point by coordinate + chart. However, due to the difficulty of realizing such logic, you must supply + a coordinate system and coordinates to define a Point here. + + The usage of this object after its definition is independent of the + coordinate system that was used in order to define it, however due to + limitations in the simplification routines you can arrive at complicated + expressions if you use inappropriate coordinate systems. + + Parameters + ========== + + coord_sys : CoordSystem + + coords : list + The coordinates of the point. + + Examples + ======== + + >>> from sympy import pi + >>> from sympy.diffgeom import Point + >>> from sympy.diffgeom.rn import R2, R2_r, R2_p + >>> rho, theta = R2_p.symbols + + >>> p = Point(R2_p, [rho, 3*pi/4]) + + >>> p.manifold == R2 + True + + >>> p.coords() + Matrix([ + [ rho], + [3*pi/4]]) + >>> p.coords(R2_r) + Matrix([ + [-sqrt(2)*rho/2], + [ sqrt(2)*rho/2]]) + + """ + + def __new__(cls, coord_sys, coords, **kwargs): + coords = Matrix(coords) + obj = super().__new__(cls, coord_sys, coords) + obj._coord_sys = coord_sys + obj._coords = coords + return obj + + @property + def patch(self): + return self._coord_sys.patch + + @property + def manifold(self): + return self._coord_sys.manifold + + @property + def dim(self): + return self.manifold.dim + + def coords(self, sys=None): + """ + Coordinates of the point in given coordinate system. If coordinate system + is not passed, it returns the coordinates in the coordinate system in which + the point was defined. + """ + if sys is None: + return self._coords + else: + return self._coord_sys.transform(sys, self._coords) + + @property + def free_symbols(self): + return self._coords.free_symbols + + +class BaseScalarField(Expr): + """Base scalar field over a manifold for a given coordinate system. + + Explanation + =========== + + A scalar field takes a point as an argument and returns a scalar. + A base scalar field of a coordinate system takes a point and returns one of + the coordinates of that point in the coordinate system in question. + + To define a scalar field you need to choose the coordinate system and the + index of the coordinate. + + The use of the scalar field after its definition is independent of the + coordinate system in which it was defined, however due to limitations in + the simplification routines you may arrive at more complicated + expression if you use unappropriate coordinate systems. + You can build complicated scalar fields by just building up SymPy + expressions containing ``BaseScalarField`` instances. + + Parameters + ========== + + coord_sys : CoordSystem + + index : integer + + Examples + ======== + + >>> from sympy import Function, pi + >>> from sympy.diffgeom import BaseScalarField + >>> from sympy.diffgeom.rn import R2_r, R2_p + >>> rho, _ = R2_p.symbols + >>> point = R2_p.point([rho, 0]) + >>> fx, fy = R2_r.base_scalars() + >>> ftheta = BaseScalarField(R2_r, 1) + + >>> fx(point) + rho + >>> fy(point) + 0 + + >>> (fx**2+fy**2).rcall(point) + rho**2 + + >>> g = Function('g') + >>> fg = g(ftheta-pi) + >>> fg.rcall(point) + g(-pi) + + """ + + is_commutative = True + + def __new__(cls, coord_sys, index, **kwargs): + index = _sympify(index) + obj = super().__new__(cls, coord_sys, index) + obj._coord_sys = coord_sys + obj._index = index + return obj + + @property + def coord_sys(self): + return self.args[0] + + @property + def index(self): + return self.args[1] + + @property + def patch(self): + return self.coord_sys.patch + + @property + def manifold(self): + return self.coord_sys.manifold + + @property + def dim(self): + return self.manifold.dim + + def __call__(self, *args): + """Evaluating the field at a point or doing nothing. + If the argument is a ``Point`` instance, the field is evaluated at that + point. The field is returned itself if the argument is any other + object. It is so in order to have working recursive calling mechanics + for all fields (check the ``__call__`` method of ``Expr``). + """ + point = args[0] + if len(args) != 1 or not isinstance(point, Point): + return self + coords = point.coords(self._coord_sys) + # XXX Calling doit is necessary with all the Subs expressions + # XXX Calling simplify is necessary with all the trig expressions + return simplify(coords[self._index]).doit() + + # XXX Workaround for limitations on the content of args + free_symbols: set[Any] = set() + + +class BaseVectorField(Expr): + r"""Base vector field over a manifold for a given coordinate system. + + Explanation + =========== + + A vector field is an operator taking a scalar field and returning a + directional derivative (which is also a scalar field). + A base vector field is the same type of operator, however the derivation is + specifically done with respect to a chosen coordinate. + + To define a base vector field you need to choose the coordinate system and + the index of the coordinate. + + The use of the vector field after its definition is independent of the + coordinate system in which it was defined, however due to limitations in the + simplification routines you may arrive at more complicated expression if you + use unappropriate coordinate systems. + + Parameters + ========== + coord_sys : CoordSystem + + index : integer + + Examples + ======== + + >>> from sympy import Function + >>> from sympy.diffgeom.rn import R2_p, R2_r + >>> from sympy.diffgeom import BaseVectorField + >>> from sympy import pprint + + >>> x, y = R2_r.symbols + >>> rho, theta = R2_p.symbols + >>> fx, fy = R2_r.base_scalars() + >>> point_p = R2_p.point([rho, theta]) + >>> point_r = R2_r.point([x, y]) + + >>> g = Function('g') + >>> s_field = g(fx, fy) + + >>> v = BaseVectorField(R2_r, 1) + >>> pprint(v(s_field)) + / d \| + |---(g(x, xi))|| + \dxi /|xi=y + >>> pprint(v(s_field).rcall(point_r).doit()) + d + --(g(x, y)) + dy + >>> pprint(v(s_field).rcall(point_p)) + / d \| + |---(g(rho*cos(theta), xi))|| + \dxi /|xi=rho*sin(theta) + + """ + + is_commutative = False + + def __new__(cls, coord_sys, index, **kwargs): + index = _sympify(index) + obj = super().__new__(cls, coord_sys, index) + obj._coord_sys = coord_sys + obj._index = index + return obj + + @property + def coord_sys(self): + return self.args[0] + + @property + def index(self): + return self.args[1] + + @property + def patch(self): + return self.coord_sys.patch + + @property + def manifold(self): + return self.coord_sys.manifold + + @property + def dim(self): + return self.manifold.dim + + def __call__(self, scalar_field): + """Apply on a scalar field. + The action of a vector field on a scalar field is a directional + differentiation. + If the argument is not a scalar field an error is raised. + """ + if covariant_order(scalar_field) or contravariant_order(scalar_field): + raise ValueError('Only scalar fields can be supplied as arguments to vector fields.') + + if scalar_field is None: + return self + + base_scalars = list(scalar_field.atoms(BaseScalarField)) + + # First step: e_x(x+r**2) -> e_x(x) + 2*r*e_x(r) + d_var = self._coord_sys._dummy + # TODO: you need a real dummy function for the next line + d_funcs = [Function('_#_%s' % i)(d_var) for i, + b in enumerate(base_scalars)] + d_result = scalar_field.subs(list(zip(base_scalars, d_funcs))) + d_result = d_result.diff(d_var) + + # Second step: e_x(x) -> 1 and e_x(r) -> cos(atan2(x, y)) + coords = self._coord_sys.symbols + d_funcs_deriv = [f.diff(d_var) for f in d_funcs] + d_funcs_deriv_sub = [] + for b in base_scalars: + jac = self._coord_sys.jacobian(b._coord_sys, coords) + d_funcs_deriv_sub.append(jac[b._index, self._index]) + d_result = d_result.subs(list(zip(d_funcs_deriv, d_funcs_deriv_sub))) + + # Remove the dummies + result = d_result.subs(list(zip(d_funcs, base_scalars))) + result = result.subs(list(zip(coords, self._coord_sys.coord_functions()))) + return result.doit() + + +def _find_coords(expr): + # Finds CoordinateSystems existing in expr + fields = expr.atoms(BaseScalarField, BaseVectorField) + return {f._coord_sys for f in fields} + + +class Commutator(Expr): + r"""Commutator of two vector fields. + + Explanation + =========== + + The commutator of two vector fields `v_1` and `v_2` is defined as the + vector field `[v_1, v_2]` that evaluated on each scalar field `f` is equal + to `v_1(v_2(f)) - v_2(v_1(f))`. + + Examples + ======== + + + >>> from sympy.diffgeom.rn import R2_p, R2_r + >>> from sympy.diffgeom import Commutator + >>> from sympy import simplify + + >>> fx, fy = R2_r.base_scalars() + >>> e_x, e_y = R2_r.base_vectors() + >>> e_r = R2_p.base_vector(0) + + >>> c_xy = Commutator(e_x, e_y) + >>> c_xr = Commutator(e_x, e_r) + >>> c_xy + 0 + + Unfortunately, the current code is not able to compute everything: + + >>> c_xr + Commutator(e_x, e_rho) + >>> simplify(c_xr(fy**2)) + -2*cos(theta)*y**2/(x**2 + y**2) + + """ + def __new__(cls, v1, v2): + if (covariant_order(v1) or contravariant_order(v1) != 1 + or covariant_order(v2) or contravariant_order(v2) != 1): + raise ValueError( + 'Only commutators of vector fields are supported.') + if v1 == v2: + return S.Zero + coord_sys = set().union(*[_find_coords(v) for v in (v1, v2)]) + if len(coord_sys) == 1: + # Only one coordinate systems is used, hence it is easy enough to + # actually evaluate the commutator. + if all(isinstance(v, BaseVectorField) for v in (v1, v2)): + return S.Zero + bases_1, bases_2 = [list(v.atoms(BaseVectorField)) + for v in (v1, v2)] + coeffs_1 = [v1.expand().coeff(b) for b in bases_1] + coeffs_2 = [v2.expand().coeff(b) for b in bases_2] + res = 0 + for c1, b1 in zip(coeffs_1, bases_1): + for c2, b2 in zip(coeffs_2, bases_2): + res += c1*b1(c2)*b2 - c2*b2(c1)*b1 + return res + else: + obj = super().__new__(cls, v1, v2) + obj._v1 = v1 # deprecated assignment + obj._v2 = v2 # deprecated assignment + return obj + + @property + def v1(self): + return self.args[0] + + @property + def v2(self): + return self.args[1] + + def __call__(self, scalar_field): + """Apply on a scalar field. + If the argument is not a scalar field an error is raised. + """ + return self.v1(self.v2(scalar_field)) - self.v2(self.v1(scalar_field)) + + +class Differential(Expr): + r"""Return the differential (exterior derivative) of a form field. + + Explanation + =========== + + The differential of a form (i.e. the exterior derivative) has a complicated + definition in the general case. + The differential `df` of the 0-form `f` is defined for any vector field `v` + as `df(v) = v(f)`. + + Examples + ======== + + >>> from sympy import Function + >>> from sympy.diffgeom.rn import R2_r + >>> from sympy.diffgeom import Differential + >>> from sympy import pprint + + >>> fx, fy = R2_r.base_scalars() + >>> e_x, e_y = R2_r.base_vectors() + >>> g = Function('g') + >>> s_field = g(fx, fy) + >>> dg = Differential(s_field) + + >>> dg + d(g(x, y)) + >>> pprint(dg(e_x)) + / d \| + |---(g(xi, y))|| + \dxi /|xi=x + >>> pprint(dg(e_y)) + / d \| + |---(g(x, xi))|| + \dxi /|xi=y + + Applying the exterior derivative operator twice always results in: + + >>> Differential(dg) + 0 + """ + + is_commutative = False + + def __new__(cls, form_field): + if contravariant_order(form_field): + raise ValueError( + 'A vector field was supplied as an argument to Differential.') + if isinstance(form_field, Differential): + return S.Zero + else: + obj = super().__new__(cls, form_field) + obj._form_field = form_field # deprecated assignment + return obj + + @property + def form_field(self): + return self.args[0] + + def __call__(self, *vector_fields): + """Apply on a list of vector_fields. + + Explanation + =========== + + If the number of vector fields supplied is not equal to 1 + the order of + the form field inside the differential the result is undefined. + + For 1-forms (i.e. differentials of scalar fields) the evaluation is + done as `df(v)=v(f)`. However if `v` is ``None`` instead of a vector + field, the differential is returned unchanged. This is done in order to + permit partial contractions for higher forms. + + In the general case the evaluation is done by applying the form field + inside the differential on a list with one less elements than the number + of elements in the original list. Lowering the number of vector fields + is achieved through replacing each pair of fields by their + commutator. + + If the arguments are not vectors or ``None``s an error is raised. + """ + if any((contravariant_order(a) != 1 or covariant_order(a)) and a is not None + for a in vector_fields): + raise ValueError('The arguments supplied to Differential should be vector fields or Nones.') + k = len(vector_fields) + if k == 1: + if vector_fields[0]: + return vector_fields[0].rcall(self._form_field) + return self + else: + # For higher form it is more complicated: + # Invariant formula: + # https://en.wikipedia.org/wiki/Exterior_derivative#Invariant_formula + # df(v1, ... vn) = +/- vi(f(v1..no i..vn)) + # +/- f([vi,vj],v1..no i, no j..vn) + f = self._form_field + v = vector_fields + ret = 0 + for i in range(k): + t = v[i].rcall(f.rcall(*v[:i] + v[i + 1:])) + ret += (-1)**i*t + for j in range(i + 1, k): + c = Commutator(v[i], v[j]) + if c: # TODO this is ugly - the Commutator can be Zero and + # this causes the next line to fail + t = f.rcall(*(c,) + v[:i] + v[i + 1:j] + v[j + 1:]) + ret += (-1)**(i + j)*t + return ret + + +class TensorProduct(Expr): + """Tensor product of forms. + + Explanation + =========== + + The tensor product permits the creation of multilinear functionals (i.e. + higher order tensors) out of lower order fields (e.g. 1-forms and vector + fields). However, the higher tensors thus created lack the interesting + features provided by the other type of product, the wedge product, namely + they are not antisymmetric and hence are not form fields. + + Examples + ======== + + >>> from sympy.diffgeom.rn import R2_r + >>> from sympy.diffgeom import TensorProduct + + >>> fx, fy = R2_r.base_scalars() + >>> e_x, e_y = R2_r.base_vectors() + >>> dx, dy = R2_r.base_oneforms() + + >>> TensorProduct(dx, dy)(e_x, e_y) + 1 + >>> TensorProduct(dx, dy)(e_y, e_x) + 0 + >>> TensorProduct(dx, fx*dy)(fx*e_x, e_y) + x**2 + >>> TensorProduct(e_x, e_y)(fx**2, fy**2) + 4*x*y + >>> TensorProduct(e_y, dx)(fy) + dx + + You can nest tensor products. + + >>> tp1 = TensorProduct(dx, dy) + >>> TensorProduct(tp1, dx)(e_x, e_y, e_x) + 1 + + You can make partial contraction for instance when 'raising an index'. + Putting ``None`` in the second argument of ``rcall`` means that the + respective position in the tensor product is left as it is. + + >>> TP = TensorProduct + >>> metric = TP(dx, dx) + 3*TP(dy, dy) + >>> metric.rcall(e_y, None) + 3*dy + + Or automatically pad the args with ``None`` without specifying them. + + >>> metric.rcall(e_y) + 3*dy + + """ + def __new__(cls, *args): + scalar = Mul(*[m for m in args if covariant_order(m) + contravariant_order(m) == 0]) + multifields = [m for m in args if covariant_order(m) + contravariant_order(m)] + if multifields: + if len(multifields) == 1: + return scalar*multifields[0] + return scalar*super().__new__(cls, *multifields) + else: + return scalar + + def __call__(self, *fields): + """Apply on a list of fields. + + If the number of input fields supplied is not equal to the order of + the tensor product field, the list of arguments is padded with ``None``'s. + + The list of arguments is divided in sublists depending on the order of + the forms inside the tensor product. The sublists are provided as + arguments to these forms and the resulting expressions are given to the + constructor of ``TensorProduct``. + + """ + tot_order = covariant_order(self) + contravariant_order(self) + tot_args = len(fields) + if tot_args != tot_order: + fields = list(fields) + [None]*(tot_order - tot_args) + orders = [covariant_order(f) + contravariant_order(f) for f in self._args] + indices = [sum(orders[:i + 1]) for i in range(len(orders) - 1)] + fields = [fields[i:j] for i, j in zip([0] + indices, indices + [None])] + multipliers = [t[0].rcall(*t[1]) for t in zip(self._args, fields)] + return TensorProduct(*multipliers) + + +class WedgeProduct(TensorProduct): + """Wedge product of forms. + + Explanation + =========== + + In the context of integration only completely antisymmetric forms make + sense. The wedge product permits the creation of such forms. + + Examples + ======== + + >>> from sympy.diffgeom.rn import R2_r + >>> from sympy.diffgeom import WedgeProduct + + >>> fx, fy = R2_r.base_scalars() + >>> e_x, e_y = R2_r.base_vectors() + >>> dx, dy = R2_r.base_oneforms() + + >>> WedgeProduct(dx, dy)(e_x, e_y) + 1 + >>> WedgeProduct(dx, dy)(e_y, e_x) + -1 + >>> WedgeProduct(dx, fx*dy)(fx*e_x, e_y) + x**2 + >>> WedgeProduct(e_x, e_y)(fy, None) + -e_x + + You can nest wedge products. + + >>> wp1 = WedgeProduct(dx, dy) + >>> WedgeProduct(wp1, dx)(e_x, e_y, e_x) + 0 + + """ + # TODO the calculation of signatures is slow + # TODO you do not need all these permutations (neither the prefactor) + def __call__(self, *fields): + """Apply on a list of vector_fields. + The expression is rewritten internally in terms of tensor products and evaluated.""" + orders = (covariant_order(e) + contravariant_order(e) for e in self.args) + mul = 1/Mul(*(factorial(o) for o in orders)) + perms = permutations(fields) + perms_par = (Permutation( + p).signature() for p in permutations(range(len(fields)))) + tensor_prod = TensorProduct(*self.args) + return mul*Add(*[tensor_prod(*p[0])*p[1] for p in zip(perms, perms_par)]) + + +class LieDerivative(Expr): + """Lie derivative with respect to a vector field. + + Explanation + =========== + + The transport operator that defines the Lie derivative is the pushforward of + the field to be derived along the integral curve of the field with respect + to which one derives. + + Examples + ======== + + >>> from sympy.diffgeom.rn import R2_r, R2_p + >>> from sympy.diffgeom import (LieDerivative, TensorProduct) + + >>> fx, fy = R2_r.base_scalars() + >>> e_x, e_y = R2_r.base_vectors() + >>> e_rho, e_theta = R2_p.base_vectors() + >>> dx, dy = R2_r.base_oneforms() + + >>> LieDerivative(e_x, fy) + 0 + >>> LieDerivative(e_x, fx) + 1 + >>> LieDerivative(e_x, e_x) + 0 + + The Lie derivative of a tensor field by another tensor field is equal to + their commutator: + + >>> LieDerivative(e_x, e_rho) + Commutator(e_x, e_rho) + >>> LieDerivative(e_x + e_y, fx) + 1 + + >>> tp = TensorProduct(dx, dy) + >>> LieDerivative(e_x, tp) + LieDerivative(e_x, TensorProduct(dx, dy)) + >>> LieDerivative(e_x, tp) + LieDerivative(e_x, TensorProduct(dx, dy)) + + """ + def __new__(cls, v_field, expr): + expr_form_ord = covariant_order(expr) + if contravariant_order(v_field) != 1 or covariant_order(v_field): + raise ValueError('Lie derivatives are defined only with respect to' + ' vector fields. The supplied argument was not a ' + 'vector field.') + if expr_form_ord > 0: + obj = super().__new__(cls, v_field, expr) + # deprecated assignments + obj._v_field = v_field + obj._expr = expr + return obj + if expr.atoms(BaseVectorField): + return Commutator(v_field, expr) + else: + return v_field.rcall(expr) + + @property + def v_field(self): + return self.args[0] + + @property + def expr(self): + return self.args[1] + + def __call__(self, *args): + v = self.v_field + expr = self.expr + lead_term = v(expr(*args)) + rest = Add(*[Mul(*args[:i] + (Commutator(v, args[i]),) + args[i + 1:]) + for i in range(len(args))]) + return lead_term - rest + + +class BaseCovarDerivativeOp(Expr): + """Covariant derivative operator with respect to a base vector. + + Examples + ======== + + >>> from sympy.diffgeom.rn import R2_r + >>> from sympy.diffgeom import BaseCovarDerivativeOp + >>> from sympy.diffgeom import metric_to_Christoffel_2nd, TensorProduct + + >>> TP = TensorProduct + >>> fx, fy = R2_r.base_scalars() + >>> e_x, e_y = R2_r.base_vectors() + >>> dx, dy = R2_r.base_oneforms() + + >>> ch = metric_to_Christoffel_2nd(TP(dx, dx) + TP(dy, dy)) + >>> ch + [[[0, 0], [0, 0]], [[0, 0], [0, 0]]] + >>> cvd = BaseCovarDerivativeOp(R2_r, 0, ch) + >>> cvd(fx) + 1 + >>> cvd(fx*e_x) + e_x + """ + + def __new__(cls, coord_sys, index, christoffel): + index = _sympify(index) + christoffel = ImmutableDenseNDimArray(christoffel) + obj = super().__new__(cls, coord_sys, index, christoffel) + # deprecated assignments + obj._coord_sys = coord_sys + obj._index = index + obj._christoffel = christoffel + return obj + + @property + def coord_sys(self): + return self.args[0] + + @property + def index(self): + return self.args[1] + + @property + def christoffel(self): + return self.args[2] + + def __call__(self, field): + """Apply on a scalar field. + + The action of a vector field on a scalar field is a directional + differentiation. + If the argument is not a scalar field the behaviour is undefined. + """ + if covariant_order(field) != 0: + raise NotImplementedError() + + field = vectors_in_basis(field, self._coord_sys) + + wrt_vector = self._coord_sys.base_vector(self._index) + wrt_scalar = self._coord_sys.coord_function(self._index) + vectors = list(field.atoms(BaseVectorField)) + + # First step: replace all vectors with something susceptible to + # derivation and do the derivation + # TODO: you need a real dummy function for the next line + d_funcs = [Function('_#_%s' % i)(wrt_scalar) for i, + b in enumerate(vectors)] + d_result = field.subs(list(zip(vectors, d_funcs))) + d_result = wrt_vector(d_result) + + # Second step: backsubstitute the vectors in + d_result = d_result.subs(list(zip(d_funcs, vectors))) + + # Third step: evaluate the derivatives of the vectors + derivs = [] + for v in vectors: + d = Add(*[(self._christoffel[k, wrt_vector._index, v._index] + *v._coord_sys.base_vector(k)) + for k in range(v._coord_sys.dim)]) + derivs.append(d) + to_subs = [wrt_vector(d) for d in d_funcs] + # XXX: This substitution can fail when there are Dummy symbols and the + # cache is disabled: https://github.com/sympy/sympy/issues/17794 + result = d_result.subs(list(zip(to_subs, derivs))) + + # Remove the dummies + result = result.subs(list(zip(d_funcs, vectors))) + return result.doit() + + +class CovarDerivativeOp(Expr): + """Covariant derivative operator. + + Examples + ======== + + >>> from sympy.diffgeom.rn import R2_r + >>> from sympy.diffgeom import CovarDerivativeOp + >>> from sympy.diffgeom import metric_to_Christoffel_2nd, TensorProduct + >>> TP = TensorProduct + >>> fx, fy = R2_r.base_scalars() + >>> e_x, e_y = R2_r.base_vectors() + >>> dx, dy = R2_r.base_oneforms() + >>> ch = metric_to_Christoffel_2nd(TP(dx, dx) + TP(dy, dy)) + + >>> ch + [[[0, 0], [0, 0]], [[0, 0], [0, 0]]] + >>> cvd = CovarDerivativeOp(fx*e_x, ch) + >>> cvd(fx) + x + >>> cvd(fx*e_x) + x*e_x + + """ + + def __new__(cls, wrt, christoffel): + if len({v._coord_sys for v in wrt.atoms(BaseVectorField)}) > 1: + raise NotImplementedError() + if contravariant_order(wrt) != 1 or covariant_order(wrt): + raise ValueError('Covariant derivatives are defined only with ' + 'respect to vector fields. The supplied argument ' + 'was not a vector field.') + christoffel = ImmutableDenseNDimArray(christoffel) + obj = super().__new__(cls, wrt, christoffel) + # deprecated assignments + obj._wrt = wrt + obj._christoffel = christoffel + return obj + + @property + def wrt(self): + return self.args[0] + + @property + def christoffel(self): + return self.args[1] + + def __call__(self, field): + vectors = list(self._wrt.atoms(BaseVectorField)) + base_ops = [BaseCovarDerivativeOp(v._coord_sys, v._index, self._christoffel) + for v in vectors] + return self._wrt.subs(list(zip(vectors, base_ops))).rcall(field) + + +############################################################################### +# Integral curves on vector fields +############################################################################### +def intcurve_series(vector_field, param, start_point, n=6, coord_sys=None, coeffs=False): + r"""Return the series expansion for an integral curve of the field. + + Explanation + =========== + + Integral curve is a function `\gamma` taking a parameter in `R` to a point + in the manifold. It verifies the equation: + + `V(f)\big(\gamma(t)\big) = \frac{d}{dt}f\big(\gamma(t)\big)` + + where the given ``vector_field`` is denoted as `V`. This holds for any + value `t` for the parameter and any scalar field `f`. + + This equation can also be decomposed of a basis of coordinate functions + `V(f_i)\big(\gamma(t)\big) = \frac{d}{dt}f_i\big(\gamma(t)\big) \quad \forall i` + + This function returns a series expansion of `\gamma(t)` in terms of the + coordinate system ``coord_sys``. The equations and expansions are necessarily + done in coordinate-system-dependent way as there is no other way to + represent movement between points on the manifold (i.e. there is no such + thing as a difference of points for a general manifold). + + Parameters + ========== + vector_field + the vector field for which an integral curve will be given + + param + the argument of the function `\gamma` from R to the curve + + start_point + the point which corresponds to `\gamma(0)` + + n + the order to which to expand + + coord_sys + the coordinate system in which to expand + coeffs (default False) - if True return a list of elements of the expansion + + Examples + ======== + + Use the predefined R2 manifold: + + >>> from sympy.abc import t, x, y + >>> from sympy.diffgeom.rn import R2_p, R2_r + >>> from sympy.diffgeom import intcurve_series + + Specify a starting point and a vector field: + + >>> start_point = R2_r.point([x, y]) + >>> vector_field = R2_r.e_x + + Calculate the series: + + >>> intcurve_series(vector_field, t, start_point, n=3) + Matrix([ + [t + x], + [ y]]) + + Or get the elements of the expansion in a list: + + >>> series = intcurve_series(vector_field, t, start_point, n=3, coeffs=True) + >>> series[0] + Matrix([ + [x], + [y]]) + >>> series[1] + Matrix([ + [t], + [0]]) + >>> series[2] + Matrix([ + [0], + [0]]) + + The series in the polar coordinate system: + + >>> series = intcurve_series(vector_field, t, start_point, + ... n=3, coord_sys=R2_p, coeffs=True) + >>> series[0] + Matrix([ + [sqrt(x**2 + y**2)], + [ atan2(y, x)]]) + >>> series[1] + Matrix([ + [t*x/sqrt(x**2 + y**2)], + [ -t*y/(x**2 + y**2)]]) + >>> series[2] + Matrix([ + [t**2*(-x**2/(x**2 + y**2)**(3/2) + 1/sqrt(x**2 + y**2))/2], + [ t**2*x*y/(x**2 + y**2)**2]]) + + See Also + ======== + + intcurve_diffequ + + """ + if contravariant_order(vector_field) != 1 or covariant_order(vector_field): + raise ValueError('The supplied field was not a vector field.') + + def iter_vfield(scalar_field, i): + """Return ``vector_field`` called `i` times on ``scalar_field``.""" + return reduce(lambda s, v: v.rcall(s), [vector_field, ]*i, scalar_field) + + def taylor_terms_per_coord(coord_function): + """Return the series for one of the coordinates.""" + return [param**i*iter_vfield(coord_function, i).rcall(start_point)/factorial(i) + for i in range(n)] + coord_sys = coord_sys if coord_sys else start_point._coord_sys + coord_functions = coord_sys.coord_functions() + taylor_terms = [taylor_terms_per_coord(f) for f in coord_functions] + if coeffs: + return [Matrix(t) for t in zip(*taylor_terms)] + else: + return Matrix([sum(c) for c in taylor_terms]) + + +def intcurve_diffequ(vector_field, param, start_point, coord_sys=None): + r"""Return the differential equation for an integral curve of the field. + + Explanation + =========== + + Integral curve is a function `\gamma` taking a parameter in `R` to a point + in the manifold. It verifies the equation: + + `V(f)\big(\gamma(t)\big) = \frac{d}{dt}f\big(\gamma(t)\big)` + + where the given ``vector_field`` is denoted as `V`. This holds for any + value `t` for the parameter and any scalar field `f`. + + This function returns the differential equation of `\gamma(t)` in terms of the + coordinate system ``coord_sys``. The equations and expansions are necessarily + done in coordinate-system-dependent way as there is no other way to + represent movement between points on the manifold (i.e. there is no such + thing as a difference of points for a general manifold). + + Parameters + ========== + + vector_field + the vector field for which an integral curve will be given + + param + the argument of the function `\gamma` from R to the curve + + start_point + the point which corresponds to `\gamma(0)` + + coord_sys + the coordinate system in which to give the equations + + Returns + ======= + + a tuple of (equations, initial conditions) + + Examples + ======== + + Use the predefined R2 manifold: + + >>> from sympy.abc import t + >>> from sympy.diffgeom.rn import R2, R2_p, R2_r + >>> from sympy.diffgeom import intcurve_diffequ + + Specify a starting point and a vector field: + + >>> start_point = R2_r.point([0, 1]) + >>> vector_field = -R2.y*R2.e_x + R2.x*R2.e_y + + Get the equation: + + >>> equations, init_cond = intcurve_diffequ(vector_field, t, start_point) + >>> equations + [f_1(t) + Derivative(f_0(t), t), -f_0(t) + Derivative(f_1(t), t)] + >>> init_cond + [f_0(0), f_1(0) - 1] + + The series in the polar coordinate system: + + >>> equations, init_cond = intcurve_diffequ(vector_field, t, start_point, R2_p) + >>> equations + [Derivative(f_0(t), t), Derivative(f_1(t), t) - 1] + >>> init_cond + [f_0(0) - 1, f_1(0) - pi/2] + + See Also + ======== + + intcurve_series + + """ + if contravariant_order(vector_field) != 1 or covariant_order(vector_field): + raise ValueError('The supplied field was not a vector field.') + coord_sys = coord_sys if coord_sys else start_point._coord_sys + gammas = [Function('f_%d' % i)(param) for i in range( + start_point._coord_sys.dim)] + arbitrary_p = Point(coord_sys, gammas) + coord_functions = coord_sys.coord_functions() + equations = [simplify(diff(cf.rcall(arbitrary_p), param) - vector_field.rcall(cf).rcall(arbitrary_p)) + for cf in coord_functions] + init_cond = [simplify(cf.rcall(arbitrary_p).subs(param, 0) - cf.rcall(start_point)) + for cf in coord_functions] + return equations, init_cond + + +############################################################################### +# Helpers +############################################################################### +def dummyfy(args, exprs): + # TODO Is this a good idea? + d_args = Matrix([s.as_dummy() for s in args]) + reps = dict(zip(args, d_args)) + d_exprs = Matrix([_sympify(expr).subs(reps) for expr in exprs]) + return d_args, d_exprs + +############################################################################### +# Helpers +############################################################################### +def contravariant_order(expr, _strict=False): + """Return the contravariant order of an expression. + + Examples + ======== + + >>> from sympy.diffgeom import contravariant_order + >>> from sympy.diffgeom.rn import R2 + >>> from sympy.abc import a + + >>> contravariant_order(a) + 0 + >>> contravariant_order(a*R2.x + 2) + 0 + >>> contravariant_order(a*R2.x*R2.e_y + R2.e_x) + 1 + + """ + # TODO move some of this to class methods. + # TODO rewrite using the .as_blah_blah methods + if isinstance(expr, Add): + orders = [contravariant_order(e) for e in expr.args] + if len(set(orders)) != 1: + raise ValueError('Misformed expression containing contravariant fields of varying order.') + return orders[0] + elif isinstance(expr, Mul): + orders = [contravariant_order(e) for e in expr.args] + not_zero = [o for o in orders if o != 0] + if len(not_zero) > 1: + raise ValueError('Misformed expression containing multiplication between vectors.') + return 0 if not not_zero else not_zero[0] + elif isinstance(expr, Pow): + if covariant_order(expr.base) or covariant_order(expr.exp): + raise ValueError( + 'Misformed expression containing a power of a vector.') + return 0 + elif isinstance(expr, BaseVectorField): + return 1 + elif isinstance(expr, TensorProduct): + return sum(contravariant_order(a) for a in expr.args) + elif not _strict or expr.atoms(BaseScalarField): + return 0 + else: # If it does not contain anything related to the diffgeom module and it is _strict + return -1 + + +def covariant_order(expr, _strict=False): + """Return the covariant order of an expression. + + Examples + ======== + + >>> from sympy.diffgeom import covariant_order + >>> from sympy.diffgeom.rn import R2 + >>> from sympy.abc import a + + >>> covariant_order(a) + 0 + >>> covariant_order(a*R2.x + 2) + 0 + >>> covariant_order(a*R2.x*R2.dy + R2.dx) + 1 + + """ + # TODO move some of this to class methods. + # TODO rewrite using the .as_blah_blah methods + if isinstance(expr, Add): + orders = [covariant_order(e) for e in expr.args] + if len(set(orders)) != 1: + raise ValueError('Misformed expression containing form fields of varying order.') + return orders[0] + elif isinstance(expr, Mul): + orders = [covariant_order(e) for e in expr.args] + not_zero = [o for o in orders if o != 0] + if len(not_zero) > 1: + raise ValueError('Misformed expression containing multiplication between forms.') + return 0 if not not_zero else not_zero[0] + elif isinstance(expr, Pow): + if covariant_order(expr.base) or covariant_order(expr.exp): + raise ValueError( + 'Misformed expression containing a power of a form.') + return 0 + elif isinstance(expr, Differential): + return covariant_order(*expr.args) + 1 + elif isinstance(expr, TensorProduct): + return sum(covariant_order(a) for a in expr.args) + elif not _strict or expr.atoms(BaseScalarField): + return 0 + else: # If it does not contain anything related to the diffgeom module and it is _strict + return -1 + + +############################################################################### +# Coordinate transformation functions +############################################################################### +def vectors_in_basis(expr, to_sys): + """Transform all base vectors in base vectors of a specified coord basis. + While the new base vectors are in the new coordinate system basis, any + coefficients are kept in the old system. + + Examples + ======== + + >>> from sympy.diffgeom import vectors_in_basis + >>> from sympy.diffgeom.rn import R2_r, R2_p + + >>> vectors_in_basis(R2_r.e_x, R2_p) + -y*e_theta/(x**2 + y**2) + x*e_rho/sqrt(x**2 + y**2) + >>> vectors_in_basis(R2_p.e_r, R2_r) + sin(theta)*e_y + cos(theta)*e_x + + """ + vectors = list(expr.atoms(BaseVectorField)) + new_vectors = [] + for v in vectors: + cs = v._coord_sys + jac = cs.jacobian(to_sys, cs.coord_functions()) + new = (jac.T*Matrix(to_sys.base_vectors()))[v._index] + new_vectors.append(new) + return expr.subs(list(zip(vectors, new_vectors))) + + +############################################################################### +# Coordinate-dependent functions +############################################################################### +def twoform_to_matrix(expr): + """Return the matrix representing the twoform. + + For the twoform `w` return the matrix `M` such that `M[i,j]=w(e_i, e_j)`, + where `e_i` is the i-th base vector field for the coordinate system in + which the expression of `w` is given. + + Examples + ======== + + >>> from sympy.diffgeom.rn import R2 + >>> from sympy.diffgeom import twoform_to_matrix, TensorProduct + >>> TP = TensorProduct + + >>> twoform_to_matrix(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy)) + Matrix([ + [1, 0], + [0, 1]]) + >>> twoform_to_matrix(R2.x*TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy)) + Matrix([ + [x, 0], + [0, 1]]) + >>> twoform_to_matrix(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy) - TP(R2.dx, R2.dy)/2) + Matrix([ + [ 1, 0], + [-1/2, 1]]) + + """ + if covariant_order(expr) != 2 or contravariant_order(expr): + raise ValueError('The input expression is not a two-form.') + coord_sys = _find_coords(expr) + if len(coord_sys) != 1: + raise ValueError('The input expression concerns more than one ' + 'coordinate systems, hence there is no unambiguous ' + 'way to choose a coordinate system for the matrix.') + coord_sys = coord_sys.pop() + vectors = coord_sys.base_vectors() + expr = expr.expand() + matrix_content = [[expr.rcall(v1, v2) for v1 in vectors] + for v2 in vectors] + return Matrix(matrix_content) + + +def metric_to_Christoffel_1st(expr): + """Return the nested list of Christoffel symbols for the given metric. + This returns the Christoffel symbol of first kind that represents the + Levi-Civita connection for the given metric. + + Examples + ======== + + >>> from sympy.diffgeom.rn import R2 + >>> from sympy.diffgeom import metric_to_Christoffel_1st, TensorProduct + >>> TP = TensorProduct + + >>> metric_to_Christoffel_1st(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy)) + [[[0, 0], [0, 0]], [[0, 0], [0, 0]]] + >>> metric_to_Christoffel_1st(R2.x*TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy)) + [[[1/2, 0], [0, 0]], [[0, 0], [0, 0]]] + + """ + matrix = twoform_to_matrix(expr) + if not matrix.is_symmetric(): + raise ValueError( + 'The two-form representing the metric is not symmetric.') + coord_sys = _find_coords(expr).pop() + deriv_matrices = [matrix.applyfunc(d) for d in coord_sys.base_vectors()] + indices = list(range(coord_sys.dim)) + christoffel = [[[(deriv_matrices[k][i, j] + deriv_matrices[j][i, k] - deriv_matrices[i][j, k])/2 + for k in indices] + for j in indices] + for i in indices] + return ImmutableDenseNDimArray(christoffel) + + +def metric_to_Christoffel_2nd(expr): + """Return the nested list of Christoffel symbols for the given metric. + This returns the Christoffel symbol of second kind that represents the + Levi-Civita connection for the given metric. + + Examples + ======== + + >>> from sympy.diffgeom.rn import R2 + >>> from sympy.diffgeom import metric_to_Christoffel_2nd, TensorProduct + >>> TP = TensorProduct + + >>> metric_to_Christoffel_2nd(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy)) + [[[0, 0], [0, 0]], [[0, 0], [0, 0]]] + >>> metric_to_Christoffel_2nd(R2.x*TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy)) + [[[1/(2*x), 0], [0, 0]], [[0, 0], [0, 0]]] + + """ + ch_1st = metric_to_Christoffel_1st(expr) + coord_sys = _find_coords(expr).pop() + indices = list(range(coord_sys.dim)) + # XXX workaround, inverting a matrix does not work if it contains non + # symbols + #matrix = twoform_to_matrix(expr).inv() + matrix = twoform_to_matrix(expr) + s_fields = set() + for e in matrix: + s_fields.update(e.atoms(BaseScalarField)) + s_fields = list(s_fields) + dums = coord_sys.symbols + matrix = matrix.subs(list(zip(s_fields, dums))).inv().subs(list(zip(dums, s_fields))) + # XXX end of workaround + christoffel = [[[Add(*[matrix[i, l]*ch_1st[l, j, k] for l in indices]) + for k in indices] + for j in indices] + for i in indices] + return ImmutableDenseNDimArray(christoffel) + + +def metric_to_Riemann_components(expr): + """Return the components of the Riemann tensor expressed in a given basis. + + Given a metric it calculates the components of the Riemann tensor in the + canonical basis of the coordinate system in which the metric expression is + given. + + Examples + ======== + + >>> from sympy import exp + >>> from sympy.diffgeom.rn import R2 + >>> from sympy.diffgeom import metric_to_Riemann_components, TensorProduct + >>> TP = TensorProduct + + >>> metric_to_Riemann_components(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy)) + [[[[0, 0], [0, 0]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]] + >>> non_trivial_metric = exp(2*R2.r)*TP(R2.dr, R2.dr) + \ + R2.r**2*TP(R2.dtheta, R2.dtheta) + >>> non_trivial_metric + exp(2*rho)*TensorProduct(drho, drho) + rho**2*TensorProduct(dtheta, dtheta) + >>> riemann = metric_to_Riemann_components(non_trivial_metric) + >>> riemann[0, :, :, :] + [[[0, 0], [0, 0]], [[0, exp(-2*rho)*rho], [-exp(-2*rho)*rho, 0]]] + >>> riemann[1, :, :, :] + [[[0, -1/rho], [1/rho, 0]], [[0, 0], [0, 0]]] + + """ + ch_2nd = metric_to_Christoffel_2nd(expr) + coord_sys = _find_coords(expr).pop() + indices = list(range(coord_sys.dim)) + deriv_ch = [[[[d(ch_2nd[i, j, k]) + for d in coord_sys.base_vectors()] + for k in indices] + for j in indices] + for i in indices] + riemann_a = [[[[deriv_ch[rho][sig][nu][mu] - deriv_ch[rho][sig][mu][nu] + for nu in indices] + for mu in indices] + for sig in indices] + for rho in indices] + riemann_b = [[[[Add(*[ch_2nd[rho, l, mu]*ch_2nd[l, sig, nu] - ch_2nd[rho, l, nu]*ch_2nd[l, sig, mu] for l in indices]) + for nu in indices] + for mu in indices] + for sig in indices] + for rho in indices] + riemann = [[[[riemann_a[rho][sig][mu][nu] + riemann_b[rho][sig][mu][nu] + for nu in indices] + for mu in indices] + for sig in indices] + for rho in indices] + return ImmutableDenseNDimArray(riemann) + + +def metric_to_Ricci_components(expr): + + """Return the components of the Ricci tensor expressed in a given basis. + + Given a metric it calculates the components of the Ricci tensor in the + canonical basis of the coordinate system in which the metric expression is + given. + + Examples + ======== + + >>> from sympy import exp + >>> from sympy.diffgeom.rn import R2 + >>> from sympy.diffgeom import metric_to_Ricci_components, TensorProduct + >>> TP = TensorProduct + + >>> metric_to_Ricci_components(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy)) + [[0, 0], [0, 0]] + >>> non_trivial_metric = exp(2*R2.r)*TP(R2.dr, R2.dr) + \ + R2.r**2*TP(R2.dtheta, R2.dtheta) + >>> non_trivial_metric + exp(2*rho)*TensorProduct(drho, drho) + rho**2*TensorProduct(dtheta, dtheta) + >>> metric_to_Ricci_components(non_trivial_metric) + [[1/rho, 0], [0, exp(-2*rho)*rho]] + + """ + riemann = metric_to_Riemann_components(expr) + coord_sys = _find_coords(expr).pop() + indices = list(range(coord_sys.dim)) + ricci = [[Add(*[riemann[k, i, k, j] for k in indices]) + for j in indices] + for i in indices] + return ImmutableDenseNDimArray(ricci) + +############################################################################### +# Classes for deprecation +############################################################################### + +class _deprecated_container: + # This class gives deprecation warning. + # When deprecated features are completely deleted, this should be removed as well. + # See https://github.com/sympy/sympy/pull/19368 + def __init__(self, message, data): + super().__init__(data) + self.message = message + + def warn(self): + sympy_deprecation_warning( + self.message, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-diffgeom-mutable", + stacklevel=4 + ) + + def __iter__(self): + self.warn() + return super().__iter__() + + def __getitem__(self, key): + self.warn() + return super().__getitem__(key) + + def __contains__(self, key): + self.warn() + return super().__contains__(key) + + +class _deprecated_list(_deprecated_container, list): + pass + + +class _deprecated_dict(_deprecated_container, dict): + pass + + +# Import at end to avoid cyclic imports +from sympy.simplify.simplify import simplify diff --git a/.venv/lib/python3.13/site-packages/sympy/diffgeom/rn.py b/.venv/lib/python3.13/site-packages/sympy/diffgeom/rn.py new file mode 100644 index 0000000000000000000000000000000000000000..897c7e82bc804d260612f79c820af92632f3b281 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/diffgeom/rn.py @@ -0,0 +1,143 @@ +"""Predefined R^n manifolds together with common coord. systems. + +Coordinate systems are predefined as well as the transformation laws between +them. + +Coordinate functions can be accessed as attributes of the manifold (eg `R2.x`), +as attributes of the coordinate systems (eg `R2_r.x` and `R2_p.theta`), or by +using the usual `coord_sys.coord_function(index, name)` interface. +""" + +from typing import Any +import warnings + +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, atan2, cos, sin) +from .diffgeom import Manifold, Patch, CoordSystem + +__all__ = [ + 'R2', 'R2_origin', 'relations_2d', 'R2_r', 'R2_p', + 'R3', 'R3_origin', 'relations_3d', 'R3_r', 'R3_c', 'R3_s' +] + +############################################################################### +# R2 +############################################################################### +R2: Any = Manifold('R^2', 2) + +R2_origin: Any = Patch('origin', R2) + +x, y = symbols('x y', real=True) +r, theta = symbols('rho theta', nonnegative=True) + +relations_2d = { + ('rectangular', 'polar'): [(x, y), (sqrt(x**2 + y**2), atan2(y, x))], + ('polar', 'rectangular'): [(r, theta), (r*cos(theta), r*sin(theta))], +} + +R2_r: Any = CoordSystem('rectangular', R2_origin, (x, y), relations_2d) +R2_p: Any = CoordSystem('polar', R2_origin, (r, theta), relations_2d) + +# support deprecated feature +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x, y, r, theta = symbols('x y r theta', cls=Dummy) + R2_r.connect_to(R2_p, [x, y], + [sqrt(x**2 + y**2), atan2(y, x)], + inverse=False, fill_in_gaps=False) + R2_p.connect_to(R2_r, [r, theta], + [r*cos(theta), r*sin(theta)], + inverse=False, fill_in_gaps=False) + +# Defining the basis coordinate functions and adding shortcuts for them to the +# manifold and the patch. +R2.x, R2.y = R2_origin.x, R2_origin.y = R2_r.x, R2_r.y = R2_r.coord_functions() +R2.r, R2.theta = R2_origin.r, R2_origin.theta = R2_p.r, R2_p.theta = R2_p.coord_functions() + +# Defining the basis vector fields and adding shortcuts for them to the +# manifold and the patch. +R2.e_x, R2.e_y = R2_origin.e_x, R2_origin.e_y = R2_r.e_x, R2_r.e_y = R2_r.base_vectors() +R2.e_r, R2.e_theta = R2_origin.e_r, R2_origin.e_theta = R2_p.e_r, R2_p.e_theta = R2_p.base_vectors() + +# Defining the basis oneform fields and adding shortcuts for them to the +# manifold and the patch. +R2.dx, R2.dy = R2_origin.dx, R2_origin.dy = R2_r.dx, R2_r.dy = R2_r.base_oneforms() +R2.dr, R2.dtheta = R2_origin.dr, R2_origin.dtheta = R2_p.dr, R2_p.dtheta = R2_p.base_oneforms() + +############################################################################### +# R3 +############################################################################### +R3: Any = Manifold('R^3', 3) + +R3_origin: Any = Patch('origin', R3) + +x, y, z = symbols('x y z', real=True) +rho, psi, r, theta, phi = symbols('rho psi r theta phi', nonnegative=True) + +relations_3d = { + ('rectangular', 'cylindrical'): [(x, y, z), + (sqrt(x**2 + y**2), atan2(y, x), z)], + ('cylindrical', 'rectangular'): [(rho, psi, z), + (rho*cos(psi), rho*sin(psi), z)], + ('rectangular', 'spherical'): [(x, y, z), + (sqrt(x**2 + y**2 + z**2), + acos(z/sqrt(x**2 + y**2 + z**2)), + atan2(y, x))], + ('spherical', 'rectangular'): [(r, theta, phi), + (r*sin(theta)*cos(phi), + r*sin(theta)*sin(phi), + r*cos(theta))], + ('cylindrical', 'spherical'): [(rho, psi, z), + (sqrt(rho**2 + z**2), + acos(z/sqrt(rho**2 + z**2)), + psi)], + ('spherical', 'cylindrical'): [(r, theta, phi), + (r*sin(theta), phi, r*cos(theta))], +} + +R3_r: Any = CoordSystem('rectangular', R3_origin, (x, y, z), relations_3d) +R3_c: Any = CoordSystem('cylindrical', R3_origin, (rho, psi, z), relations_3d) +R3_s: Any = CoordSystem('spherical', R3_origin, (r, theta, phi), relations_3d) + +# support deprecated feature +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x, y, z, rho, psi, r, theta, phi = symbols('x y z rho psi r theta phi', cls=Dummy) + R3_r.connect_to(R3_c, [x, y, z], + [sqrt(x**2 + y**2), atan2(y, x), z], + inverse=False, fill_in_gaps=False) + R3_c.connect_to(R3_r, [rho, psi, z], + [rho*cos(psi), rho*sin(psi), z], + inverse=False, fill_in_gaps=False) + ## rectangular <-> spherical + R3_r.connect_to(R3_s, [x, y, z], + [sqrt(x**2 + y**2 + z**2), acos(z/ + sqrt(x**2 + y**2 + z**2)), atan2(y, x)], + inverse=False, fill_in_gaps=False) + R3_s.connect_to(R3_r, [r, theta, phi], + [r*sin(theta)*cos(phi), r*sin( + theta)*sin(phi), r*cos(theta)], + inverse=False, fill_in_gaps=False) + ## cylindrical <-> spherical + R3_c.connect_to(R3_s, [rho, psi, z], + [sqrt(rho**2 + z**2), acos(z/sqrt(rho**2 + z**2)), psi], + inverse=False, fill_in_gaps=False) + R3_s.connect_to(R3_c, [r, theta, phi], + [r*sin(theta), phi, r*cos(theta)], + inverse=False, fill_in_gaps=False) + +# Defining the basis coordinate functions. +R3_r.x, R3_r.y, R3_r.z = R3_r.coord_functions() +R3_c.rho, R3_c.psi, R3_c.z = R3_c.coord_functions() +R3_s.r, R3_s.theta, R3_s.phi = R3_s.coord_functions() + +# Defining the basis vector fields. +R3_r.e_x, R3_r.e_y, R3_r.e_z = R3_r.base_vectors() +R3_c.e_rho, R3_c.e_psi, R3_c.e_z = R3_c.base_vectors() +R3_s.e_r, R3_s.e_theta, R3_s.e_phi = R3_s.base_vectors() + +# Defining the basis oneform fields. +R3_r.dx, R3_r.dy, R3_r.dz = R3_r.base_oneforms() +R3_c.drho, R3_c.dpsi, R3_c.dz = R3_c.base_oneforms() +R3_s.dr, R3_s.dtheta, R3_s.dphi = R3_s.base_oneforms() diff --git a/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_class_structure.py b/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_class_structure.py new file mode 100644 index 0000000000000000000000000000000000000000..c649fd9fcb9acdf1f410a021966c6e0fee62cc2b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_class_structure.py @@ -0,0 +1,33 @@ +from sympy.diffgeom import Manifold, Patch, CoordSystem, Point +from sympy.core.function import Function +from sympy.core.symbol import symbols +from sympy.testing.pytest import warns_deprecated_sympy + +m = Manifold('m', 2) +p = Patch('p', m) +a, b = symbols('a b') +cs = CoordSystem('cs', p, [a, b]) +x, y = symbols('x y') +f = Function('f') +s1, s2 = cs.coord_functions() +v1, v2 = cs.base_vectors() +f1, f2 = cs.base_oneforms() + +def test_point(): + point = Point(cs, [x, y]) + assert point != Point(cs, [2, y]) + #TODO assert point.subs(x, 2) == Point(cs, [2, y]) + #TODO assert point.free_symbols == set([x, y]) + +def test_subs(): + assert s1.subs(s1, s2) == s2 + assert v1.subs(v1, v2) == v2 + assert f1.subs(f1, f2) == f2 + assert (x*f(s1) + y).subs(s1, s2) == x*f(s2) + y + assert (f(s1)*v1).subs(v1, v2) == f(s1)*v2 + assert (y*f(s1)*f1).subs(f1, f2) == y*f(s1)*f2 + +def test_deprecated(): + with warns_deprecated_sympy(): + cs_wname = CoordSystem('cs', p, ['a', 'b']) + assert cs_wname == cs_wname.func(*cs_wname.args) diff --git a/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_diffgeom.py b/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_diffgeom.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3c9265785896b8f4ffa3a2b41816ca90579758 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_diffgeom.py @@ -0,0 +1,342 @@ +from sympy.core import Lambda, Symbol, symbols +from sympy.diffgeom.rn import R2, R2_p, R2_r, R3_r, R3_c, R3_s, R2_origin +from sympy.diffgeom import (Manifold, Patch, CoordSystem, Commutator, Differential, TensorProduct, + WedgeProduct, BaseCovarDerivativeOp, CovarDerivativeOp, LieDerivative, + covariant_order, contravariant_order, twoform_to_matrix, metric_to_Christoffel_1st, + metric_to_Christoffel_2nd, metric_to_Riemann_components, + metric_to_Ricci_components, intcurve_diffequ, intcurve_series) +from sympy.simplify import trigsimp, simplify +from sympy.functions import sqrt, atan2, sin +from sympy.matrices import Matrix +from sympy.testing.pytest import raises, nocache_fail +from sympy.testing.pytest import warns_deprecated_sympy + +TP = TensorProduct + + +def test_coordsys_transform(): + # test inverse transforms + p, q, r, s = symbols('p q r s') + rel = {('first', 'second'): [(p, q), (q, -p)]} + R2_pq = CoordSystem('first', R2_origin, [p, q], rel) + R2_rs = CoordSystem('second', R2_origin, [r, s], rel) + r, s = R2_rs.symbols + assert R2_rs.transform(R2_pq) == Matrix([[-s], [r]]) + + # inverse transform impossible case + a, b = symbols('a b', positive=True) + rel = {('first', 'second'): [(a,), (-a,)]} + R2_a = CoordSystem('first', R2_origin, [a], rel) + R2_b = CoordSystem('second', R2_origin, [b], rel) + # This transformation is uninvertible because there is no positive a, b satisfying a = -b + with raises(NotImplementedError): + R2_b.transform(R2_a) + + # inverse transform ambiguous case + c, d = symbols('c d') + rel = {('first', 'second'): [(c,), (c**2,)]} + R2_c = CoordSystem('first', R2_origin, [c], rel) + R2_d = CoordSystem('second', R2_origin, [d], rel) + # The transform method should throw if it finds multiple inverses for a coordinate transformation. + with raises(ValueError): + R2_d.transform(R2_c) + + # test indirect transformation + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + rel = {('C1', 'C2'): [(a, b), (2*a, 3*b)], + ('C2', 'C3'): [(c, d), (3*c, 2*d)]} + C1 = CoordSystem('C1', R2_origin, (a, b), rel) + C2 = CoordSystem('C2', R2_origin, (c, d), rel) + C3 = CoordSystem('C3', R2_origin, (e, f), rel) + a, b = C1.symbols + c, d = C2.symbols + e, f = C3.symbols + assert C2.transform(C1) == Matrix([c/2, d/3]) + assert C1.transform(C3) == Matrix([6*a, 6*b]) + assert C3.transform(C1) == Matrix([e/6, f/6]) + assert C3.transform(C2) == Matrix([e/3, f/2]) + + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + rel = {('C1', 'C2'): [(a, b), (2*a, 3*b + 1)], + ('C3', 'C2'): [(e, f), (-e - 2, 2*f)]} + C1 = CoordSystem('C1', R2_origin, (a, b), rel) + C2 = CoordSystem('C2', R2_origin, (c, d), rel) + C3 = CoordSystem('C3', R2_origin, (e, f), rel) + a, b = C1.symbols + c, d = C2.symbols + e, f = C3.symbols + assert C2.transform(C1) == Matrix([c/2, (d - 1)/3]) + assert C1.transform(C3) == Matrix([-2*a - 2, (3*b + 1)/2]) + assert C3.transform(C1) == Matrix([-e/2 - 1, (2*f - 1)/3]) + assert C3.transform(C2) == Matrix([-e - 2, 2*f]) + + # old signature uses Lambda + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + rel = {('C1', 'C2'): Lambda((a, b), (2*a, 3*b + 1)), + ('C3', 'C2'): Lambda((e, f), (-e - 2, 2*f))} + C1 = CoordSystem('C1', R2_origin, (a, b), rel) + C2 = CoordSystem('C2', R2_origin, (c, d), rel) + C3 = CoordSystem('C3', R2_origin, (e, f), rel) + a, b = C1.symbols + c, d = C2.symbols + e, f = C3.symbols + assert C2.transform(C1) == Matrix([c/2, (d - 1)/3]) + assert C1.transform(C3) == Matrix([-2*a - 2, (3*b + 1)/2]) + assert C3.transform(C1) == Matrix([-e/2 - 1, (2*f - 1)/3]) + assert C3.transform(C2) == Matrix([-e - 2, 2*f]) + + +def test_R2(): + x0, y0, r0, theta0 = symbols('x0, y0, r0, theta0', real=True) + point_r = R2_r.point([x0, y0]) + point_p = R2_p.point([r0, theta0]) + + # r**2 = x**2 + y**2 + assert (R2.r**2 - R2.x**2 - R2.y**2).rcall(point_r) == 0 + assert trigsimp( (R2.r**2 - R2.x**2 - R2.y**2).rcall(point_p) ) == 0 + assert trigsimp(R2.e_r(R2.x**2 + R2.y**2).rcall(point_p).doit()) == 2*r0 + + # polar->rect->polar == Id + a, b = symbols('a b', positive=True) + m = Matrix([[a], [b]]) + + #TODO assert m == R2_r.transform(R2_p, R2_p.transform(R2_r, [a, b])).applyfunc(simplify) + assert m == R2_p.transform(R2_r, R2_r.transform(R2_p, m)).applyfunc(simplify) + + # deprecated method + with warns_deprecated_sympy(): + assert m == R2_p.coord_tuple_transform_to( + R2_r, R2_r.coord_tuple_transform_to(R2_p, m)).applyfunc(simplify) + + +def test_R3(): + a, b, c = symbols('a b c', positive=True) + m = Matrix([[a], [b], [c]]) + + assert m == R3_c.transform(R3_r, R3_r.transform(R3_c, m)).applyfunc(simplify) + #TODO assert m == R3_r.transform(R3_c, R3_c.transform(R3_r, m)).applyfunc(simplify) + assert m == R3_s.transform( + R3_r, R3_r.transform(R3_s, m)).applyfunc(simplify) + #TODO assert m == R3_r.transform(R3_s, R3_s.transform(R3_r, m)).applyfunc(simplify) + assert m == R3_s.transform( + R3_c, R3_c.transform(R3_s, m)).applyfunc(simplify) + #TODO assert m == R3_c.transform(R3_s, R3_s.transform(R3_c, m)).applyfunc(simplify) + + with warns_deprecated_sympy(): + assert m == R3_c.coord_tuple_transform_to( + R3_r, R3_r.coord_tuple_transform_to(R3_c, m)).applyfunc(simplify) + #TODO assert m == R3_r.coord_tuple_transform_to(R3_c, R3_c.coord_tuple_transform_to(R3_r, m)).applyfunc(simplify) + assert m == R3_s.coord_tuple_transform_to( + R3_r, R3_r.coord_tuple_transform_to(R3_s, m)).applyfunc(simplify) + #TODO assert m == R3_r.coord_tuple_transform_to(R3_s, R3_s.coord_tuple_transform_to(R3_r, m)).applyfunc(simplify) + assert m == R3_s.coord_tuple_transform_to( + R3_c, R3_c.coord_tuple_transform_to(R3_s, m)).applyfunc(simplify) + #TODO assert m == R3_c.coord_tuple_transform_to(R3_s, R3_s.coord_tuple_transform_to(R3_c, m)).applyfunc(simplify) + + +def test_CoordinateSymbol(): + x, y = R2_r.symbols + r, theta = R2_p.symbols + assert y.rewrite(R2_p) == r*sin(theta) + + +def test_point(): + x, y = symbols('x, y') + p = R2_r.point([x, y]) + assert p.free_symbols == {x, y} + assert p.coords(R2_r) == p.coords() == Matrix([x, y]) + assert p.coords(R2_p) == Matrix([sqrt(x**2 + y**2), atan2(y, x)]) + + +def test_commutator(): + assert Commutator(R2.e_x, R2.e_y) == 0 + assert Commutator(R2.x*R2.e_x, R2.x*R2.e_x) == 0 + assert Commutator(R2.x*R2.e_x, R2.x*R2.e_y) == R2.x*R2.e_y + c = Commutator(R2.e_x, R2.e_r) + assert c(R2.x) == R2.y*(R2.x**2 + R2.y**2)**(-1)*sin(R2.theta) + + +def test_differential(): + xdy = R2.x*R2.dy + dxdy = Differential(xdy) + assert xdy.rcall(None) == xdy + assert dxdy(R2.e_x, R2.e_y) == 1 + assert dxdy(R2.e_x, R2.x*R2.e_y) == R2.x + assert Differential(dxdy) == 0 + + +def test_products(): + assert TensorProduct( + R2.dx, R2.dy)(R2.e_x, R2.e_y) == R2.dx(R2.e_x)*R2.dy(R2.e_y) == 1 + assert TensorProduct(R2.dx, R2.dy)(None, R2.e_y) == R2.dx + assert TensorProduct(R2.dx, R2.dy)(R2.e_x, None) == R2.dy + assert TensorProduct(R2.dx, R2.dy)(R2.e_x) == R2.dy + assert TensorProduct(R2.x, R2.dx) == R2.x*R2.dx + assert TensorProduct( + R2.e_x, R2.e_y)(R2.x, R2.y) == R2.e_x(R2.x) * R2.e_y(R2.y) == 1 + assert TensorProduct(R2.e_x, R2.e_y)(None, R2.y) == R2.e_x + assert TensorProduct(R2.e_x, R2.e_y)(R2.x, None) == R2.e_y + assert TensorProduct(R2.e_x, R2.e_y)(R2.x) == R2.e_y + assert TensorProduct(R2.x, R2.e_x) == R2.x * R2.e_x + assert TensorProduct( + R2.dx, R2.e_y)(R2.e_x, R2.y) == R2.dx(R2.e_x) * R2.e_y(R2.y) == 1 + assert TensorProduct(R2.dx, R2.e_y)(None, R2.y) == R2.dx + assert TensorProduct(R2.dx, R2.e_y)(R2.e_x, None) == R2.e_y + assert TensorProduct(R2.dx, R2.e_y)(R2.e_x) == R2.e_y + assert TensorProduct(R2.x, R2.e_x) == R2.x * R2.e_x + assert TensorProduct( + R2.e_x, R2.dy)(R2.x, R2.e_y) == R2.e_x(R2.x) * R2.dy(R2.e_y) == 1 + assert TensorProduct(R2.e_x, R2.dy)(None, R2.e_y) == R2.e_x + assert TensorProduct(R2.e_x, R2.dy)(R2.x, None) == R2.dy + assert TensorProduct(R2.e_x, R2.dy)(R2.x) == R2.dy + assert TensorProduct(R2.e_y,R2.e_x)(R2.x**2 + R2.y**2,R2.x**2 + R2.y**2) == 4*R2.x*R2.y + + assert WedgeProduct(R2.dx, R2.dy)(R2.e_x, R2.e_y) == 1 + assert WedgeProduct(R2.e_x, R2.e_y)(R2.x, R2.y) == 1 + + +def test_lie_derivative(): + assert LieDerivative(R2.e_x, R2.y) == R2.e_x(R2.y) == 0 + assert LieDerivative(R2.e_x, R2.x) == R2.e_x(R2.x) == 1 + assert LieDerivative(R2.e_x, R2.e_x) == Commutator(R2.e_x, R2.e_x) == 0 + assert LieDerivative(R2.e_x, R2.e_r) == Commutator(R2.e_x, R2.e_r) + assert LieDerivative(R2.e_x + R2.e_y, R2.x) == 1 + assert LieDerivative( + R2.e_x, TensorProduct(R2.dx, R2.dy))(R2.e_x, R2.e_y) == 0 + + +@nocache_fail +def test_covar_deriv(): + ch = metric_to_Christoffel_2nd(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy)) + cvd = BaseCovarDerivativeOp(R2_r, 0, ch) + assert cvd(R2.x) == 1 + # This line fails if the cache is disabled: + assert cvd(R2.x*R2.e_x) == R2.e_x + cvd = CovarDerivativeOp(R2.x*R2.e_x, ch) + assert cvd(R2.x) == R2.x + assert cvd(R2.x*R2.e_x) == R2.x*R2.e_x + + +def test_intcurve_diffequ(): + t = symbols('t') + start_point = R2_r.point([1, 0]) + vector_field = -R2.y*R2.e_x + R2.x*R2.e_y + equations, init_cond = intcurve_diffequ(vector_field, t, start_point) + assert str(equations) == '[f_1(t) + Derivative(f_0(t), t), -f_0(t) + Derivative(f_1(t), t)]' + assert str(init_cond) == '[f_0(0) - 1, f_1(0)]' + equations, init_cond = intcurve_diffequ(vector_field, t, start_point, R2_p) + assert str( + equations) == '[Derivative(f_0(t), t), Derivative(f_1(t), t) - 1]' + assert str(init_cond) == '[f_0(0) - 1, f_1(0)]' + + +def test_helpers_and_coordinate_dependent(): + one_form = R2.dr + R2.dx + two_form = Differential(R2.x*R2.dr + R2.r*R2.dx) + three_form = Differential( + R2.y*two_form) + Differential(R2.x*Differential(R2.r*R2.dr)) + metric = TensorProduct(R2.dx, R2.dx) + TensorProduct(R2.dy, R2.dy) + metric_ambig = TensorProduct(R2.dx, R2.dx) + TensorProduct(R2.dr, R2.dr) + misform_a = TensorProduct(R2.dr, R2.dr) + R2.dr + misform_b = R2.dr**4 + misform_c = R2.dx*R2.dy + twoform_not_sym = TensorProduct(R2.dx, R2.dx) + TensorProduct(R2.dx, R2.dy) + twoform_not_TP = WedgeProduct(R2.dx, R2.dy) + + one_vector = R2.e_x + R2.e_y + two_vector = TensorProduct(R2.e_x, R2.e_y) + three_vector = TensorProduct(R2.e_x, R2.e_y, R2.e_x) + two_wp = WedgeProduct(R2.e_x,R2.e_y) + + assert covariant_order(one_form) == 1 + assert covariant_order(two_form) == 2 + assert covariant_order(three_form) == 3 + assert covariant_order(two_form + metric) == 2 + assert covariant_order(two_form + metric_ambig) == 2 + assert covariant_order(two_form + twoform_not_sym) == 2 + assert covariant_order(two_form + twoform_not_TP) == 2 + + assert contravariant_order(one_vector) == 1 + assert contravariant_order(two_vector) == 2 + assert contravariant_order(three_vector) == 3 + assert contravariant_order(two_vector + two_wp) == 2 + + raises(ValueError, lambda: covariant_order(misform_a)) + raises(ValueError, lambda: covariant_order(misform_b)) + raises(ValueError, lambda: covariant_order(misform_c)) + + assert twoform_to_matrix(metric) == Matrix([[1, 0], [0, 1]]) + assert twoform_to_matrix(twoform_not_sym) == Matrix([[1, 0], [1, 0]]) + assert twoform_to_matrix(twoform_not_TP) == Matrix([[0, -1], [1, 0]]) + + raises(ValueError, lambda: twoform_to_matrix(one_form)) + raises(ValueError, lambda: twoform_to_matrix(three_form)) + raises(ValueError, lambda: twoform_to_matrix(metric_ambig)) + + raises(ValueError, lambda: metric_to_Christoffel_1st(twoform_not_sym)) + raises(ValueError, lambda: metric_to_Christoffel_2nd(twoform_not_sym)) + raises(ValueError, lambda: metric_to_Riemann_components(twoform_not_sym)) + raises(ValueError, lambda: metric_to_Ricci_components(twoform_not_sym)) + + +def test_correct_arguments(): + raises(ValueError, lambda: R2.e_x(R2.e_x)) + raises(ValueError, lambda: R2.e_x(R2.dx)) + + raises(ValueError, lambda: Commutator(R2.e_x, R2.x)) + raises(ValueError, lambda: Commutator(R2.dx, R2.e_x)) + + raises(ValueError, lambda: Differential(Differential(R2.e_x))) + + raises(ValueError, lambda: R2.dx(R2.x)) + + raises(ValueError, lambda: LieDerivative(R2.dx, R2.dx)) + raises(ValueError, lambda: LieDerivative(R2.x, R2.dx)) + + raises(ValueError, lambda: CovarDerivativeOp(R2.dx, [])) + raises(ValueError, lambda: CovarDerivativeOp(R2.x, [])) + + a = Symbol('a') + raises(ValueError, lambda: intcurve_series(R2.dx, a, R2_r.point([1, 2]))) + raises(ValueError, lambda: intcurve_series(R2.x, a, R2_r.point([1, 2]))) + + raises(ValueError, lambda: intcurve_diffequ(R2.dx, a, R2_r.point([1, 2]))) + raises(ValueError, lambda: intcurve_diffequ(R2.x, a, R2_r.point([1, 2]))) + + raises(ValueError, lambda: contravariant_order(R2.e_x + R2.dx)) + raises(ValueError, lambda: covariant_order(R2.e_x + R2.dx)) + + raises(ValueError, lambda: contravariant_order(R2.e_x*R2.e_y)) + raises(ValueError, lambda: covariant_order(R2.dx*R2.dy)) + +def test_simplify(): + x, y = R2_r.coord_functions() + dx, dy = R2_r.base_oneforms() + ex, ey = R2_r.base_vectors() + assert simplify(x) == x + assert simplify(x*y) == x*y + assert simplify(dx*dy) == dx*dy + assert simplify(ex*ey) == ex*ey + assert ((1-x)*dx)/(1-x)**2 == dx/(1-x) + + +def test_issue_17917(): + X = R2.x*R2.e_x - R2.y*R2.e_y + Y = (R2.x**2 + R2.y**2)*R2.e_x - R2.x*R2.y*R2.e_y + assert LieDerivative(X, Y).expand() == ( + R2.x**2*R2.e_x - 3*R2.y**2*R2.e_x - R2.x*R2.y*R2.e_y) + +def test_deprecations(): + m = Manifold('M', 2) + p = Patch('P', m) + with warns_deprecated_sympy(): + CoordSystem('Car2d', p, names=['x', 'y']) + + with warns_deprecated_sympy(): + c = CoordSystem('Car2d', p, ['x', 'y']) + + with warns_deprecated_sympy(): + list(m.patches) + + with warns_deprecated_sympy(): + list(c.transforms) diff --git a/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_function_diffgeom_book.py b/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_function_diffgeom_book.py new file mode 100644 index 0000000000000000000000000000000000000000..44d9623bc34ab73c7d575d9d9fd5b6d84f8e4a94 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_function_diffgeom_book.py @@ -0,0 +1,145 @@ +from sympy.diffgeom.rn import R2, R2_p, R2_r, R3_r +from sympy.diffgeom import intcurve_series, Differential, WedgeProduct +from sympy.core import symbols, Function, Derivative +from sympy.simplify import trigsimp, simplify +from sympy.functions import sqrt, atan2, sin, cos +from sympy.matrices import Matrix + +# Most of the functionality is covered in the +# test_functional_diffgeom_ch* tests which are based on the +# example from the paper of Sussman and Wisdom. +# If they do not cover something, additional tests are added in other test +# functions. + +# From "Functional Differential Geometry" as of 2011 +# by Sussman and Wisdom. + + +def test_functional_diffgeom_ch2(): + x0, y0, r0, theta0 = symbols('x0, y0, r0, theta0', real=True) + x, y = symbols('x, y', real=True) + f = Function('f') + + assert (R2_p.point_to_coords(R2_r.point([x0, y0])) == + Matrix([sqrt(x0**2 + y0**2), atan2(y0, x0)])) + assert (R2_r.point_to_coords(R2_p.point([r0, theta0])) == + Matrix([r0*cos(theta0), r0*sin(theta0)])) + + assert R2_p.jacobian(R2_r, [r0, theta0]) == Matrix( + [[cos(theta0), -r0*sin(theta0)], [sin(theta0), r0*cos(theta0)]]) + + field = f(R2.x, R2.y) + p1_in_rect = R2_r.point([x0, y0]) + p1_in_polar = R2_p.point([sqrt(x0**2 + y0**2), atan2(y0, x0)]) + assert field.rcall(p1_in_rect) == f(x0, y0) + assert field.rcall(p1_in_polar) == f(x0, y0) + + p_r = R2_r.point([x0, y0]) + p_p = R2_p.point([r0, theta0]) + assert R2.x(p_r) == x0 + assert R2.x(p_p) == r0*cos(theta0) + assert R2.r(p_p) == r0 + assert R2.r(p_r) == sqrt(x0**2 + y0**2) + assert R2.theta(p_r) == atan2(y0, x0) + + h = R2.x*R2.r**2 + R2.y**3 + assert h.rcall(p_r) == x0*(x0**2 + y0**2) + y0**3 + assert h.rcall(p_p) == r0**3*sin(theta0)**3 + r0**3*cos(theta0) + + +def test_functional_diffgeom_ch3(): + x0, y0 = symbols('x0, y0', real=True) + x, y, t = symbols('x, y, t', real=True) + f = Function('f') + b1 = Function('b1') + b2 = Function('b2') + p_r = R2_r.point([x0, y0]) + + s_field = f(R2.x, R2.y) + v_field = b1(R2.x)*R2.e_x + b2(R2.y)*R2.e_y + assert v_field.rcall(s_field).rcall(p_r).doit() == b1( + x0)*Derivative(f(x0, y0), x0) + b2(y0)*Derivative(f(x0, y0), y0) + + assert R2.e_x(R2.r**2).rcall(p_r) == 2*x0 + v = R2.e_x + 2*R2.e_y + s = R2.r**2 + 3*R2.x + assert v.rcall(s).rcall(p_r).doit() == 2*x0 + 4*y0 + 3 + + circ = -R2.y*R2.e_x + R2.x*R2.e_y + series = intcurve_series(circ, t, R2_r.point([1, 0]), coeffs=True) + series_x, series_y = zip(*series) + assert all( + term == cos(t).taylor_term(i, t) for i, term in enumerate(series_x)) + assert all( + term == sin(t).taylor_term(i, t) for i, term in enumerate(series_y)) + + +def test_functional_diffgeom_ch4(): + x0, y0, theta0 = symbols('x0, y0, theta0', real=True) + x, y, r, theta = symbols('x, y, r, theta', real=True) + r0 = symbols('r0', positive=True) + f = Function('f') + b1 = Function('b1') + b2 = Function('b2') + p_r = R2_r.point([x0, y0]) + p_p = R2_p.point([r0, theta0]) + + f_field = b1(R2.x, R2.y)*R2.dx + b2(R2.x, R2.y)*R2.dy + assert f_field.rcall(R2.e_x).rcall(p_r) == b1(x0, y0) + assert f_field.rcall(R2.e_y).rcall(p_r) == b2(x0, y0) + + s_field_r = f(R2.x, R2.y) + df = Differential(s_field_r) + assert df(R2.e_x).rcall(p_r).doit() == Derivative(f(x0, y0), x0) + assert df(R2.e_y).rcall(p_r).doit() == Derivative(f(x0, y0), y0) + + s_field_p = f(R2.r, R2.theta) + df = Differential(s_field_p) + assert trigsimp(df(R2.e_x).rcall(p_p).doit()) == ( + cos(theta0)*Derivative(f(r0, theta0), r0) - + sin(theta0)*Derivative(f(r0, theta0), theta0)/r0) + assert trigsimp(df(R2.e_y).rcall(p_p).doit()) == ( + sin(theta0)*Derivative(f(r0, theta0), r0) + + cos(theta0)*Derivative(f(r0, theta0), theta0)/r0) + + assert R2.dx(R2.e_x).rcall(p_r) == 1 + assert R2.dx(R2.e_x) == 1 + assert R2.dx(R2.e_y).rcall(p_r) == 0 + assert R2.dx(R2.e_y) == 0 + + circ = -R2.y*R2.e_x + R2.x*R2.e_y + assert R2.dx(circ).rcall(p_r).doit() == -y0 + assert R2.dy(circ).rcall(p_r) == x0 + assert R2.dr(circ).rcall(p_r) == 0 + assert simplify(R2.dtheta(circ).rcall(p_r)) == 1 + + assert (circ - R2.e_theta).rcall(s_field_r).rcall(p_r) == 0 + + +def test_functional_diffgeom_ch6(): + u0, u1, u2, v0, v1, v2, w0, w1, w2 = symbols('u0:3, v0:3, w0:3', real=True) + + u = u0*R2.e_x + u1*R2.e_y + v = v0*R2.e_x + v1*R2.e_y + wp = WedgeProduct(R2.dx, R2.dy) + assert wp(u, v) == u0*v1 - u1*v0 + + u = u0*R3_r.e_x + u1*R3_r.e_y + u2*R3_r.e_z + v = v0*R3_r.e_x + v1*R3_r.e_y + v2*R3_r.e_z + w = w0*R3_r.e_x + w1*R3_r.e_y + w2*R3_r.e_z + wp = WedgeProduct(R3_r.dx, R3_r.dy, R3_r.dz) + assert wp( + u, v, w) == Matrix(3, 3, [u0, u1, u2, v0, v1, v2, w0, w1, w2]).det() + + a, b, c = symbols('a, b, c', cls=Function) + a_f = a(R3_r.x, R3_r.y, R3_r.z) + b_f = b(R3_r.x, R3_r.y, R3_r.z) + c_f = c(R3_r.x, R3_r.y, R3_r.z) + theta = a_f*R3_r.dx + b_f*R3_r.dy + c_f*R3_r.dz + dtheta = Differential(theta) + da = Differential(a_f) + db = Differential(b_f) + dc = Differential(c_f) + expr = dtheta - WedgeProduct( + da, R3_r.dx) - WedgeProduct(db, R3_r.dy) - WedgeProduct(dc, R3_r.dz) + assert expr.rcall(R3_r.e_x, R3_r.e_y) == 0 diff --git a/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_hyperbolic_space.py b/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_hyperbolic_space.py new file mode 100644 index 0000000000000000000000000000000000000000..48ddc7f8065f2b69bcd8eca4726a21c5901514ec --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_hyperbolic_space.py @@ -0,0 +1,91 @@ +r''' +unit test describing the hyperbolic half-plane with the Poincare metric. This +is a basic model of hyperbolic geometry on the (positive) half-space + +{(x,y) \in R^2 | y > 0} + +with the Riemannian metric + +ds^2 = (dx^2 + dy^2)/y^2 + +It has constant negative scalar curvature = -2 + +https://en.wikipedia.org/wiki/Poincare_half-plane_model +''' +from sympy.matrices.dense import diag +from sympy.diffgeom import (twoform_to_matrix, + metric_to_Christoffel_1st, metric_to_Christoffel_2nd, + metric_to_Riemann_components, metric_to_Ricci_components) +import sympy.diffgeom.rn +from sympy.tensor.array import ImmutableDenseNDimArray + + +def test_H2(): + TP = sympy.diffgeom.TensorProduct + R2 = sympy.diffgeom.rn.R2 + y = R2.y + dy = R2.dy + dx = R2.dx + g = (TP(dx, dx) + TP(dy, dy))*y**(-2) + automat = twoform_to_matrix(g) + mat = diag(y**(-2), y**(-2)) + assert mat == automat + + gamma1 = metric_to_Christoffel_1st(g) + assert gamma1[0, 0, 0] == 0 + assert gamma1[0, 0, 1] == -y**(-3) + assert gamma1[0, 1, 0] == -y**(-3) + assert gamma1[0, 1, 1] == 0 + + assert gamma1[1, 1, 1] == -y**(-3) + assert gamma1[1, 1, 0] == 0 + assert gamma1[1, 0, 1] == 0 + assert gamma1[1, 0, 0] == y**(-3) + + gamma2 = metric_to_Christoffel_2nd(g) + assert gamma2[0, 0, 0] == 0 + assert gamma2[0, 0, 1] == -y**(-1) + assert gamma2[0, 1, 0] == -y**(-1) + assert gamma2[0, 1, 1] == 0 + + assert gamma2[1, 1, 1] == -y**(-1) + assert gamma2[1, 1, 0] == 0 + assert gamma2[1, 0, 1] == 0 + assert gamma2[1, 0, 0] == y**(-1) + + Rm = metric_to_Riemann_components(g) + assert Rm[0, 0, 0, 0] == 0 + assert Rm[0, 0, 0, 1] == 0 + assert Rm[0, 0, 1, 0] == 0 + assert Rm[0, 0, 1, 1] == 0 + + assert Rm[0, 1, 0, 0] == 0 + assert Rm[0, 1, 0, 1] == -y**(-2) + assert Rm[0, 1, 1, 0] == y**(-2) + assert Rm[0, 1, 1, 1] == 0 + + assert Rm[1, 0, 0, 0] == 0 + assert Rm[1, 0, 0, 1] == y**(-2) + assert Rm[1, 0, 1, 0] == -y**(-2) + assert Rm[1, 0, 1, 1] == 0 + + assert Rm[1, 1, 0, 0] == 0 + assert Rm[1, 1, 0, 1] == 0 + assert Rm[1, 1, 1, 0] == 0 + assert Rm[1, 1, 1, 1] == 0 + + Ric = metric_to_Ricci_components(g) + assert Ric[0, 0] == -y**(-2) + assert Ric[0, 1] == 0 + assert Ric[1, 0] == 0 + assert Ric[0, 0] == -y**(-2) + + assert Ric == ImmutableDenseNDimArray([-y**(-2), 0, 0, -y**(-2)], (2, 2)) + + ## scalar curvature is -2 + #TODO - it would be nice to have index contraction built-in + R = (Ric[0, 0] + Ric[1, 1])*y**2 + assert R == -2 + + ## Gauss curvature is -1 + assert R/2 == -1 diff --git a/.venv/lib/python3.13/site-packages/sympy/external/__init__.py b/.venv/lib/python3.13/site-packages/sympy/external/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..549b4b96cdce0ee4d31960e89cb9dc26af0e105d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/external/__init__.py @@ -0,0 +1,20 @@ +""" +Unified place for determining if external dependencies are installed or not. + +You should import all external modules using the import_module() function. + +For example + +>>> from sympy.external import import_module +>>> numpy = import_module('numpy') + +If the resulting library is not installed, or if the installed version +is less than a given minimum version, the function will return None. +Otherwise, it will return the library. See the docstring of +import_module() for more information. + +""" + +from sympy.external.importtools import import_module + +__all__ = ['import_module'] diff --git a/.venv/lib/python3.13/site-packages/sympy/external/gmpy.py b/.venv/lib/python3.13/site-packages/sympy/external/gmpy.py new file mode 100644 index 0000000000000000000000000000000000000000..d26942864bf4786e72198d3640d488857b3313f4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/external/gmpy.py @@ -0,0 +1,342 @@ +from __future__ import annotations +import os +from ctypes import c_long, sizeof +from functools import reduce +from typing import Type +from warnings import warn + +from sympy.external import import_module + +from .pythonmpq import PythonMPQ + +from .ntheory import ( + bit_scan1 as python_bit_scan1, + bit_scan0 as python_bit_scan0, + remove as python_remove, + factorial as python_factorial, + sqrt as python_sqrt, + sqrtrem as python_sqrtrem, + gcd as python_gcd, + lcm as python_lcm, + gcdext as python_gcdext, + is_square as python_is_square, + invert as python_invert, + legendre as python_legendre, + jacobi as python_jacobi, + kronecker as python_kronecker, + iroot as python_iroot, + is_fermat_prp as python_is_fermat_prp, + is_euler_prp as python_is_euler_prp, + is_strong_prp as python_is_strong_prp, + is_fibonacci_prp as python_is_fibonacci_prp, + is_lucas_prp as python_is_lucas_prp, + is_selfridge_prp as python_is_selfridge_prp, + is_strong_lucas_prp as python_is_strong_lucas_prp, + is_strong_selfridge_prp as python_is_strong_selfridge_prp, + is_bpsw_prp as python_is_bpsw_prp, + is_strong_bpsw_prp as python_is_strong_bpsw_prp, +) + + +__all__ = [ + # GROUND_TYPES is either 'gmpy' or 'python' depending on which is used. If + # gmpy is installed then it will be used unless the environment variable + # SYMPY_GROUND_TYPES is set to something other than 'auto', 'gmpy', or + # 'gmpy2'. + 'GROUND_TYPES', + + # If HAS_GMPY is 0, no supported version of gmpy is available. Otherwise, + # HAS_GMPY will be 2 for gmpy2 if GROUND_TYPES is 'gmpy'. It used to be + # possible for HAS_GMPY to be 1 for gmpy but gmpy is no longer supported. + 'HAS_GMPY', + + # SYMPY_INTS is a tuple containing the base types for valid integer types. + # This is either (int,) or (int, type(mpz(0))) depending on GROUND_TYPES. + 'SYMPY_INTS', + + # MPQ is either gmpy.mpq or the Python equivalent from + # sympy.external.pythonmpq + 'MPQ', + + # MPZ is either gmpy.mpz or int. + 'MPZ', + + 'bit_scan1', + 'bit_scan0', + 'remove', + 'factorial', + 'sqrt', + 'is_square', + 'sqrtrem', + 'gcd', + 'lcm', + 'gcdext', + 'invert', + 'legendre', + 'jacobi', + 'kronecker', + 'iroot', + 'is_fermat_prp', + 'is_euler_prp', + 'is_strong_prp', + 'is_fibonacci_prp', + 'is_lucas_prp', + 'is_selfridge_prp', + 'is_strong_lucas_prp', + 'is_strong_selfridge_prp', + 'is_bpsw_prp', + 'is_strong_bpsw_prp', +] + + +# +# Tested python-flint version. Future versions might work but we will only use +# them if explicitly requested by SYMPY_GROUND_TYPES=flint. +# +_PYTHON_FLINT_VERSION_NEEDED = ["0.6", "0.7", "0.8", "0.9", "0.10"] + + +def _flint_version_okay(flint_version): + major, minor = flint_version.split('.')[:2] + flint_ver = f'{major}.{minor}' + return flint_ver in _PYTHON_FLINT_VERSION_NEEDED + +# +# We will only use gmpy2 >= 2.0.0 +# +_GMPY2_MIN_VERSION = '2.0.0' + + +def _get_flint(sympy_ground_types): + if sympy_ground_types not in ('auto', 'flint'): + return None + + try: + import flint + # Earlier versions of python-flint may not have __version__. + from flint import __version__ as _flint_version + except ImportError: + if sympy_ground_types == 'flint': + warn("SYMPY_GROUND_TYPES was set to flint but python-flint is not " + "installed. Falling back to other ground types.") + return None + + if _flint_version_okay(_flint_version): + return flint + elif sympy_ground_types == 'auto': + return None + else: + warn(f"Using python-flint {_flint_version} because SYMPY_GROUND_TYPES " + f"is set to flint but this version of SymPy is only tested " + f"with python-flint versions {_PYTHON_FLINT_VERSION_NEEDED}.") + return flint + + +def _get_gmpy2(sympy_ground_types): + if sympy_ground_types not in ('auto', 'gmpy', 'gmpy2'): + return None + + gmpy = import_module('gmpy2', min_module_version=_GMPY2_MIN_VERSION, + module_version_attr='version', module_version_attr_call_args=()) + + if sympy_ground_types != 'auto' and gmpy is None: + warn("gmpy2 library is not installed, switching to 'python' ground types") + + return gmpy + + +# +# SYMPY_GROUND_TYPES can be flint, gmpy, gmpy2, python or auto (default) +# +_SYMPY_GROUND_TYPES = os.environ.get('SYMPY_GROUND_TYPES', 'auto').lower() +_flint = None +_gmpy = None + +# +# First handle auto-detection of flint/gmpy2. We will prefer flint if available +# or otherwise gmpy2 if available and then lastly the python types. +# +if _SYMPY_GROUND_TYPES in ('auto', 'flint'): + _flint = _get_flint(_SYMPY_GROUND_TYPES) + if _flint is not None: + _SYMPY_GROUND_TYPES = 'flint' + else: + _SYMPY_GROUND_TYPES = 'auto' + +if _SYMPY_GROUND_TYPES in ('auto', 'gmpy', 'gmpy2'): + _gmpy = _get_gmpy2(_SYMPY_GROUND_TYPES) + if _gmpy is not None: + _SYMPY_GROUND_TYPES = 'gmpy' + else: + _SYMPY_GROUND_TYPES = 'python' + +if _SYMPY_GROUND_TYPES not in ('flint', 'gmpy', 'python'): + warn("SYMPY_GROUND_TYPES environment variable unrecognised. " + "Should be 'auto', 'flint', 'gmpy', 'gmpy2' or 'python'.") + _SYMPY_GROUND_TYPES = 'python' + +# +# At this point _SYMPY_GROUND_TYPES is either flint, gmpy or python. The blocks +# below define the values exported by this module in each case. +# + +# +# In gmpy2 and flint, there are functions that take a long (or unsigned long) +# argument. That is, it is not possible to input a value larger than that. +# +LONG_MAX = (1 << (8*sizeof(c_long) - 1)) - 1 + +# +# Type checkers are confused by what SYMPY_INTS is. There may be a better type +# hint for this like Type[Integral] or something. +# +SYMPY_INTS: tuple[Type, ...] + +if _SYMPY_GROUND_TYPES == 'gmpy': + + assert _gmpy is not None + + flint = None + gmpy = _gmpy + + HAS_GMPY = 2 + GROUND_TYPES = 'gmpy' + SYMPY_INTS = (int, type(gmpy.mpz(0))) + MPZ = gmpy.mpz + MPQ = gmpy.mpq + + bit_scan1 = gmpy.bit_scan1 + bit_scan0 = gmpy.bit_scan0 + remove = gmpy.remove + factorial = gmpy.fac + sqrt = gmpy.isqrt + is_square = gmpy.is_square + sqrtrem = gmpy.isqrt_rem + gcd = gmpy.gcd + lcm = gmpy.lcm + gcdext = gmpy.gcdext + invert = gmpy.invert + legendre = gmpy.legendre + jacobi = gmpy.jacobi + kronecker = gmpy.kronecker + + def iroot(x, n): + # In the latest gmpy2, the threshold for n is ULONG_MAX, + # but adjust to the older one. + if n <= LONG_MAX: + return gmpy.iroot(x, n) + return python_iroot(x, n) + + is_fermat_prp = gmpy.is_fermat_prp + is_euler_prp = gmpy.is_euler_prp + is_strong_prp = gmpy.is_strong_prp + is_fibonacci_prp = gmpy.is_fibonacci_prp + is_lucas_prp = gmpy.is_lucas_prp + is_selfridge_prp = gmpy.is_selfridge_prp + is_strong_lucas_prp = gmpy.is_strong_lucas_prp + is_strong_selfridge_prp = gmpy.is_strong_selfridge_prp + is_bpsw_prp = gmpy.is_bpsw_prp + is_strong_bpsw_prp = gmpy.is_strong_bpsw_prp + +elif _SYMPY_GROUND_TYPES == 'flint': + + assert _flint is not None + + flint = _flint + gmpy = None + + HAS_GMPY = 0 + GROUND_TYPES = 'flint' + SYMPY_INTS = (int, flint.fmpz) # type: ignore + MPZ = flint.fmpz # type: ignore + MPQ = flint.fmpq # type: ignore + + bit_scan1 = python_bit_scan1 + bit_scan0 = python_bit_scan0 + remove = python_remove + factorial = python_factorial + + def sqrt(x): + return flint.fmpz(x).isqrt() + + def is_square(x): + if x < 0: + return False + return flint.fmpz(x).sqrtrem()[1] == 0 + + def sqrtrem(x): + return flint.fmpz(x).sqrtrem() + + def gcd(*args): + return reduce(flint.fmpz.gcd, args, flint.fmpz(0)) + + def lcm(*args): + return reduce(flint.fmpz.lcm, args, flint.fmpz(1)) + + gcdext = python_gcdext + invert = python_invert + legendre = python_legendre + + def jacobi(x, y): + if y <= 0 or not y % 2: + raise ValueError("y should be an odd positive integer") + return flint.fmpz(x).jacobi(y) + + kronecker = python_kronecker + + def iroot(x, n): + if n <= LONG_MAX: + y = flint.fmpz(x).root(n) + return y, y**n == x + return python_iroot(x, n) + + is_fermat_prp = python_is_fermat_prp + is_euler_prp = python_is_euler_prp + is_strong_prp = python_is_strong_prp + is_fibonacci_prp = python_is_fibonacci_prp + is_lucas_prp = python_is_lucas_prp + is_selfridge_prp = python_is_selfridge_prp + is_strong_lucas_prp = python_is_strong_lucas_prp + is_strong_selfridge_prp = python_is_strong_selfridge_prp + is_bpsw_prp = python_is_bpsw_prp + is_strong_bpsw_prp = python_is_strong_bpsw_prp + +elif _SYMPY_GROUND_TYPES == 'python': + + flint = None + gmpy = None + + HAS_GMPY = 0 + GROUND_TYPES = 'python' + SYMPY_INTS = (int,) + MPZ = int + MPQ = PythonMPQ + + bit_scan1 = python_bit_scan1 + bit_scan0 = python_bit_scan0 + remove = python_remove + factorial = python_factorial + sqrt = python_sqrt + is_square = python_is_square + sqrtrem = python_sqrtrem + gcd = python_gcd + lcm = python_lcm + gcdext = python_gcdext + invert = python_invert + legendre = python_legendre + jacobi = python_jacobi + kronecker = python_kronecker + iroot = python_iroot + is_fermat_prp = python_is_fermat_prp + is_euler_prp = python_is_euler_prp + is_strong_prp = python_is_strong_prp + is_fibonacci_prp = python_is_fibonacci_prp + is_lucas_prp = python_is_lucas_prp + is_selfridge_prp = python_is_selfridge_prp + is_strong_lucas_prp = python_is_strong_lucas_prp + is_strong_selfridge_prp = python_is_strong_selfridge_prp + is_bpsw_prp = python_is_bpsw_prp + is_strong_bpsw_prp = python_is_strong_bpsw_prp + +else: + assert False diff --git a/.venv/lib/python3.13/site-packages/sympy/external/importtools.py b/.venv/lib/python3.13/site-packages/sympy/external/importtools.py new file mode 100644 index 0000000000000000000000000000000000000000..5008b3dd4634d3cee10744a0a92b1204051f07cc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/external/importtools.py @@ -0,0 +1,187 @@ +"""Tools to assist importing optional external modules.""" + +import sys +import re + +# Override these in the module to change the default warning behavior. +# For example, you might set both to False before running the tests so that +# warnings are not printed to the console, or set both to True for debugging. + +WARN_NOT_INSTALLED = None # Default is False +WARN_OLD_VERSION = None # Default is True + + +def __sympy_debug(): + # helper function from sympy/__init__.py + # We don't just import SYMPY_DEBUG from that file because we don't want to + # import all of SymPy just to use this module. + import os + debug_str = os.getenv('SYMPY_DEBUG', 'False') + if debug_str in ('True', 'False'): + return eval(debug_str) + else: + raise RuntimeError("unrecognized value for SYMPY_DEBUG: %s" % + debug_str) + +if __sympy_debug(): + WARN_OLD_VERSION = True + WARN_NOT_INSTALLED = True + + +_component_re = re.compile(r'(\d+ | [a-z]+ | \.)', re.VERBOSE) + +def version_tuple(vstring): + # Parse a version string to a tuple e.g. '1.2' -> (1, 2) + # Simplified from distutils.version.LooseVersion which was deprecated in + # Python 3.10. + components = [] + for x in _component_re.split(vstring): + if x and x != '.': + try: + x = int(x) + except ValueError: + pass + components.append(x) + return tuple(components) + + +def import_module(module, min_module_version=None, min_python_version=None, + warn_not_installed=None, warn_old_version=None, + module_version_attr='__version__', module_version_attr_call_args=None, + import_kwargs={}, catch=()): + """ + Import and return a module if it is installed. + + If the module is not installed, it returns None. + + A minimum version for the module can be given as the keyword argument + min_module_version. This should be comparable against the module version. + By default, module.__version__ is used to get the module version. To + override this, set the module_version_attr keyword argument. If the + attribute of the module to get the version should be called (e.g., + module.version()), then set module_version_attr_call_args to the args such + that module.module_version_attr(*module_version_attr_call_args) returns the + module's version. + + If the module version is less than min_module_version using the Python < + comparison, None will be returned, even if the module is installed. You can + use this to keep from importing an incompatible older version of a module. + + You can also specify a minimum Python version by using the + min_python_version keyword argument. This should be comparable against + sys.version_info. + + If the keyword argument warn_not_installed is set to True, the function will + emit a UserWarning when the module is not installed. + + If the keyword argument warn_old_version is set to True, the function will + emit a UserWarning when the library is installed, but cannot be imported + because of the min_module_version or min_python_version options. + + Note that because of the way warnings are handled, a warning will be + emitted for each module only once. You can change the default warning + behavior by overriding the values of WARN_NOT_INSTALLED and WARN_OLD_VERSION + in sympy.external.importtools. By default, WARN_NOT_INSTALLED is False and + WARN_OLD_VERSION is True. + + This function uses __import__() to import the module. To pass additional + options to __import__(), use the import_kwargs keyword argument. For + example, to import a submodule A.B, you must pass a nonempty fromlist option + to __import__. See the docstring of __import__(). + + This catches ImportError to determine if the module is not installed. To + catch additional errors, pass them as a tuple to the catch keyword + argument. + + Examples + ======== + + >>> from sympy.external import import_module + + >>> numpy = import_module('numpy') + + >>> numpy = import_module('numpy', min_python_version=(2, 7), + ... warn_old_version=False) + + >>> numpy = import_module('numpy', min_module_version='1.5', + ... warn_old_version=False) # numpy.__version__ is a string + + >>> # gmpy does not have __version__, but it does have gmpy.version() + + >>> gmpy = import_module('gmpy', min_module_version='1.14', + ... module_version_attr='version', module_version_attr_call_args=(), + ... warn_old_version=False) + + >>> # To import a submodule, you must pass a nonempty fromlist to + >>> # __import__(). The values do not matter. + >>> p3 = import_module('mpl_toolkits.mplot3d', + ... import_kwargs={'fromlist':['something']}) + + >>> # matplotlib.pyplot can raise RuntimeError when the display cannot be opened + >>> matplotlib = import_module('matplotlib', + ... import_kwargs={'fromlist':['pyplot']}, catch=(RuntimeError,)) + + """ + # keyword argument overrides default, and global variable overrides + # keyword argument. + warn_old_version = (WARN_OLD_VERSION if WARN_OLD_VERSION is not None + else warn_old_version or True) + warn_not_installed = (WARN_NOT_INSTALLED if WARN_NOT_INSTALLED is not None + else warn_not_installed or False) + + import warnings + + # Check Python first so we don't waste time importing a module we can't use + if min_python_version: + if sys.version_info < min_python_version: + if warn_old_version: + warnings.warn("Python version is too old to use %s " + "(%s or newer required)" % ( + module, '.'.join(map(str, min_python_version))), + UserWarning, stacklevel=2) + return + + try: + mod = __import__(module, **import_kwargs) + + ## there's something funny about imports with matplotlib and py3k. doing + ## from matplotlib import collections + ## gives python's stdlib collections module. explicitly re-importing + ## the module fixes this. + from_list = import_kwargs.get('fromlist', ()) + for submod in from_list: + if submod == 'collections' and mod.__name__ == 'matplotlib': + __import__(module + '.' + submod) + except ImportError: + if warn_not_installed: + warnings.warn("%s module is not installed" % module, UserWarning, + stacklevel=2) + return + except catch as e: + if warn_not_installed: + warnings.warn( + "%s module could not be used (%s)" % (module, repr(e)), + stacklevel=2) + return + + if min_module_version: + modversion = getattr(mod, module_version_attr) + if module_version_attr_call_args is not None: + modversion = modversion(*module_version_attr_call_args) + if version_tuple(modversion) < version_tuple(min_module_version): + if warn_old_version: + # Attempt to create a pretty string version of the version + if isinstance(min_module_version, str): + verstr = min_module_version + elif isinstance(min_module_version, (tuple, list)): + verstr = '.'.join(map(str, min_module_version)) + else: + # Either don't know what this is. Hopefully + # it's something that has a nice str version, like an int. + verstr = str(min_module_version) + warnings.warn("%s version is too old to use " + "(%s or newer required)" % (module, verstr), + UserWarning, stacklevel=2) + return + + return mod diff --git a/.venv/lib/python3.13/site-packages/sympy/external/ntheory.py b/.venv/lib/python3.13/site-packages/sympy/external/ntheory.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c9bf813cf02b311f9a12ee7fbc4932ed551f3b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/external/ntheory.py @@ -0,0 +1,618 @@ +# sympy.external.ntheory +# +# This module provides pure Python implementations of some number theory +# functions that are alternately used from gmpy2 if it is installed. + +import math + +import mpmath.libmp as mlib + + +_small_trailing = [0] * 256 +for j in range(1, 8): + _small_trailing[1 << j :: 1 << (j + 1)] = [j] * (1 << (7 - j)) + + +def bit_scan1(x, n=0): + if not x: + return + x = abs(x >> n) + low_byte = x & 0xFF + if low_byte: + return _small_trailing[low_byte] + n + + t = 8 + n + x >>= 8 + # 2**m is quick for z up through 2**30 + z = x.bit_length() - 1 + if x == 1 << z: + return z + t + + if z < 300: + # fixed 8-byte reduction + while not x & 0xFF: + x >>= 8 + t += 8 + else: + # binary reduction important when there might be a large + # number of trailing 0s + p = z >> 1 + while not x & 0xFF: + while x & ((1 << p) - 1): + p >>= 1 + x >>= p + t += p + return t + _small_trailing[x & 0xFF] + + +def bit_scan0(x, n=0): + return bit_scan1(x + (1 << n), n) + + +def remove(x, f): + if f < 2: + raise ValueError("factor must be > 1") + if x == 0: + return 0, 0 + if f == 2: + b = bit_scan1(x) + return x >> b, b + m = 0 + y, rem = divmod(x, f) + while not rem: + x = y + m += 1 + if m > 5: + pow_list = [f**2] + while pow_list: + _f = pow_list[-1] + y, rem = divmod(x, _f) + if not rem: + m += 1 << len(pow_list) + x = y + pow_list.append(_f**2) + else: + pow_list.pop() + y, rem = divmod(x, f) + return x, m + + +def factorial(x): + """Return x!.""" + return int(mlib.ifac(int(x))) + + +def sqrt(x): + """Integer square root of x.""" + return int(mlib.isqrt(int(x))) + + +def sqrtrem(x): + """Integer square root of x and remainder.""" + s, r = mlib.sqrtrem(int(x)) + return (int(s), int(r)) + + +gcd = math.gcd +lcm = math.lcm + + +def _sign(n): + if n < 0: + return -1, -n + return 1, n + + +def gcdext(a, b): + if not a or not b: + g = abs(a) or abs(b) + if not g: + return (0, 0, 0) + return (g, a // g, b // g) + + x_sign, a = _sign(a) + y_sign, b = _sign(b) + x, r = 1, 0 + y, s = 0, 1 + + while b: + q, c = divmod(a, b) + a, b = b, c + x, r = r, x - q*r + y, s = s, y - q*s + + return (a, x * x_sign, y * y_sign) + + +def is_square(x): + """Return True if x is a square number.""" + if x < 0: + return False + + # Note that the possible values of y**2 % n for a given n are limited. + # For example, when n=4, y**2 % n can only take 0 or 1. + # In other words, if x % 4 is 2 or 3, then x is not a square number. + # Mathematically, it determines if it belongs to the set {y**2 % n}, + # but implementationally, it can be realized as a logical conjunction + # with an n-bit integer. + # see https://mersenneforum.org/showpost.php?p=110896 + # def magic(n): + # s = {y**2 % n for y in range(n)} + # s = set(range(n)) - s + # return sum(1 << bit for bit in s) + # >>> print(hex(magic(128))) + # 0xfdfdfdedfdfdfdecfdfdfdedfdfcfdec + # >>> print(hex(magic(99))) + # 0x5f6f9ffb6fb7ddfcb75befdec + # >>> print(hex(magic(91))) + # 0x6fd1bfcfed5f3679d3ebdec + # >>> print(hex(magic(85))) + # 0xdef9ae771ffe3b9d67dec + if 0xfdfdfdedfdfdfdecfdfdfdedfdfcfdec & (1 << (x & 127)): + return False # e.g. 2, 3 + m = x % 765765 # 765765 = 99 * 91 * 85 + if 0x5f6f9ffb6fb7ddfcb75befdec & (1 << (m % 99)): + return False # e.g. 17, 68 + if 0x6fd1bfcfed5f3679d3ebdec & (1 << (m % 91)): + return False # e.g. 97, 388 + if 0xdef9ae771ffe3b9d67dec & (1 << (m % 85)): + return False # e.g. 793, 1408 + return mlib.sqrtrem(int(x))[1] == 0 + + +def invert(x, m): + """Modular inverse of x modulo m. + + Returns y such that x*y == 1 mod m. + + Uses ``math.pow`` but reproduces the behaviour of ``gmpy2.invert`` + which raises ZeroDivisionError if no inverse exists. + """ + try: + return pow(x, -1, m) + except ValueError: + raise ZeroDivisionError("invert() no inverse exists") + + +def legendre(x, y): + """Legendre symbol (x / y). + + Following the implementation of gmpy2, + the error is raised only when y is an even number. + """ + if y <= 0 or not y % 2: + raise ValueError("y should be an odd prime") + x %= y + if not x: + return 0 + if pow(x, (y - 1) // 2, y) == 1: + return 1 + return -1 + + +def jacobi(x, y): + """Jacobi symbol (x / y).""" + if y <= 0 or not y % 2: + raise ValueError("y should be an odd positive integer") + x %= y + if not x: + return int(y == 1) + if y == 1 or x == 1: + return 1 + if gcd(x, y) != 1: + return 0 + j = 1 + while x != 0: + while x % 2 == 0 and x > 0: + x >>= 1 + if y % 8 in [3, 5]: + j = -j + x, y = y, x + if x % 4 == y % 4 == 3: + j = -j + x %= y + return j + + +def kronecker(x, y): + """Kronecker symbol (x / y).""" + if gcd(x, y) != 1: + return 0 + if y == 0: + return 1 + sign = -1 if y < 0 and x < 0 else 1 + y = abs(y) + s = bit_scan1(y) + y >>= s + if s % 2 and x % 8 in [3, 5]: + sign = -sign + return sign * jacobi(x, y) + + +def iroot(y, n): + if y < 0: + raise ValueError("y must be nonnegative") + if n < 1: + raise ValueError("n must be positive") + if y in (0, 1): + return y, True + if n == 1: + return y, True + if n == 2: + x, rem = mlib.sqrtrem(y) + return int(x), not rem + if n >= y.bit_length(): + return 1, False + # Get initial estimate for Newton's method. Care must be taken to + # avoid overflow + try: + guess = int(y**(1./n) + 0.5) + except OverflowError: + exp = math.log2(y)/n + if exp > 53: + shift = int(exp - 53) + guess = int(2.0**(exp - shift) + 1) << shift + else: + guess = int(2.0**exp) + if guess > 2**50: + # Newton iteration + xprev, x = -1, guess + while 1: + t = x**(n - 1) + xprev, x = x, ((n - 1)*x + y//t)//n + if abs(x - xprev) < 2: + break + else: + x = guess + # Compensate + t = x**n + while t < y: + x += 1 + t = x**n + while t > y: + x -= 1 + t = x**n + return x, t == y + + +def is_fermat_prp(n, a): + if a < 2: + raise ValueError("is_fermat_prp() requires 'a' greater than or equal to 2") + if n < 1: + raise ValueError("is_fermat_prp() requires 'n' be greater than 0") + if n == 1: + return False + if n % 2 == 0: + return n == 2 + a %= n + if gcd(n, a) != 1: + raise ValueError("is_fermat_prp() requires gcd(n,a) == 1") + return pow(a, n - 1, n) == 1 + + +def is_euler_prp(n, a): + if a < 2: + raise ValueError("is_euler_prp() requires 'a' greater than or equal to 2") + if n < 1: + raise ValueError("is_euler_prp() requires 'n' be greater than 0") + if n == 1: + return False + if n % 2 == 0: + return n == 2 + a %= n + if gcd(n, a) != 1: + raise ValueError("is_euler_prp() requires gcd(n,a) == 1") + return pow(a, n >> 1, n) == jacobi(a, n) % n + + +def _is_strong_prp(n, a): + s = bit_scan1(n - 1) + a = pow(a, n >> s, n) + if a == 1 or a == n - 1: + return True + for _ in range(s - 1): + a = pow(a, 2, n) + if a == n - 1: + return True + if a == 1: + return False + return False + + +def is_strong_prp(n, a): + if a < 2: + raise ValueError("is_strong_prp() requires 'a' greater than or equal to 2") + if n < 1: + raise ValueError("is_strong_prp() requires 'n' be greater than 0") + if n == 1: + return False + if n % 2 == 0: + return n == 2 + a %= n + if gcd(n, a) != 1: + raise ValueError("is_strong_prp() requires gcd(n,a) == 1") + return _is_strong_prp(n, a) + + +def _lucas_sequence(n, P, Q, k): + r"""Return the modular Lucas sequence (U_k, V_k, Q_k). + + Explanation + =========== + + Given a Lucas sequence defined by P, Q, returns the kth values for + U and V, along with Q^k, all modulo n. This is intended for use with + possibly very large values of n and k, where the combinatorial functions + would be completely unusable. + + .. math :: + U_k = \begin{cases} + 0 & \text{if } k = 0\\ + 1 & \text{if } k = 1\\ + PU_{k-1} - QU_{k-2} & \text{if } k > 1 + \end{cases}\\ + V_k = \begin{cases} + 2 & \text{if } k = 0\\ + P & \text{if } k = 1\\ + PV_{k-1} - QV_{k-2} & \text{if } k > 1 + \end{cases} + + The modular Lucas sequences are used in numerous places in number theory, + especially in the Lucas compositeness tests and the various n + 1 proofs. + + Parameters + ========== + + n : int + n is an odd number greater than or equal to 3 + P : int + Q : int + D determined by D = P**2 - 4*Q is non-zero + k : int + k is a nonnegative integer + + Returns + ======= + + U, V, Qk : (int, int, int) + `(U_k \bmod{n}, V_k \bmod{n}, Q^k \bmod{n})` + + Examples + ======== + + >>> from sympy.external.ntheory import _lucas_sequence + >>> N = 10**2000 + 4561 + >>> sol = U, V, Qk = _lucas_sequence(N, 3, 1, N//2); sol + (0, 2, 1) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lucas_sequence + + """ + if k == 0: + return (0, 2, 1) + D = P**2 - 4*Q + U = 1 + V = P + Qk = Q % n + if Q == 1: + # Optimization for extra strong tests. + for b in bin(k)[3:]: + U = (U*V) % n + V = (V*V - 2) % n + if b == "1": + U, V = U*P + V, V*P + U*D + if U & 1: + U += n + if V & 1: + V += n + U, V = U >> 1, V >> 1 + elif P == 1 and Q == -1: + # Small optimization for 50% of Selfridge parameters. + for b in bin(k)[3:]: + U = (U*V) % n + if Qk == 1: + V = (V*V - 2) % n + else: + V = (V*V + 2) % n + Qk = 1 + if b == "1": + # new_U = (U + V) // 2 + # new_V = (5*U + V) // 2 = 2*U + new_U + U, V = U + V, U << 1 + if U & 1: + U += n + U >>= 1 + V += U + Qk = -1 + Qk %= n + elif P == 1: + for b in bin(k)[3:]: + U = (U*V) % n + V = (V*V - 2*Qk) % n + Qk *= Qk + if b == "1": + # new_U = (U + V) // 2 + # new_V = new_U - 2*Q*U + U, V = U + V, (Q*U) << 1 + if U & 1: + U += n + U >>= 1 + V = U - V + Qk *= Q + Qk %= n + else: + # The general case with any P and Q. + for b in bin(k)[3:]: + U = (U*V) % n + V = (V*V - 2*Qk) % n + Qk *= Qk + if b == "1": + U, V = U*P + V, V*P + U*D + if U & 1: + U += n + if V & 1: + V += n + U, V = U >> 1, V >> 1 + Qk *= Q + Qk %= n + return (U % n, V % n, Qk) + + +def is_fibonacci_prp(n, p, q): + d = p**2 - 4*q + if d == 0 or p <= 0 or q not in [1, -1]: + raise ValueError("invalid values for p,q in is_fibonacci_prp()") + if n < 1: + raise ValueError("is_fibonacci_prp() requires 'n' be greater than 0") + if n == 1: + return False + if n % 2 == 0: + return n == 2 + return _lucas_sequence(n, p, q, n)[1] == p % n + + +def is_lucas_prp(n, p, q): + d = p**2 - 4*q + if d == 0: + raise ValueError("invalid values for p,q in is_lucas_prp()") + if n < 1: + raise ValueError("is_lucas_prp() requires 'n' be greater than 0") + if n == 1: + return False + if n % 2 == 0: + return n == 2 + if gcd(n, q*d) not in [1, n]: + raise ValueError("is_lucas_prp() requires gcd(n,2*q*D) == 1") + return _lucas_sequence(n, p, q, n - jacobi(d, n))[0] == 0 + + +def _is_selfridge_prp(n): + """Lucas compositeness test with the Selfridge parameters for n. + + Explanation + =========== + + The Lucas compositeness test checks whether n is a prime number. + The test can be run with arbitrary parameters ``P`` and ``Q``, which also change the performance of the test. + So, which parameters are most effective for running the Lucas compositeness test? + As an algorithm for determining ``P`` and ``Q``, Selfridge proposed method A [1]_ page 1401 + (Since two methods were proposed, referred to simply as A and B in the paper, + we will refer to one of them as "method A"). + + method A fixes ``P = 1``. Then, ``D`` defined by ``D = P**2 - 4Q`` is varied from 5, -7, 9, -11, 13, and so on, + with the first ``D`` being ``jacobi(D, n) == -1``. Once ``D`` is determined, + ``Q`` is determined to be ``(P**2 - D)//4``. + + References + ========== + + .. [1] Robert Baillie, Samuel S. Wagstaff, Lucas Pseudoprimes, + Math. Comp. Vol 35, Number 152 (1980), pp. 1391-1417, + https://doi.org/10.1090%2FS0025-5718-1980-0583518-6 + http://mpqs.free.fr/LucasPseudoprimes.pdf + + """ + for D in range(5, 1_000_000, 2): + if D & 2: # if D % 4 == 3 + D = -D + j = jacobi(D, n) + if j == -1: + return _lucas_sequence(n, 1, (1-D) // 4, n + 1)[0] == 0 + if j == 0 and D % n: + return False + # When j == -1 is hard to find, suspect a square number + if D == 13 and is_square(n): + return False + raise ValueError("appropriate value for D cannot be found in is_selfridge_prp()") + + +def is_selfridge_prp(n): + if n < 1: + raise ValueError("is_selfridge_prp() requires 'n' be greater than 0") + if n == 1: + return False + if n % 2 == 0: + return n == 2 + return _is_selfridge_prp(n) + + +def is_strong_lucas_prp(n, p, q): + D = p**2 - 4*q + if D == 0: + raise ValueError("invalid values for p,q in is_strong_lucas_prp()") + if n < 1: + raise ValueError("is_selfridge_prp() requires 'n' be greater than 0") + if n == 1: + return False + if n % 2 == 0: + return n == 2 + if gcd(n, q*D) not in [1, n]: + raise ValueError("is_strong_lucas_prp() requires gcd(n,2*q*D) == 1") + j = jacobi(D, n) + s = bit_scan1(n - j) + U, V, Qk = _lucas_sequence(n, p, q, (n - j) >> s) + if U == 0 or V == 0: + return True + for _ in range(s - 1): + V = (V*V - 2*Qk) % n + if V == 0: + return True + Qk = pow(Qk, 2, n) + return False + + +def _is_strong_selfridge_prp(n): + for D in range(5, 1_000_000, 2): + if D & 2: # if D % 4 == 3 + D = -D + j = jacobi(D, n) + if j == -1: + s = bit_scan1(n + 1) + U, V, Qk = _lucas_sequence(n, 1, (1-D) // 4, (n + 1) >> s) + if U == 0 or V == 0: + return True + for _ in range(s - 1): + V = (V*V - 2*Qk) % n + if V == 0: + return True + Qk = pow(Qk, 2, n) + return False + if j == 0 and D % n: + return False + # When j == -1 is hard to find, suspect a square number + if D == 13 and is_square(n): + return False + raise ValueError("appropriate value for D cannot be found in is_strong_selfridge_prp()") + + +def is_strong_selfridge_prp(n): + if n < 1: + raise ValueError("is_strong_selfridge_prp() requires 'n' be greater than 0") + if n == 1: + return False + if n % 2 == 0: + return n == 2 + return _is_strong_selfridge_prp(n) + + +def is_bpsw_prp(n): + if n < 1: + raise ValueError("is_bpsw_prp() requires 'n' be greater than 0") + if n == 1: + return False + if n % 2 == 0: + return n == 2 + return _is_strong_prp(n, 2) and _is_selfridge_prp(n) + + +def is_strong_bpsw_prp(n): + if n < 1: + raise ValueError("is_strong_bpsw_prp() requires 'n' be greater than 0") + if n == 1: + return False + if n % 2 == 0: + return n == 2 + return _is_strong_prp(n, 2) and _is_strong_selfridge_prp(n) diff --git a/.venv/lib/python3.13/site-packages/sympy/external/pythonmpq.py b/.venv/lib/python3.13/site-packages/sympy/external/pythonmpq.py new file mode 100644 index 0000000000000000000000000000000000000000..4f2d102974e04e139c00a39057976b5a5bf90776 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/external/pythonmpq.py @@ -0,0 +1,341 @@ +""" +PythonMPQ: Rational number type based on Python integers. + +This class is intended as a pure Python fallback for when gmpy2 is not +installed. If gmpy2 is installed then its mpq type will be used instead. The +mpq type is around 20x faster. We could just use the stdlib Fraction class +here but that is slower: + + from fractions import Fraction + from sympy.external.pythonmpq import PythonMPQ + nums = range(1000) + dens = range(5, 1005) + rats = [Fraction(n, d) for n, d in zip(nums, dens)] + sum(rats) # <--- 24 milliseconds + rats = [PythonMPQ(n, d) for n, d in zip(nums, dens)] + sum(rats) # <--- 7 milliseconds + +Both mpq and Fraction have some awkward features like the behaviour of +division with // and %: + + >>> from fractions import Fraction + >>> Fraction(2, 3) % Fraction(1, 4) + 1/6 + +For the QQ domain we do not want this behaviour because there should be no +remainder when dividing rational numbers. SymPy does not make use of this +aspect of mpq when gmpy2 is installed. Since this class is a fallback for that +case we do not bother implementing e.g. __mod__ so that we can be sure we +are not using it when gmpy2 is installed either. +""" + +from __future__ import annotations +import operator +from math import gcd +from decimal import Decimal +from fractions import Fraction +import sys +from typing import Type + + +# Used for __hash__ +_PyHASH_MODULUS = sys.hash_info.modulus +_PyHASH_INF = sys.hash_info.inf + + +class PythonMPQ: + """Rational number implementation that is intended to be compatible with + gmpy2's mpq. + + Also slightly faster than fractions.Fraction. + + PythonMPQ should be treated as immutable although no effort is made to + prevent mutation (since that might slow down calculations). + """ + __slots__ = ('numerator', 'denominator') + + def __new__(cls, numerator, denominator=None): + """Construct PythonMPQ with gcd computation and checks""" + if denominator is not None: + # + # PythonMPQ(n, d): require n and d to be int and d != 0 + # + if isinstance(numerator, int) and isinstance(denominator, int): + # This is the slow part: + divisor = gcd(numerator, denominator) + numerator //= divisor + denominator //= divisor + return cls._new_check(numerator, denominator) + else: + # + # PythonMPQ(q) + # + # Here q can be PythonMPQ, int, Decimal, float, Fraction or str + # + if isinstance(numerator, int): + return cls._new(numerator, 1) + elif isinstance(numerator, PythonMPQ): + return cls._new(numerator.numerator, numerator.denominator) + + # Let Fraction handle Decimal/float conversion and str parsing + if isinstance(numerator, (Decimal, float, str)): + numerator = Fraction(numerator) + if isinstance(numerator, Fraction): + return cls._new(numerator.numerator, numerator.denominator) + # + # Reject everything else. This is more strict than mpq which allows + # things like mpq(Fraction, Fraction) or mpq(Decimal, any). The mpq + # behaviour is somewhat inconsistent so we choose to accept only a + # more strict subset of what mpq allows. + # + raise TypeError("PythonMPQ() requires numeric or string argument") + + @classmethod + def _new_check(cls, numerator, denominator): + """Construct PythonMPQ, check divide by zero and canonicalize signs""" + if not denominator: + raise ZeroDivisionError(f'Zero divisor {numerator}/{denominator}') + elif denominator < 0: + numerator = -numerator + denominator = -denominator + return cls._new(numerator, denominator) + + @classmethod + def _new(cls, numerator, denominator): + """Construct PythonMPQ efficiently (no checks)""" + obj = super().__new__(cls) + obj.numerator = numerator + obj.denominator = denominator + return obj + + def __int__(self): + """Convert to int (truncates towards zero)""" + p, q = self.numerator, self.denominator + if p < 0: + return -(-p//q) + return p//q + + def __float__(self): + """Convert to float (approximately)""" + return self.numerator / self.denominator + + def __bool__(self): + """True/False if nonzero/zero""" + return bool(self.numerator) + + def __eq__(self, other): + """Compare equal with PythonMPQ, int, float, Decimal or Fraction""" + if isinstance(other, PythonMPQ): + return (self.numerator == other.numerator + and self.denominator == other.denominator) + elif isinstance(other, self._compatible_types): + return self.__eq__(PythonMPQ(other)) + else: + return NotImplemented + + def __hash__(self): + """hash - same as mpq/Fraction""" + try: + dinv = pow(self.denominator, -1, _PyHASH_MODULUS) + except ValueError: + hash_ = _PyHASH_INF + else: + hash_ = hash(hash(abs(self.numerator)) * dinv) + result = hash_ if self.numerator >= 0 else -hash_ + return -2 if result == -1 else result + + def __reduce__(self): + """Deconstruct for pickling""" + return type(self), (self.numerator, self.denominator) + + def __str__(self): + """Convert to string""" + if self.denominator != 1: + return f"{self.numerator}/{self.denominator}" + else: + return f"{self.numerator}" + + def __repr__(self): + """Convert to string""" + return f"MPQ({self.numerator},{self.denominator})" + + def _cmp(self, other, op): + """Helper for lt/le/gt/ge""" + if not isinstance(other, self._compatible_types): + return NotImplemented + lhs = self.numerator * other.denominator + rhs = other.numerator * self.denominator + return op(lhs, rhs) + + def __lt__(self, other): + """self < other""" + return self._cmp(other, operator.lt) + + def __le__(self, other): + """self <= other""" + return self._cmp(other, operator.le) + + def __gt__(self, other): + """self > other""" + return self._cmp(other, operator.gt) + + def __ge__(self, other): + """self >= other""" + return self._cmp(other, operator.ge) + + def __abs__(self): + """abs(q)""" + return self._new(abs(self.numerator), self.denominator) + + def __pos__(self): + """+q""" + return self + + def __neg__(self): + """-q""" + return self._new(-self.numerator, self.denominator) + + def __add__(self, other): + """q1 + q2""" + if isinstance(other, PythonMPQ): + # + # This is much faster than the naive method used in the stdlib + # fractions module. Not sure where this method comes from + # though... + # + # Compare timings for something like: + # nums = range(1000) + # rats = [PythonMPQ(n, d) for n, d in zip(nums[:-5], nums[5:])] + # sum(rats) # <-- time this + # + ap, aq = self.numerator, self.denominator + bp, bq = other.numerator, other.denominator + g = gcd(aq, bq) + if g == 1: + p = ap*bq + aq*bp + q = bq*aq + else: + q1, q2 = aq//g, bq//g + p, q = ap*q2 + bp*q1, q1*q2 + g2 = gcd(p, g) + p, q = (p // g2), q * (g // g2) + + elif isinstance(other, int): + p = self.numerator + self.denominator * other + q = self.denominator + else: + return NotImplemented + + return self._new(p, q) + + def __radd__(self, other): + """z1 + q2""" + if isinstance(other, int): + p = self.numerator + self.denominator * other + q = self.denominator + return self._new(p, q) + else: + return NotImplemented + + def __sub__(self ,other): + """q1 - q2""" + if isinstance(other, PythonMPQ): + ap, aq = self.numerator, self.denominator + bp, bq = other.numerator, other.denominator + g = gcd(aq, bq) + if g == 1: + p = ap*bq - aq*bp + q = bq*aq + else: + q1, q2 = aq//g, bq//g + p, q = ap*q2 - bp*q1, q1*q2 + g2 = gcd(p, g) + p, q = (p // g2), q * (g // g2) + elif isinstance(other, int): + p = self.numerator - self.denominator*other + q = self.denominator + else: + return NotImplemented + + return self._new(p, q) + + def __rsub__(self, other): + """z1 - q2""" + if isinstance(other, int): + p = self.denominator * other - self.numerator + q = self.denominator + return self._new(p, q) + else: + return NotImplemented + + def __mul__(self, other): + """q1 * q2""" + if isinstance(other, PythonMPQ): + ap, aq = self.numerator, self.denominator + bp, bq = other.numerator, other.denominator + x1 = gcd(ap, bq) + x2 = gcd(bp, aq) + p, q = ((ap//x1)*(bp//x2), (aq//x2)*(bq//x1)) + elif isinstance(other, int): + x = gcd(other, self.denominator) + p = self.numerator*(other//x) + q = self.denominator//x + else: + return NotImplemented + + return self._new(p, q) + + def __rmul__(self, other): + """z1 * q2""" + if isinstance(other, int): + x = gcd(self.denominator, other) + p = self.numerator*(other//x) + q = self.denominator//x + return self._new(p, q) + else: + return NotImplemented + + def __pow__(self, exp): + """q ** z""" + p, q = self.numerator, self.denominator + + if exp < 0: + p, q, exp = q, p, -exp + + return self._new_check(p**exp, q**exp) + + def __truediv__(self, other): + """q1 / q2""" + if isinstance(other, PythonMPQ): + ap, aq = self.numerator, self.denominator + bp, bq = other.numerator, other.denominator + x1 = gcd(ap, bp) + x2 = gcd(bq, aq) + p, q = ((ap//x1)*(bq//x2), (aq//x2)*(bp//x1)) + elif isinstance(other, int): + x = gcd(other, self.numerator) + p = self.numerator//x + q = self.denominator*(other//x) + else: + return NotImplemented + + return self._new_check(p, q) + + def __rtruediv__(self, other): + """z / q""" + if isinstance(other, int): + x = gcd(self.numerator, other) + p = self.denominator*(other//x) + q = self.numerator//x + return self._new_check(p, q) + else: + return NotImplemented + + _compatible_types: tuple[Type, ...] = () + +# +# These are the types that PythonMPQ will interoperate with for operations +# and comparisons such as ==, + etc. We define this down here so that we can +# include PythonMPQ in the list as well. +# +PythonMPQ._compatible_types = (PythonMPQ, int, Decimal, Fraction) diff --git a/.venv/lib/python3.13/site-packages/sympy/external/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/external/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/external/tests/test_autowrap.py b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_autowrap.py new file mode 100644 index 0000000000000000000000000000000000000000..d469b552995b7625f786f3296089e41f42da75cb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_autowrap.py @@ -0,0 +1,313 @@ +import sympy +import tempfile +import os +from pathlib import Path +from sympy.core.mod import Mod +from sympy.core.relational import Eq +from sympy.core.symbol import symbols +from sympy.external import import_module +from sympy.tensor import IndexedBase, Idx +from sympy.utilities.autowrap import autowrap, ufuncify, CodeWrapError +from sympy.testing.pytest import skip + +numpy = import_module('numpy', min_module_version='1.6.1') +Cython = import_module('Cython', min_module_version='0.15.1') +f2py = import_module('numpy.f2py', import_kwargs={'fromlist': ['f2py']}) + +f2pyworks = False +if f2py: + try: + autowrap(symbols('x'), 'f95', 'f2py') + except (CodeWrapError, ImportError, OSError): + f2pyworks = False + else: + f2pyworks = True + +a, b, c = symbols('a b c') +n, m, d = symbols('n m d', integer=True) +A, B, C = symbols('A B C', cls=IndexedBase) +i = Idx('i', m) +j = Idx('j', n) +k = Idx('k', d) + + +def has_module(module): + """ + Return True if module exists, otherwise run skip(). + + module should be a string. + """ + # To give a string of the module name to skip(), this function takes a + # string. So we don't waste time running import_module() more than once, + # just map the three modules tested here in this dict. + modnames = {'numpy': numpy, 'Cython': Cython, 'f2py': f2py} + + if modnames[module]: + if module == 'f2py' and not f2pyworks: + skip("Couldn't run f2py.") + return True + skip("Couldn't import %s." % module) + +# +# test runners used by several language-backend combinations +# + +def runtest_autowrap_twice(language, backend): + f = autowrap((((a + b)/c)**5).expand(), language, backend) + g = autowrap((((a + b)/c)**4).expand(), language, backend) + + # check that autowrap updates the module name. Else, g gives the same as f + assert f(1, -2, 1) == -1.0 + assert g(1, -2, 1) == 1.0 + + +def runtest_autowrap_trace(language, backend): + has_module('numpy') + trace = autowrap(A[i, i], language, backend) + assert trace(numpy.eye(100)) == 100 + + +def runtest_autowrap_matrix_vector(language, backend): + has_module('numpy') + x, y = symbols('x y', cls=IndexedBase) + expr = Eq(y[i], A[i, j]*x[j]) + mv = autowrap(expr, language, backend) + + # compare with numpy's dot product + M = numpy.random.rand(10, 20) + x = numpy.random.rand(20) + y = numpy.dot(M, x) + assert numpy.sum(numpy.abs(y - mv(M, x))) < 1e-13 + + +def runtest_autowrap_matrix_matrix(language, backend): + has_module('numpy') + expr = Eq(C[i, j], A[i, k]*B[k, j]) + matmat = autowrap(expr, language, backend) + + # compare with numpy's dot product + M1 = numpy.random.rand(10, 20) + M2 = numpy.random.rand(20, 15) + M3 = numpy.dot(M1, M2) + assert numpy.sum(numpy.abs(M3 - matmat(M1, M2))) < 1e-13 + + +def runtest_ufuncify(language, backend): + has_module('numpy') + a, b, c = symbols('a b c') + fabc = ufuncify([a, b, c], a*b + c, backend=backend) + facb = ufuncify([a, c, b], a*b + c, backend=backend) + grid = numpy.linspace(-2, 2, 50) + b = numpy.linspace(-5, 4, 50) + c = numpy.linspace(-1, 1, 50) + expected = grid*b + c + numpy.testing.assert_allclose(fabc(grid, b, c), expected) + numpy.testing.assert_allclose(facb(grid, c, b), expected) + + +def runtest_issue_10274(language, backend): + expr = (a - b + c)**(13) + tmp = tempfile.mkdtemp() + f = autowrap(expr, language, backend, tempdir=tmp, + helpers=('helper', a - b + c, (a, b, c))) + assert f(1, 1, 1) == 1 + + for file in os.listdir(tmp): + if not (file.startswith("wrapped_code_") and file.endswith(".c")): + continue + + with open(tmp + '/' + file) as fil: + lines = fil.readlines() + assert lines[0] == "/******************************************************************************\n" + assert "Code generated with SymPy " + sympy.__version__ in lines[1] + assert lines[2:] == [ + " * *\n", + " * See http://www.sympy.org/ for more information. *\n", + " * *\n", + " * This file is part of 'autowrap' *\n", + " ******************************************************************************/\n", + "#include " + '"' + file[:-1]+ 'h"' + "\n", + "#include \n", + "\n", + "double helper(double a, double b, double c) {\n", + "\n", + " double helper_result;\n", + " helper_result = a - b + c;\n", + " return helper_result;\n", + "\n", + "}\n", + "\n", + "double autofunc(double a, double b, double c) {\n", + "\n", + " double autofunc_result;\n", + " autofunc_result = pow(helper(a, b, c), 13);\n", + " return autofunc_result;\n", + "\n", + "}\n", + ] + + +def runtest_issue_15337(language, backend): + has_module('numpy') + # NOTE : autowrap was originally designed to only accept an iterable for + # the kwarg "helpers", but in issue 10274 the user mistakenly thought that + # if there was only a single helper it did not need to be passed via an + # iterable that wrapped the helper tuple. There were no tests for this + # behavior so when the code was changed to accept a single tuple it broke + # the original behavior. These tests below ensure that both now work. + a, b, c, d, e = symbols('a, b, c, d, e') + expr = (a - b + c - d + e)**13 + exp_res = (1. - 2. + 3. - 4. + 5.)**13 + + f = autowrap(expr, language, backend, args=(a, b, c, d, e), + helpers=('f1', a - b + c, (a, b, c))) + numpy.testing.assert_allclose(f(1, 2, 3, 4, 5), exp_res) + + f = autowrap(expr, language, backend, args=(a, b, c, d, e), + helpers=(('f1', a - b, (a, b)), ('f2', c - d, (c, d)))) + numpy.testing.assert_allclose(f(1, 2, 3, 4, 5), exp_res) + + +def test_issue_15230(): + has_module('f2py') + + x, y = symbols('x, y') + expr = Mod(x, 3.0) - Mod(y, -2.0) + f = autowrap(expr, args=[x, y], language='F95') + exp_res = float(expr.xreplace({x: 3.5, y: 2.7}).evalf()) + assert abs(f(3.5, 2.7) - exp_res) < 1e-14 + + x, y = symbols('x, y', integer=True) + expr = Mod(x, 3) - Mod(y, -2) + f = autowrap(expr, args=[x, y], language='F95') + assert f(3, 2) == expr.xreplace({x: 3, y: 2}) + +# +# tests of language-backend combinations +# + +# f2py + + +def test_wrap_twice_f95_f2py(): + has_module('f2py') + runtest_autowrap_twice('f95', 'f2py') + + +def test_autowrap_trace_f95_f2py(): + has_module('f2py') + runtest_autowrap_trace('f95', 'f2py') + + +def test_autowrap_matrix_vector_f95_f2py(): + has_module('f2py') + runtest_autowrap_matrix_vector('f95', 'f2py') + + +def test_autowrap_matrix_matrix_f95_f2py(): + has_module('f2py') + runtest_autowrap_matrix_matrix('f95', 'f2py') + + +def test_ufuncify_f95_f2py(): + has_module('f2py') + runtest_ufuncify('f95', 'f2py') + + +def test_issue_15337_f95_f2py(): + has_module('f2py') + runtest_issue_15337('f95', 'f2py') + +# Cython + + +def test_wrap_twice_c_cython(): + has_module('Cython') + runtest_autowrap_twice('C', 'cython') + + +def test_autowrap_trace_C_Cython(): + has_module('Cython') + runtest_autowrap_trace('C99', 'cython') + + +def test_autowrap_matrix_vector_C_cython(): + has_module('Cython') + runtest_autowrap_matrix_vector('C99', 'cython') + + +def test_autowrap_matrix_matrix_C_cython(): + has_module('Cython') + runtest_autowrap_matrix_matrix('C99', 'cython') + + +def test_ufuncify_C_Cython(): + has_module('Cython') + runtest_ufuncify('C99', 'cython') + + +def test_issue_10274_C_cython(): + has_module('Cython') + runtest_issue_10274('C89', 'cython') + + +def test_issue_15337_C_cython(): + has_module('Cython') + runtest_issue_15337('C89', 'cython') + + +def test_autowrap_custom_printer(): + has_module('Cython') + + from sympy.core.numbers import pi + from sympy.utilities.codegen import C99CodeGen + from sympy.printing.c import C99CodePrinter + + class PiPrinter(C99CodePrinter): + def _print_Pi(self, expr): + return "S_PI" + + printer = PiPrinter() + gen = C99CodeGen(printer=printer) + gen.preprocessor_statements.append('#include "shortpi.h"') + + expr = pi * a + + expected = ( + '#include "%s"\n' + '#include \n' + '#include "shortpi.h"\n' + '\n' + 'double autofunc(double a) {\n' + '\n' + ' double autofunc_result;\n' + ' autofunc_result = S_PI*a;\n' + ' return autofunc_result;\n' + '\n' + '}\n' + ) + + tmpdir = tempfile.mkdtemp() + # write a trivial header file to use in the generated code + Path(os.path.join(tmpdir, 'shortpi.h')).write_text('#define S_PI 3.14') + + func = autowrap(expr, backend='cython', tempdir=tmpdir, code_gen=gen) + + assert func(4.2) == 3.14 * 4.2 + + # check that the generated code is correct + for filename in os.listdir(tmpdir): + if filename.startswith('wrapped_code') and filename.endswith('.c'): + with open(os.path.join(tmpdir, filename)) as f: + lines = f.readlines() + expected = expected % filename.replace('.c', '.h') + assert ''.join(lines[7:]) == expected + + +# Numpy + +def test_ufuncify_numpy(): + # This test doesn't use Cython, but if Cython works, then there is a valid + # C compiler, which is needed. + has_module('Cython') + runtest_ufuncify('C99', 'numpy') diff --git a/.venv/lib/python3.13/site-packages/sympy/external/tests/test_codegen.py b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..8a4fe28300b86fb0b38d98fcf2fcbbe514cf720f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_codegen.py @@ -0,0 +1,375 @@ +# This tests the compilation and execution of the source code generated with +# utilities.codegen. The compilation takes place in a temporary directory that +# is removed after the test. By default the test directory is always removed, +# but this behavior can be changed by setting the environment variable +# SYMPY_TEST_CLEAN_TEMP to: +# export SYMPY_TEST_CLEAN_TEMP=always : the default behavior. +# export SYMPY_TEST_CLEAN_TEMP=success : only remove the directories of working tests. +# export SYMPY_TEST_CLEAN_TEMP=never : never remove the directories with the test code. +# When a directory is not removed, the necessary information is printed on +# screen to find the files that belong to the (failed) tests. If a test does +# not fail, py.test captures all the output and you will not see the directories +# corresponding to the successful tests. Use the --nocapture option to see all +# the output. + +# All tests below have a counterpart in utilities/test/test_codegen.py. In the +# latter file, the resulting code is compared with predefined strings, without +# compilation or execution. + +# All the generated Fortran code should conform with the Fortran 95 standard, +# and all the generated C code should be ANSI C, which facilitates the +# incorporation in various projects. The tests below assume that the binary cc +# is somewhere in the path and that it can compile ANSI C code. + +from sympy.abc import x, y, z +from sympy.testing.pytest import IS_WASM, skip +from sympy.utilities.codegen import codegen, make_routine, get_code_generator +import sys +import os +import tempfile +import subprocess +from pathlib import Path + + +# templates for the main program that will test the generated code. + +main_template = {} +main_template['F95'] = """ +program main + include "codegen.h" + integer :: result; + result = 0 + + %(statements)s + + call exit(result) +end program +""" + +main_template['C89'] = """ +#include "codegen.h" +#include +#include + +int main() { + int result = 0; + + %(statements)s + + return result; +} +""" +main_template['C99'] = main_template['C89'] +# templates for the numerical tests + +numerical_test_template = {} +numerical_test_template['C89'] = """ + if (fabs(%(call)s)>%(threshold)s) { + printf("Numerical validation failed: %(call)s=%%e threshold=%(threshold)s\\n", %(call)s); + result = -1; + } +""" +numerical_test_template['C99'] = numerical_test_template['C89'] + +numerical_test_template['F95'] = """ + if (abs(%(call)s)>%(threshold)s) then + write(6,"('Numerical validation failed:')") + write(6,"('%(call)s=',e15.5,'threshold=',e15.5)") %(call)s, %(threshold)s + result = -1; + end if +""" +# command sequences for supported compilers + +compile_commands = {} +compile_commands['cc'] = [ + "cc -c codegen.c -o codegen.o", + "cc -c main.c -o main.o", + "cc main.o codegen.o -lm -o test.exe" +] + +compile_commands['gfortran'] = [ + "gfortran -c codegen.f90 -o codegen.o", + "gfortran -ffree-line-length-none -c main.f90 -o main.o", + "gfortran main.o codegen.o -o test.exe" +] + +compile_commands['g95'] = [ + "g95 -c codegen.f90 -o codegen.o", + "g95 -ffree-line-length-huge -c main.f90 -o main.o", + "g95 main.o codegen.o -o test.exe" +] + +compile_commands['ifort'] = [ + "ifort -c codegen.f90 -o codegen.o", + "ifort -c main.f90 -o main.o", + "ifort main.o codegen.o -o test.exe" +] + +combinations_lang_compiler = [ + ('C89', 'cc'), + ('C99', 'cc'), + ('F95', 'ifort'), + ('F95', 'gfortran'), + ('F95', 'g95') +] + +def try_run(commands): + """Run a series of commands and only return True if all ran fine.""" + if IS_WASM: + return False + with open(os.devnull, 'w') as null: + for command in commands: + retcode = subprocess.call(command, stdout=null, shell=True, + stderr=subprocess.STDOUT) + if retcode != 0: + return False + return True + + +def run_test(label, routines, numerical_tests, language, commands, friendly=True): + """A driver for the codegen tests. + + This driver assumes that a compiler ifort is present in the PATH and that + ifort is (at least) a Fortran 90 compiler. The generated code is written in + a temporary directory, together with a main program that validates the + generated code. The test passes when the compilation and the validation + run correctly. + """ + + # Check input arguments before touching the file system + language = language.upper() + assert language in main_template + assert language in numerical_test_template + + # Check that environment variable makes sense + clean = os.getenv('SYMPY_TEST_CLEAN_TEMP', 'always').lower() + if clean not in ('always', 'success', 'never'): + raise ValueError("SYMPY_TEST_CLEAN_TEMP must be one of the following: 'always', 'success' or 'never'.") + + # Do all the magic to compile, run and validate the test code + # 1) prepare the temporary working directory, switch to that dir + work = tempfile.mkdtemp("_sympy_%s_test" % language, "%s_" % label) + oldwork = os.getcwd() + os.chdir(work) + + # 2) write the generated code + if friendly: + # interpret the routines as a name_expr list and call the friendly + # function codegen + codegen(routines, language, "codegen", to_files=True) + else: + code_gen = get_code_generator(language, "codegen") + code_gen.write(routines, "codegen", to_files=True) + + # 3) write a simple main program that links to the generated code, and that + # includes the numerical tests + test_strings = [] + for fn_name, args, expected, threshold in numerical_tests: + call_string = "%s(%s)-(%s)" % ( + fn_name, ",".join(str(arg) for arg in args), expected) + if language == "F95": + call_string = fortranize_double_constants(call_string) + threshold = fortranize_double_constants(str(threshold)) + test_strings.append(numerical_test_template[language] % { + "call": call_string, + "threshold": threshold, + }) + + if language == "F95": + f_name = "main.f90" + elif language.startswith("C"): + f_name = "main.c" + else: + raise NotImplementedError( + "FIXME: filename extension unknown for language: %s" % language) + + Path(f_name).write_text( + main_template[language] % {'statements': "".join(test_strings)}) + + # 4) Compile and link + compiled = try_run(commands) + + # 5) Run if compiled + if compiled: + executed = try_run(["./test.exe"]) + else: + executed = False + + # 6) Clean up stuff + if clean == 'always' or (clean == 'success' and compiled and executed): + def safe_remove(filename): + if os.path.isfile(filename): + os.remove(filename) + safe_remove("codegen.f90") + safe_remove("codegen.c") + safe_remove("codegen.h") + safe_remove("codegen.o") + safe_remove("main.f90") + safe_remove("main.c") + safe_remove("main.o") + safe_remove("test.exe") + os.chdir(oldwork) + os.rmdir(work) + else: + print("TEST NOT REMOVED: %s" % work, file=sys.stderr) + os.chdir(oldwork) + + # 7) Do the assertions in the end + assert compiled, "failed to compile %s code with:\n%s" % ( + language, "\n".join(commands)) + assert executed, "failed to execute %s code from:\n%s" % ( + language, "\n".join(commands)) + + +def fortranize_double_constants(code_string): + """ + Replaces every literal float with literal doubles + """ + import re + pattern_exp = re.compile(r'\d+(\.)?\d*[eE]-?\d+') + pattern_float = re.compile(r'\d+\.\d*(?!\d*d)') + + def subs_exp(matchobj): + return re.sub('[eE]', 'd', matchobj.group(0)) + + def subs_float(matchobj): + return "%sd0" % matchobj.group(0) + + code_string = pattern_exp.sub(subs_exp, code_string) + code_string = pattern_float.sub(subs_float, code_string) + + return code_string + + +def is_feasible(language, commands): + # This test should always work, otherwise the compiler is not present. + routine = make_routine("test", x) + numerical_tests = [ + ("test", ( 1.0,), 1.0, 1e-15), + ("test", (-1.0,), -1.0, 1e-15), + ] + try: + run_test("is_feasible", [routine], numerical_tests, language, commands, + friendly=False) + return True + except AssertionError: + return False + +valid_lang_commands = [] +invalid_lang_compilers = [] +for lang, compiler in combinations_lang_compiler: + commands = compile_commands[compiler] + if is_feasible(lang, commands): + valid_lang_commands.append((lang, commands)) + else: + invalid_lang_compilers.append((lang, compiler)) + +# We test all language-compiler combinations, just to report what is skipped + +def test_C89_cc(): + if ("C89", 'cc') in invalid_lang_compilers: + skip("`cc' command didn't work as expected (C89)") + + +def test_C99_cc(): + if ("C99", 'cc') in invalid_lang_compilers: + skip("`cc' command didn't work as expected (C99)") + + +def test_F95_ifort(): + if ("F95", 'ifort') in invalid_lang_compilers: + skip("`ifort' command didn't work as expected") + + +def test_F95_gfortran(): + if ("F95", 'gfortran') in invalid_lang_compilers: + skip("`gfortran' command didn't work as expected") + + +def test_F95_g95(): + if ("F95", 'g95') in invalid_lang_compilers: + skip("`g95' command didn't work as expected") + +# Here comes the actual tests + + +def test_basic_codegen(): + numerical_tests = [ + ("test", (1.0, 6.0, 3.0), 21.0, 1e-15), + ("test", (-1.0, 2.0, -2.5), -2.5, 1e-15), + ] + name_expr = [("test", (x + y)*z)] + for lang, commands in valid_lang_commands: + run_test("basic_codegen", name_expr, numerical_tests, lang, commands) + + +def test_intrinsic_math1_codegen(): + # not included: log10 + from sympy.core.evalf import N + from sympy.functions import ln + from sympy.functions.elementary.exponential import log + from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh) + from sympy.functions.elementary.integers import (ceiling, floor) + from sympy.functions.elementary.miscellaneous import sqrt + from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan) + name_expr = [ + ("test_fabs", abs(x)), + ("test_acos", acos(x)), + ("test_asin", asin(x)), + ("test_atan", atan(x)), + ("test_cos", cos(x)), + ("test_cosh", cosh(x)), + ("test_log", log(x)), + ("test_ln", ln(x)), + ("test_sin", sin(x)), + ("test_sinh", sinh(x)), + ("test_sqrt", sqrt(x)), + ("test_tan", tan(x)), + ("test_tanh", tanh(x)), + ] + numerical_tests = [] + for name, expr in name_expr: + for xval in 0.2, 0.5, 0.8: + expected = N(expr.subs(x, xval)) + numerical_tests.append((name, (xval,), expected, 1e-14)) + for lang, commands in valid_lang_commands: + if lang.startswith("C"): + name_expr_C = [("test_floor", floor(x)), ("test_ceil", ceiling(x))] + else: + name_expr_C = [] + run_test("intrinsic_math1", name_expr + name_expr_C, + numerical_tests, lang, commands) + + +def test_instrinsic_math2_codegen(): + # not included: frexp, ldexp, modf, fmod + from sympy.core.evalf import N + from sympy.functions.elementary.trigonometric import atan2 + name_expr = [ + ("test_atan2", atan2(x, y)), + ("test_pow", x**y), + ] + numerical_tests = [] + for name, expr in name_expr: + for xval, yval in (0.2, 1.3), (0.5, -0.2), (0.8, 0.8): + expected = N(expr.subs(x, xval).subs(y, yval)) + numerical_tests.append((name, (xval, yval), expected, 1e-14)) + for lang, commands in valid_lang_commands: + run_test("intrinsic_math2", name_expr, numerical_tests, lang, commands) + + +def test_complicated_codegen(): + from sympy.core.evalf import N + from sympy.functions.elementary.trigonometric import (cos, sin, tan) + name_expr = [ + ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()), + ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))), + ] + numerical_tests = [] + for name, expr in name_expr: + for xval, yval, zval in (0.2, 1.3, -0.3), (0.5, -0.2, 0.0), (0.8, 2.1, 0.8): + expected = N(expr.subs(x, xval).subs(y, yval).subs(z, zval)) + numerical_tests.append((name, (xval, yval, zval), expected, 1e-12)) + for lang, commands in valid_lang_commands: + run_test( + "complicated_codegen", name_expr, numerical_tests, lang, commands) diff --git a/.venv/lib/python3.13/site-packages/sympy/external/tests/test_gmpy.py b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_gmpy.py new file mode 100644 index 0000000000000000000000000000000000000000..d88f9da0c6c26c15f529ce485fff5b72342170ea --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_gmpy.py @@ -0,0 +1,12 @@ +from sympy.external.gmpy import LONG_MAX, iroot +from sympy.testing.pytest import raises + + +def test_iroot(): + assert iroot(2, LONG_MAX) == (1, False) + assert iroot(2, LONG_MAX + 1) == (1, False) + for x in range(3): + assert iroot(x, 1) == (x, True) + raises(ValueError, lambda: iroot(-1, 1)) + raises(ValueError, lambda: iroot(0, 0)) + raises(ValueError, lambda: iroot(0, -1)) diff --git a/.venv/lib/python3.13/site-packages/sympy/external/tests/test_importtools.py b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_importtools.py new file mode 100644 index 0000000000000000000000000000000000000000..0b954070c179282ed2bcf5735d802c5f22a3a261 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_importtools.py @@ -0,0 +1,40 @@ +from sympy.external import import_module +from sympy.testing.pytest import warns + +# fixes issue that arose in addressing issue 6533 +def test_no_stdlib_collections(): + ''' + make sure we get the right collections when it is not part of a + larger list + ''' + import collections + matplotlib = import_module('matplotlib', + import_kwargs={'fromlist': ['cm', 'collections']}, + min_module_version='1.1.0', catch=(RuntimeError,)) + if matplotlib: + assert collections != matplotlib.collections + +def test_no_stdlib_collections2(): + ''' + make sure we get the right collections when it is not part of a + larger list + ''' + import collections + matplotlib = import_module('matplotlib', + import_kwargs={'fromlist': ['collections']}, + min_module_version='1.1.0', catch=(RuntimeError,)) + if matplotlib: + assert collections != matplotlib.collections + +def test_no_stdlib_collections3(): + '''make sure we get the right collections with no catch''' + import collections + matplotlib = import_module('matplotlib', + import_kwargs={'fromlist': ['cm', 'collections']}, + min_module_version='1.1.0') + if matplotlib: + assert collections != matplotlib.collections + +def test_min_module_version_python3_basestring_error(): + with warns(UserWarning): + import_module('mpmath', min_module_version='1000.0.1') diff --git a/.venv/lib/python3.13/site-packages/sympy/external/tests/test_ntheory.py b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_ntheory.py new file mode 100644 index 0000000000000000000000000000000000000000..00824481ad27aa9071ea5801fb3bde75cacbc3c8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_ntheory.py @@ -0,0 +1,307 @@ +from itertools import permutations + +from sympy.external.ntheory import (bit_scan1, remove, bit_scan0, is_fermat_prp, + is_euler_prp, is_strong_prp, gcdext, _lucas_sequence, + is_fibonacci_prp, is_lucas_prp, is_selfridge_prp, + is_strong_lucas_prp, is_strong_selfridge_prp, + is_bpsw_prp, is_strong_bpsw_prp) +from sympy.testing.pytest import raises + + +def test_bit_scan1(): + assert bit_scan1(0) is None + assert bit_scan1(1) == 0 + assert bit_scan1(-1) == 0 + assert bit_scan1(2) == 1 + assert bit_scan1(7) == 0 + assert bit_scan1(-7) == 0 + for i in range(100): + assert bit_scan1(1 << i) == i + assert bit_scan1((1 << i) * 31337) == i + for i in range(500): + n = (1 << 500) + (1 << i) + assert bit_scan1(n) == i + assert bit_scan1(1 << 1000001) == 1000001 + assert bit_scan1((1 << 273956)*7**37) == 273956 + # issue 12709 + for i in range(1, 10): + big = 1 << i + assert bit_scan1(-big) == bit_scan1(big) + + +def test_bit_scan0(): + assert bit_scan0(-1) is None + assert bit_scan0(0) == 0 + assert bit_scan0(1) == 1 + assert bit_scan0(-2) == 0 + + +def test_remove(): + raises(ValueError, lambda: remove(1, 1)) + assert remove(0, 3) == (0, 0) + for f in range(2, 10): + for y in range(2, 1000): + for z in [1, 17, 101, 1009]: + assert remove(z*f**y, f) == (z, y) + + +def test_gcdext(): + assert gcdext(0, 0) == (0, 0, 0) + assert gcdext(3, 0) == (3, 1, 0) + assert gcdext(0, 4) == (4, 0, 1) + + for n in range(1, 10): + assert gcdext(n, 1) == gcdext(-n, 1) == (1, 0, 1) + assert gcdext(n, -1) == gcdext(-n, -1) == (1, 0, -1) + assert gcdext(n, n) == gcdext(-n, n) == (n, 0, 1) + assert gcdext(n, -n) == gcdext(-n, -n) == (n, 0, -1) + + for n in range(2, 10): + assert gcdext(1, n) == gcdext(1, -n) == (1, 1, 0) + assert gcdext(-1, n) == gcdext(-1, -n) == (1, -1, 0) + + for a, b in permutations([2**5, 3, 5, 7**2, 11], 2): + g, x, y = gcdext(a, b) + assert g == a*x + b*y == 1 + + +def test_is_fermat_prp(): + # invalid input + raises(ValueError, lambda: is_fermat_prp(0, 10)) + raises(ValueError, lambda: is_fermat_prp(5, 1)) + + # n = 1 + assert not is_fermat_prp(1, 3) + + # n is prime + assert is_fermat_prp(2, 4) + assert is_fermat_prp(3, 2) + assert is_fermat_prp(11, 3) + assert is_fermat_prp(2**31-1, 5) + + # A001567 + pseudorpime = [341, 561, 645, 1105, 1387, 1729, 1905, 2047, + 2465, 2701, 2821, 3277, 4033, 4369, 4371, 4681] + for n in pseudorpime: + assert is_fermat_prp(n, 2) + + # A020136 + pseudorpime = [15, 85, 91, 341, 435, 451, 561, 645, 703, 1105, + 1247, 1271, 1387, 1581, 1695, 1729, 1891, 1905] + for n in pseudorpime: + assert is_fermat_prp(n, 4) + + +def test_is_euler_prp(): + # invalid input + raises(ValueError, lambda: is_euler_prp(0, 10)) + raises(ValueError, lambda: is_euler_prp(5, 1)) + + # n = 1 + assert not is_euler_prp(1, 3) + + # n is prime + assert is_euler_prp(2, 4) + assert is_euler_prp(3, 2) + assert is_euler_prp(11, 3) + assert is_euler_prp(2**31-1, 5) + + # A047713 + pseudorpime = [561, 1105, 1729, 1905, 2047, 2465, 3277, 4033, + 4681, 6601, 8321, 8481, 10585, 12801, 15841] + for n in pseudorpime: + assert is_euler_prp(n, 2) + + # A048950 + pseudorpime = [121, 703, 1729, 1891, 2821, 3281, 7381, 8401, + 8911, 10585, 12403, 15457, 15841, 16531, 18721] + for n in pseudorpime: + assert is_euler_prp(n, 3) + + +def test_is_strong_prp(): + # invalid input + raises(ValueError, lambda: is_strong_prp(0, 10)) + raises(ValueError, lambda: is_strong_prp(5, 1)) + + # n = 1 + assert not is_strong_prp(1, 3) + + # n is prime + assert is_strong_prp(2, 4) + assert is_strong_prp(3, 2) + assert is_strong_prp(11, 3) + assert is_strong_prp(2**31-1, 5) + + # A001262 + pseudorpime = [2047, 3277, 4033, 4681, 8321, 15841, 29341, + 42799, 49141, 52633, 65281, 74665, 80581] + for n in pseudorpime: + assert is_strong_prp(n, 2) + + # A020229 + pseudorpime = [121, 703, 1891, 3281, 8401, 8911, 10585, 12403, + 16531, 18721, 19345, 23521, 31621, 44287, 47197] + for n in pseudorpime: + assert is_strong_prp(n, 3) + + +def test_lucas_sequence(): + def lucas_u(P, Q, length): + array = [0] * length + array[1] = 1 + for k in range(2, length): + array[k] = P * array[k - 1] - Q * array[k - 2] + return array + + def lucas_v(P, Q, length): + array = [0] * length + array[0] = 2 + array[1] = P + for k in range(2, length): + array[k] = P * array[k - 1] - Q * array[k - 2] + return array + + length = 20 + for P in range(-10, 10): + for Q in range(-10, 10): + D = P**2 - 4*Q + if D == 0: + continue + us = lucas_u(P, Q, length) + vs = lucas_v(P, Q, length) + for n in range(3, 100, 2): + for k in range(length): + U, V, Qk = _lucas_sequence(n, P, Q, k) + assert U == us[k] % n + assert V == vs[k] % n + assert pow(Q, k, n) == Qk + + +def test_is_fibonacci_prp(): + # invalid input + raises(ValueError, lambda: is_fibonacci_prp(3, 2, 1)) + raises(ValueError, lambda: is_fibonacci_prp(3, -5, 1)) + raises(ValueError, lambda: is_fibonacci_prp(3, 5, 2)) + raises(ValueError, lambda: is_fibonacci_prp(0, 5, -1)) + + # n = 1 + assert not is_fibonacci_prp(1, 3, 1) + + # n is prime + assert is_fibonacci_prp(2, 5, 1) + assert is_fibonacci_prp(3, 6, -1) + assert is_fibonacci_prp(11, 7, 1) + assert is_fibonacci_prp(2**31-1, 8, -1) + + # A005845 + pseudorpime = [705, 2465, 2737, 3745, 4181, 5777, 6721, + 10877, 13201, 15251, 24465, 29281, 34561] + for n in pseudorpime: + assert is_fibonacci_prp(n, 1, -1) + + +def test_is_lucas_prp(): + # invalid input + raises(ValueError, lambda: is_lucas_prp(3, 2, 1)) + raises(ValueError, lambda: is_lucas_prp(0, 5, -1)) + raises(ValueError, lambda: is_lucas_prp(15, 3, 1)) + + # n = 1 + assert not is_lucas_prp(1, 3, 1) + + # n is prime + assert is_lucas_prp(2, 5, 2) + assert is_lucas_prp(3, 6, -1) + assert is_lucas_prp(11, 7, 5) + assert is_lucas_prp(2**31-1, 8, -3) + + # A081264 + pseudorpime = [323, 377, 1891, 3827, 4181, 5777, 6601, 6721, + 8149, 10877, 11663, 13201, 13981, 15251, 17119] + for n in pseudorpime: + assert is_lucas_prp(n, 1, -1) + + +def test_is_selfridge_prp(): + # invalid input + raises(ValueError, lambda: is_selfridge_prp(0)) + + # n = 1 + assert not is_selfridge_prp(1) + + # n is prime + assert is_selfridge_prp(2) + assert is_selfridge_prp(3) + assert is_selfridge_prp(11) + assert is_selfridge_prp(2**31-1) + + # A217120 + pseudorpime = [323, 377, 1159, 1829, 3827, 5459, 5777, 9071, + 9179, 10877, 11419, 11663, 13919, 14839, 16109] + for n in pseudorpime: + assert is_selfridge_prp(n) + + +def test_is_strong_lucas_prp(): + # invalid input + raises(ValueError, lambda: is_strong_lucas_prp(3, 2, 1)) + raises(ValueError, lambda: is_strong_lucas_prp(0, 5, -1)) + raises(ValueError, lambda: is_strong_lucas_prp(15, 3, 1)) + + # n = 1 + assert not is_strong_lucas_prp(1, 3, 1) + + # n is prime + assert is_strong_lucas_prp(2, 5, 2) + assert is_strong_lucas_prp(3, 6, -1) + assert is_strong_lucas_prp(11, 7, 5) + assert is_strong_lucas_prp(2**31-1, 8, -3) + + +def test_is_strong_selfridge_prp(): + # invalid input + raises(ValueError, lambda: is_strong_selfridge_prp(0)) + + # n = 1 + assert not is_strong_selfridge_prp(1) + + # n is prime + assert is_strong_selfridge_prp(2) + assert is_strong_selfridge_prp(3) + assert is_strong_selfridge_prp(11) + assert is_strong_selfridge_prp(2**31-1) + + # A217255 + pseudorpime = [5459, 5777, 10877, 16109, 18971, 22499, 24569, + 25199, 40309, 58519, 75077, 97439, 100127, 113573] + for n in pseudorpime: + assert is_strong_selfridge_prp(n) + + +def test_is_bpsw_prp(): + # invalid input + raises(ValueError, lambda: is_bpsw_prp(0)) + + # n = 1 + assert not is_bpsw_prp(1) + + # n is prime + assert is_bpsw_prp(2) + assert is_bpsw_prp(3) + assert is_bpsw_prp(11) + assert is_bpsw_prp(2**31-1) + + +def test_is_strong_bpsw_prp(): + # invalid input + raises(ValueError, lambda: is_strong_bpsw_prp(0)) + + # n = 1 + assert not is_strong_bpsw_prp(1) + + # n is prime + assert is_strong_bpsw_prp(2) + assert is_strong_bpsw_prp(3) + assert is_strong_bpsw_prp(11) + assert is_strong_bpsw_prp(2**31-1) diff --git a/.venv/lib/python3.13/site-packages/sympy/external/tests/test_numpy.py b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..cd456d0d6cc49138c29d7ab28ee02694448d578f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_numpy.py @@ -0,0 +1,335 @@ +# This testfile tests SymPy <-> NumPy compatibility + +# Don't test any SymPy features here. Just pure interaction with NumPy. +# Always write regular SymPy tests for anything, that can be tested in pure +# Python (without numpy). Here we test everything, that a user may need when +# using SymPy with NumPy +from sympy.external.importtools import version_tuple +from sympy.external import import_module + +numpy = import_module('numpy') +if numpy: + array, matrix, ndarray = numpy.array, numpy.matrix, numpy.ndarray +else: + #bin/test will not execute any tests now + disabled = True + + +from sympy.core.numbers import (Float, Integer, Rational) +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.trigonometric import sin +from sympy.matrices.dense import (Matrix, list2numpy, matrix2numpy, symarray) +from sympy.utilities.lambdify import lambdify +import sympy + +import mpmath +from sympy.abc import x, y, z +from sympy.utilities.decorator import conserve_mpmath_dps +from sympy.utilities.exceptions import ignore_warnings +from sympy.testing.pytest import raises + + +# first, systematically check, that all operations are implemented and don't +# raise an exception + + +def test_systematic_basic(): + def s(sympy_object, numpy_array): + _ = [sympy_object + numpy_array, + numpy_array + sympy_object, + sympy_object - numpy_array, + numpy_array - sympy_object, + sympy_object * numpy_array, + numpy_array * sympy_object, + sympy_object / numpy_array, + numpy_array / sympy_object, + sympy_object ** numpy_array, + numpy_array ** sympy_object] + x = Symbol("x") + y = Symbol("y") + sympy_objs = [ + Rational(2, 3), + Float("1.3"), + x, + y, + pow(x, y)*y, + Integer(5), + Float(5.5), + ] + numpy_objs = [ + array([1]), + array([3, 8, -1]), + array([x, x**2, Rational(5)]), + array([x/y*sin(y), 5, Rational(5)]), + ] + for x in sympy_objs: + for y in numpy_objs: + s(x, y) + + +# now some random tests, that test particular problems and that also +# check that the results of the operations are correct + +def test_basics(): + one = Rational(1) + zero = Rational(0) + assert array(1) == array(one) + assert array([one]) == array([one]) + assert array([x]) == array([x]) + assert array(x) == array(Symbol("x")) + assert array(one + x) == array(1 + x) + + X = array([one, zero, zero]) + assert (X == array([one, zero, zero])).all() + assert (X == array([one, 0, 0])).all() + + +def test_arrays(): + one = Rational(1) + zero = Rational(0) + X = array([one, zero, zero]) + Y = one*X + X = array([Symbol("a") + Rational(1, 2)]) + Y = X + X + assert Y == array([1 + 2*Symbol("a")]) + Y = Y + 1 + assert Y == array([2 + 2*Symbol("a")]) + Y = X - X + assert Y == array([0]) + + +def test_conversion1(): + a = list2numpy([x**2, x]) + #looks like an array? + assert isinstance(a, ndarray) + assert a[0] == x**2 + assert a[1] == x + assert len(a) == 2 + #yes, it's the array + + +def test_conversion2(): + a = 2*list2numpy([x**2, x]) + b = list2numpy([2*x**2, 2*x]) + assert (a == b).all() + + one = Rational(1) + zero = Rational(0) + X = list2numpy([one, zero, zero]) + Y = one*X + X = list2numpy([Symbol("a") + Rational(1, 2)]) + Y = X + X + assert Y == array([1 + 2*Symbol("a")]) + Y = Y + 1 + assert Y == array([2 + 2*Symbol("a")]) + Y = X - X + assert Y == array([0]) + + +def test_list2numpy(): + assert (array([x**2, x]) == list2numpy([x**2, x])).all() + + +def test_Matrix1(): + m = Matrix([[x, x**2], [5, 2/x]]) + assert (array(m.subs(x, 2)) == array([[2, 4], [5, 1]])).all() + m = Matrix([[sin(x), x**2], [5, 2/x]]) + assert (array(m.subs(x, 2)) == array([[sin(2), 4], [5, 1]])).all() + + +def test_Matrix2(): + m = Matrix([[x, x**2], [5, 2/x]]) + with ignore_warnings(PendingDeprecationWarning): + assert (matrix(m.subs(x, 2)) == matrix([[2, 4], [5, 1]])).all() + m = Matrix([[sin(x), x**2], [5, 2/x]]) + with ignore_warnings(PendingDeprecationWarning): + assert (matrix(m.subs(x, 2)) == matrix([[sin(2), 4], [5, 1]])).all() + + +def test_Matrix3(): + a = array([[2, 4], [5, 1]]) + assert Matrix(a) == Matrix([[2, 4], [5, 1]]) + assert Matrix(a) != Matrix([[2, 4], [5, 2]]) + a = array([[sin(2), 4], [5, 1]]) + assert Matrix(a) == Matrix([[sin(2), 4], [5, 1]]) + assert Matrix(a) != Matrix([[sin(0), 4], [5, 1]]) + + +def test_Matrix4(): + with ignore_warnings(PendingDeprecationWarning): + a = matrix([[2, 4], [5, 1]]) + assert Matrix(a) == Matrix([[2, 4], [5, 1]]) + assert Matrix(a) != Matrix([[2, 4], [5, 2]]) + with ignore_warnings(PendingDeprecationWarning): + a = matrix([[sin(2), 4], [5, 1]]) + assert Matrix(a) == Matrix([[sin(2), 4], [5, 1]]) + assert Matrix(a) != Matrix([[sin(0), 4], [5, 1]]) + + +def test_Matrix_sum(): + M = Matrix([[1, 2, 3], [x, y, x], [2*y, -50, z*x]]) + with ignore_warnings(PendingDeprecationWarning): + m = matrix([[2, 3, 4], [x, 5, 6], [x, y, z**2]]) + assert M + m == Matrix([[3, 5, 7], [2*x, y + 5, x + 6], [2*y + x, y - 50, z*x + z**2]]) + assert m + M == Matrix([[3, 5, 7], [2*x, y + 5, x + 6], [2*y + x, y - 50, z*x + z**2]]) + assert M + m == M.add(m) + + +def test_Matrix_mul(): + M = Matrix([[1, 2, 3], [x, y, x]]) + with ignore_warnings(PendingDeprecationWarning): + m = matrix([[2, 4], [x, 6], [x, z**2]]) + assert M*m == Matrix([ + [ 2 + 5*x, 16 + 3*z**2], + [2*x + x*y + x**2, 4*x + 6*y + x*z**2], + ]) + + assert m*M == Matrix([ + [ 2 + 4*x, 4 + 4*y, 6 + 4*x], + [ 7*x, 2*x + 6*y, 9*x], + [x + x*z**2, 2*x + y*z**2, 3*x + x*z**2], + ]) + a = array([2]) + assert a[0] * M == 2 * M + assert M * a[0] == 2 * M + + +def test_Matrix_array(): + class matarray: + def __array__(self, dtype=object, copy=None): + if copy is not None and not copy: + raise TypeError("Cannot implement copy=False when converting Matrix to ndarray") + from numpy import array + return array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + matarr = matarray() + assert Matrix(matarr) == Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + +def test_matrix2numpy(): + a = matrix2numpy(Matrix([[1, x**2], [3*sin(x), 0]])) + assert isinstance(a, ndarray) + assert a.shape == (2, 2) + assert a[0, 0] == 1 + assert a[0, 1] == x**2 + assert a[1, 0] == 3*sin(x) + assert a[1, 1] == 0 + + +def test_matrix2numpy_conversion(): + a = Matrix([[1, 2, sin(x)], [x**2, x, Rational(1, 2)]]) + b = array([[1, 2, sin(x)], [x**2, x, Rational(1, 2)]]) + assert (matrix2numpy(a) == b).all() + assert matrix2numpy(a).dtype == numpy.dtype('object') + + c = matrix2numpy(Matrix([[1, 2], [10, 20]]), dtype='int8') + d = matrix2numpy(Matrix([[1, 2], [10, 20]]), dtype='float64') + assert c.dtype == numpy.dtype('int8') + assert d.dtype == numpy.dtype('float64') + + +def test_issue_3728(): + assert (Rational(1, 2)*array([2*x, 0]) == array([x, 0])).all() + assert (Rational(1, 2) + array( + [2*x, 0]) == array([2*x + Rational(1, 2), Rational(1, 2)])).all() + assert (Float("0.5")*array([2*x, 0]) == array([Float("1.0")*x, 0])).all() + assert (Float("0.5") + array( + [2*x, 0]) == array([2*x + Float("0.5"), Float("0.5")])).all() + + +@conserve_mpmath_dps +def test_lambdify(): + mpmath.mp.dps = 16 + sin02 = mpmath.mpf("0.198669330795061215459412627") + f = lambdify(x, sin(x), "numpy") + prec = 1e-15 + assert -prec < f(0.2) - sin02 < prec + + # if this succeeds, it can't be a numpy function + + if version_tuple(numpy.__version__) >= version_tuple('1.17'): + with raises(TypeError): + f(x) + else: + with raises(AttributeError): + f(x) + + +def test_lambdify_matrix(): + f = lambdify(x, Matrix([[x, 2*x], [1, 2]]), [{'ImmutableMatrix': numpy.array}, "numpy"]) + assert (f(1) == array([[1, 2], [1, 2]])).all() + + +def test_lambdify_matrix_multi_input(): + M = sympy.Matrix([[x**2, x*y, x*z], + [y*x, y**2, y*z], + [z*x, z*y, z**2]]) + f = lambdify((x, y, z), M, [{'ImmutableMatrix': numpy.array}, "numpy"]) + + xh, yh, zh = 1.0, 2.0, 3.0 + expected = array([[xh**2, xh*yh, xh*zh], + [yh*xh, yh**2, yh*zh], + [zh*xh, zh*yh, zh**2]]) + actual = f(xh, yh, zh) + assert numpy.allclose(actual, expected) + + +def test_lambdify_matrix_vec_input(): + X = sympy.DeferredVector('X') + M = Matrix([ + [X[0]**2, X[0]*X[1], X[0]*X[2]], + [X[1]*X[0], X[1]**2, X[1]*X[2]], + [X[2]*X[0], X[2]*X[1], X[2]**2]]) + f = lambdify(X, M, [{'ImmutableMatrix': numpy.array}, "numpy"]) + + Xh = array([1.0, 2.0, 3.0]) + expected = array([[Xh[0]**2, Xh[0]*Xh[1], Xh[0]*Xh[2]], + [Xh[1]*Xh[0], Xh[1]**2, Xh[1]*Xh[2]], + [Xh[2]*Xh[0], Xh[2]*Xh[1], Xh[2]**2]]) + actual = f(Xh) + assert numpy.allclose(actual, expected) + + +def test_lambdify_transl(): + from sympy.utilities.lambdify import NUMPY_TRANSLATIONS + for sym, mat in NUMPY_TRANSLATIONS.items(): + assert sym in sympy.__dict__ + assert mat in numpy.__dict__ + + +def test_symarray(): + """Test creation of numpy arrays of SymPy symbols.""" + + import numpy as np + import numpy.testing as npt + + syms = symbols('_0,_1,_2') + s1 = symarray("", 3) + s2 = symarray("", 3) + npt.assert_array_equal(s1, np.array(syms, dtype=object)) + assert s1[0] == s2[0] + + a = symarray('a', 3) + b = symarray('b', 3) + assert not(a[0] == b[0]) + + asyms = symbols('a_0,a_1,a_2') + npt.assert_array_equal(a, np.array(asyms, dtype=object)) + + # Multidimensional checks + a2d = symarray('a', (2, 3)) + assert a2d.shape == (2, 3) + a00, a12 = symbols('a_0_0,a_1_2') + assert a2d[0, 0] == a00 + assert a2d[1, 2] == a12 + + a3d = symarray('a', (2, 3, 2)) + assert a3d.shape == (2, 3, 2) + a000, a120, a121 = symbols('a_0_0_0,a_1_2_0,a_1_2_1') + assert a3d[0, 0, 0] == a000 + assert a3d[1, 2, 0] == a120 + assert a3d[1, 2, 1] == a121 + + +def test_vectorize(): + assert (numpy.vectorize( + sin)([1, 2, 3]) == numpy.array([sin(1), sin(2), sin(3)])).all() diff --git a/.venv/lib/python3.13/site-packages/sympy/external/tests/test_pythonmpq.py b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_pythonmpq.py new file mode 100644 index 0000000000000000000000000000000000000000..137cfdf5c858544f0811ae666f000cfb368787a0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_pythonmpq.py @@ -0,0 +1,176 @@ +""" +test_pythonmpq.py + +Test the PythonMPQ class for consistency with gmpy2's mpq type. If gmpy2 is +installed run the same tests for both. +""" +from fractions import Fraction +from decimal import Decimal +import pickle +from typing import Callable, List, Tuple, Type + +from sympy.testing.pytest import raises + +from sympy.external.pythonmpq import PythonMPQ + +# +# If gmpy2 is installed then run the tests for both mpq and PythonMPQ. +# That should ensure consistency between the implementation here and mpq. +# +rational_types: List[Tuple[Callable, Type, Callable, Type]] +rational_types = [(PythonMPQ, PythonMPQ, int, int)] +try: + from gmpy2 import mpq, mpz + rational_types.append((mpq, type(mpq(1)), mpz, type(mpz(1)))) +except ImportError: + pass + + +def test_PythonMPQ(): + # + # Test PythonMPQ and also mpq if gmpy/gmpy2 is installed. + # + for Q, TQ, Z, TZ in rational_types: + + def check_Q(q): + assert isinstance(q, TQ) + assert isinstance(q.numerator, TZ) + assert isinstance(q.denominator, TZ) + return q.numerator, q.denominator + + # Check construction from different types + assert check_Q(Q(3)) == (3, 1) + assert check_Q(Q(3, 5)) == (3, 5) + assert check_Q(Q(Q(3, 5))) == (3, 5) + assert check_Q(Q(0.5)) == (1, 2) + assert check_Q(Q('0.5')) == (1, 2) + assert check_Q(Q(Fraction(3, 5))) == (3, 5) + + # https://github.com/aleaxit/gmpy/issues/327 + if Q is PythonMPQ: + assert check_Q(Q(Decimal('0.6'))) == (3, 5) + + # Invalid types + raises(TypeError, lambda: Q([])) + raises(TypeError, lambda: Q([], [])) + + # Check normalisation of signs + assert check_Q(Q(2, 3)) == (2, 3) + assert check_Q(Q(-2, 3)) == (-2, 3) + assert check_Q(Q(2, -3)) == (-2, 3) + assert check_Q(Q(-2, -3)) == (2, 3) + + # Check gcd calculation + assert check_Q(Q(12, 8)) == (3, 2) + + # __int__/__float__ + assert int(Q(5, 3)) == 1 + assert int(Q(-5, 3)) == -1 + assert float(Q(5, 2)) == 2.5 + assert float(Q(-5, 2)) == -2.5 + + # __str__/__repr__ + assert str(Q(2, 1)) == "2" + assert str(Q(1, 2)) == "1/2" + if Q is PythonMPQ: + assert repr(Q(2, 1)) == "MPQ(2,1)" + assert repr(Q(1, 2)) == "MPQ(1,2)" + else: + assert repr(Q(2, 1)) == "mpq(2,1)" + assert repr(Q(1, 2)) == "mpq(1,2)" + + # __bool__ + assert bool(Q(1, 2)) is True + assert bool(Q(0)) is False + + # __eq__/__ne__ + assert (Q(2, 3) == Q(2, 3)) is True + assert (Q(2, 3) == Q(2, 5)) is False + assert (Q(2, 3) != Q(2, 3)) is False + assert (Q(2, 3) != Q(2, 5)) is True + + # __hash__ + assert hash(Q(3, 5)) == hash(Fraction(3, 5)) + + # __reduce__ + q = Q(2, 3) + assert pickle.loads(pickle.dumps(q)) == q + + # __ge__/__gt__/__le__/__lt__ + assert (Q(1, 3) < Q(2, 3)) is True + assert (Q(2, 3) < Q(2, 3)) is False + assert (Q(2, 3) < Q(1, 3)) is False + assert (Q(-2, 3) < Q(1, 3)) is True + assert (Q(1, 3) < Q(-2, 3)) is False + + assert (Q(1, 3) <= Q(2, 3)) is True + assert (Q(2, 3) <= Q(2, 3)) is True + assert (Q(2, 3) <= Q(1, 3)) is False + assert (Q(-2, 3) <= Q(1, 3)) is True + assert (Q(1, 3) <= Q(-2, 3)) is False + + assert (Q(1, 3) > Q(2, 3)) is False + assert (Q(2, 3) > Q(2, 3)) is False + assert (Q(2, 3) > Q(1, 3)) is True + assert (Q(-2, 3) > Q(1, 3)) is False + assert (Q(1, 3) > Q(-2, 3)) is True + + assert (Q(1, 3) >= Q(2, 3)) is False + assert (Q(2, 3) >= Q(2, 3)) is True + assert (Q(2, 3) >= Q(1, 3)) is True + assert (Q(-2, 3) >= Q(1, 3)) is False + assert (Q(1, 3) >= Q(-2, 3)) is True + + # __abs__/__pos__/__neg__ + assert abs(Q(2, 3)) == abs(Q(-2, 3)) == Q(2, 3) + assert +Q(2, 3) == Q(2, 3) + assert -Q(2, 3) == Q(-2, 3) + + # __add__/__radd__ + assert Q(2, 3) + Q(5, 7) == Q(29, 21) + assert Q(2, 3) + 1 == Q(5, 3) + assert 1 + Q(2, 3) == Q(5, 3) + raises(TypeError, lambda: [] + Q(1)) + raises(TypeError, lambda: Q(1) + []) + + # __sub__/__rsub__ + assert Q(2, 3) - Q(5, 7) == Q(-1, 21) + assert Q(2, 3) - 1 == Q(-1, 3) + assert 1 - Q(2, 3) == Q(1, 3) + raises(TypeError, lambda: [] - Q(1)) + raises(TypeError, lambda: Q(1) - []) + + # __mul__/__rmul__ + assert Q(2, 3) * Q(5, 7) == Q(10, 21) + assert Q(2, 3) * 1 == Q(2, 3) + assert 1 * Q(2, 3) == Q(2, 3) + raises(TypeError, lambda: [] * Q(1)) + raises(TypeError, lambda: Q(1) * []) + + # __pow__/__rpow__ + assert Q(2, 3) ** 2 == Q(4, 9) + assert Q(2, 3) ** 1 == Q(2, 3) + assert Q(-2, 3) ** 2 == Q(4, 9) + assert Q(-2, 3) ** -1 == Q(-3, 2) + if Q is PythonMPQ: + raises(TypeError, lambda: 1 ** Q(2, 3)) + raises(TypeError, lambda: Q(1, 4) ** Q(1, 2)) + raises(TypeError, lambda: [] ** Q(1)) + raises(TypeError, lambda: Q(1) ** []) + + # __div__/__rdiv__ + assert Q(2, 3) / Q(5, 7) == Q(14, 15) + assert Q(2, 3) / 1 == Q(2, 3) + assert 1 / Q(2, 3) == Q(3, 2) + raises(TypeError, lambda: [] / Q(1)) + raises(TypeError, lambda: Q(1) / []) + raises(ZeroDivisionError, lambda: Q(1, 2) / Q(0)) + + # __divmod__ + if Q is PythonMPQ: + raises(TypeError, lambda: Q(2, 3) // Q(1, 3)) + raises(TypeError, lambda: Q(2, 3) % Q(1, 3)) + raises(TypeError, lambda: 1 // Q(1, 3)) + raises(TypeError, lambda: 1 % Q(1, 3)) + raises(TypeError, lambda: Q(2, 3) // 1) + raises(TypeError, lambda: Q(2, 3) % 1) diff --git a/.venv/lib/python3.13/site-packages/sympy/external/tests/test_scipy.py b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_scipy.py new file mode 100644 index 0000000000000000000000000000000000000000..3746d1a311eb68bb1af16e18ab152c7236b42bb5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/external/tests/test_scipy.py @@ -0,0 +1,35 @@ +# This testfile tests SymPy <-> SciPy compatibility + +# Don't test any SymPy features here. Just pure interaction with SciPy. +# Always write regular SymPy tests for anything, that can be tested in pure +# Python (without scipy). Here we test everything, that a user may need when +# using SymPy with SciPy + +from sympy.external import import_module + +scipy = import_module('scipy') +if not scipy: + #bin/test will not execute any tests now + disabled = True + +from sympy.functions.special.bessel import jn_zeros + + +def eq(a, b, tol=1e-6): + for x, y in zip(a, b): + if not (abs(x - y) < tol): + return False + return True + + +def test_jn_zeros(): + assert eq(jn_zeros(0, 4, method="scipy"), + [3.141592, 6.283185, 9.424777, 12.566370]) + assert eq(jn_zeros(1, 4, method="scipy"), + [4.493409, 7.725251, 10.904121, 14.066193]) + assert eq(jn_zeros(2, 4, method="scipy"), + [5.763459, 9.095011, 12.322940, 15.514603]) + assert eq(jn_zeros(3, 4, method="scipy"), + [6.987932, 10.417118, 13.698023, 16.923621]) + assert eq(jn_zeros(4, 4, method="scipy"), + [8.182561, 11.704907, 15.039664, 18.301255]) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/__init__.py b/.venv/lib/python3.13/site-packages/sympy/functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed93b2a11754aa26af5eef3932d177374b3ddfd6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/__init__.py @@ -0,0 +1,115 @@ +"""A functions module, includes all the standard functions. + +Combinatorial - factorial, fibonacci, harmonic, bernoulli... +Elementary - hyperbolic, trigonometric, exponential, floor and ceiling, sqrt... +Special - gamma, zeta,spherical harmonics... +""" + +from sympy.functions.combinatorial.factorials import (factorial, factorial2, + rf, ff, binomial, RisingFactorial, FallingFactorial, subfactorial) +from sympy.functions.combinatorial.numbers import (carmichael, fibonacci, lucas, tribonacci, + harmonic, bernoulli, bell, euler, catalan, genocchi, andre, partition, divisor_sigma, + udivisor_sigma, legendre_symbol, jacobi_symbol, kronecker_symbol, mobius, + primenu, primeomega, totient, reduced_totient, primepi, motzkin) +from sympy.functions.elementary.miscellaneous import (sqrt, root, Min, Max, + Id, real_root, cbrt, Rem) +from sympy.functions.elementary.complexes import (re, im, sign, Abs, + conjugate, arg, polar_lift, periodic_argument, unbranched_argument, + principal_branch, transpose, adjoint, polarify, unpolarify) +from sympy.functions.elementary.trigonometric import (sin, cos, tan, + sec, csc, cot, sinc, asin, acos, atan, asec, acsc, acot, atan2) +from sympy.functions.elementary.exponential import (exp_polar, exp, log, + LambertW) +from sympy.functions.elementary.hyperbolic import (sinh, cosh, tanh, coth, + sech, csch, asinh, acosh, atanh, acoth, asech, acsch) +from sympy.functions.elementary.integers import floor, ceiling, frac +from sympy.functions.elementary.piecewise import (Piecewise, piecewise_fold, + piecewise_exclusive) +from sympy.functions.special.error_functions import (erf, erfc, erfi, erf2, + erfinv, erfcinv, erf2inv, Ei, expint, E1, li, Li, Si, Ci, Shi, Chi, + fresnels, fresnelc) +from sympy.functions.special.gamma_functions import (gamma, lowergamma, + uppergamma, polygamma, loggamma, digamma, trigamma, multigamma) +from sympy.functions.special.zeta_functions import (dirichlet_eta, zeta, + lerchphi, polylog, stieltjes, riemann_xi) +from sympy.functions.special.tensor_functions import (Eijk, LeviCivita, + KroneckerDelta) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.functions.special.delta_functions import DiracDelta, Heaviside +from sympy.functions.special.bsplines import bspline_basis, bspline_basis_set, interpolating_spline +from sympy.functions.special.bessel import (besselj, bessely, besseli, besselk, + hankel1, hankel2, jn, yn, jn_zeros, hn1, hn2, airyai, airybi, airyaiprime, airybiprime, marcumq) +from sympy.functions.special.hyper import hyper, meijerg, appellf1 +from sympy.functions.special.polynomials import (legendre, assoc_legendre, + hermite, hermite_prob, chebyshevt, chebyshevu, chebyshevu_root, + chebyshevt_root, laguerre, assoc_laguerre, gegenbauer, jacobi, jacobi_normalized) +from sympy.functions.special.spherical_harmonics import Ynm, Ynm_c, Znm +from sympy.functions.special.elliptic_integrals import (elliptic_k, + elliptic_f, elliptic_e, elliptic_pi) +from sympy.functions.special.beta_functions import beta, betainc, betainc_regularized +from sympy.functions.special.mathieu_functions import (mathieus, mathieuc, + mathieusprime, mathieucprime) +ln = log + +__all__ = [ + 'factorial', 'factorial2', 'rf', 'ff', 'binomial', 'RisingFactorial', + 'FallingFactorial', 'subfactorial', + + 'carmichael', 'fibonacci', 'lucas', 'motzkin', 'tribonacci', 'harmonic', + 'bernoulli', 'bell', 'euler', 'catalan', 'genocchi', 'andre', 'partition', + 'divisor_sigma', 'udivisor_sigma', 'legendre_symbol', 'jacobi_symbol', 'kronecker_symbol', + 'mobius', 'primenu', 'primeomega', 'totient', 'reduced_totient', 'primepi', + + 'sqrt', 'root', 'Min', 'Max', 'Id', 'real_root', 'cbrt', 'Rem', + + 're', 'im', 'sign', 'Abs', 'conjugate', 'arg', 'polar_lift', + 'periodic_argument', 'unbranched_argument', 'principal_branch', + 'transpose', 'adjoint', 'polarify', 'unpolarify', + + 'sin', 'cos', 'tan', 'sec', 'csc', 'cot', 'sinc', 'asin', 'acos', 'atan', + 'asec', 'acsc', 'acot', 'atan2', + + 'exp_polar', 'exp', 'ln', 'log', 'LambertW', + + 'sinh', 'cosh', 'tanh', 'coth', 'sech', 'csch', 'asinh', 'acosh', 'atanh', + 'acoth', 'asech', 'acsch', + + 'floor', 'ceiling', 'frac', + + 'Piecewise', 'piecewise_fold', 'piecewise_exclusive', + + 'erf', 'erfc', 'erfi', 'erf2', 'erfinv', 'erfcinv', 'erf2inv', 'Ei', + 'expint', 'E1', 'li', 'Li', 'Si', 'Ci', 'Shi', 'Chi', 'fresnels', + 'fresnelc', + + 'gamma', 'lowergamma', 'uppergamma', 'polygamma', 'loggamma', 'digamma', + 'trigamma', 'multigamma', + + 'dirichlet_eta', 'zeta', 'lerchphi', 'polylog', 'stieltjes', 'riemann_xi', + + 'Eijk', 'LeviCivita', 'KroneckerDelta', + + 'SingularityFunction', + + 'DiracDelta', 'Heaviside', + + 'bspline_basis', 'bspline_basis_set', 'interpolating_spline', + + 'besselj', 'bessely', 'besseli', 'besselk', 'hankel1', 'hankel2', 'jn', + 'yn', 'jn_zeros', 'hn1', 'hn2', 'airyai', 'airybi', 'airyaiprime', + 'airybiprime', 'marcumq', + + 'hyper', 'meijerg', 'appellf1', + + 'legendre', 'assoc_legendre', 'hermite', 'hermite_prob', 'chebyshevt', + 'chebyshevu', 'chebyshevu_root', 'chebyshevt_root', 'laguerre', + 'assoc_laguerre', 'gegenbauer', 'jacobi', 'jacobi_normalized', + + 'Ynm', 'Ynm_c', 'Znm', + + 'elliptic_k', 'elliptic_f', 'elliptic_e', 'elliptic_pi', + + 'beta', 'betainc', 'betainc_regularized', + + 'mathieus', 'mathieuc', 'mathieusprime', 'mathieucprime', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/__init__.py b/.venv/lib/python3.13/site-packages/sympy/geometry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb85d4ff5d53eb44a039a95cfc2fff687322cc76 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/__init__.py @@ -0,0 +1,45 @@ +""" +A geometry module for the SymPy library. This module contains all of the +entities and functions needed to construct basic geometrical data and to +perform simple informational queries. + +Usage: +====== + +Examples +======== + +""" +from sympy.geometry.point import Point, Point2D, Point3D +from sympy.geometry.line import Line, Ray, Segment, Line2D, Segment2D, Ray2D, \ + Line3D, Segment3D, Ray3D +from sympy.geometry.plane import Plane +from sympy.geometry.ellipse import Ellipse, Circle +from sympy.geometry.polygon import Polygon, RegularPolygon, Triangle, rad, deg +from sympy.geometry.util import are_similar, centroid, convex_hull, idiff, \ + intersection, closest_points, farthest_points +from sympy.geometry.exceptions import GeometryError +from sympy.geometry.curve import Curve +from sympy.geometry.parabola import Parabola + +__all__ = [ + 'Point', 'Point2D', 'Point3D', + + 'Line', 'Ray', 'Segment', 'Line2D', 'Segment2D', 'Ray2D', 'Line3D', + 'Segment3D', 'Ray3D', + + 'Plane', + + 'Ellipse', 'Circle', + + 'Polygon', 'RegularPolygon', 'Triangle', 'rad', 'deg', + + 'are_similar', 'centroid', 'convex_hull', 'idiff', 'intersection', + 'closest_points', 'farthest_points', + + 'GeometryError', + + 'Curve', + + 'Parabola', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/curve.py b/.venv/lib/python3.13/site-packages/sympy/geometry/curve.py new file mode 100644 index 0000000000000000000000000000000000000000..c074f22cad79b1261ad44be4ccface972cdd3b82 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/curve.py @@ -0,0 +1,424 @@ +"""Curves in 2-dimensional Euclidean space. + +Contains +======== +Curve + +""" + +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.core import diff +from sympy.core.containers import Tuple +from sympy.core.symbol import _symbol +from sympy.geometry.entity import GeometryEntity, GeometrySet +from sympy.geometry.point import Point +from sympy.integrals import integrate +from sympy.matrices import Matrix, rot_axis3 +from sympy.utilities.iterables import is_sequence + +from mpmath.libmp.libmpf import prec_to_dps + + +class Curve(GeometrySet): + """A curve in space. + + A curve is defined by parametric functions for the coordinates, a + parameter and the lower and upper bounds for the parameter value. + + Parameters + ========== + + function : list of functions + limits : 3-tuple + Function parameter and lower and upper bounds. + + Attributes + ========== + + functions + parameter + limits + + Raises + ====== + + ValueError + When `functions` are specified incorrectly. + When `limits` are specified incorrectly. + + Examples + ======== + + >>> from sympy import Curve, sin, cos, interpolate + >>> from sympy.abc import t, a + >>> C = Curve((sin(t), cos(t)), (t, 0, 2)) + >>> C.functions + (sin(t), cos(t)) + >>> C.limits + (t, 0, 2) + >>> C.parameter + t + >>> C = Curve((t, interpolate([1, 4, 9, 16], t)), (t, 0, 1)); C + Curve((t, t**2), (t, 0, 1)) + >>> C.subs(t, 4) + Point2D(4, 16) + >>> C.arbitrary_point(a) + Point2D(a, a**2) + + See Also + ======== + + sympy.core.function.Function + sympy.polys.polyfuncs.interpolate + + """ + + def __new__(cls, function, limits): + if not is_sequence(function) or len(function) != 2: + raise ValueError("Function argument should be (x(t), y(t)) " + "but got %s" % str(function)) + if not is_sequence(limits) or len(limits) != 3: + raise ValueError("Limit argument should be (t, tmin, tmax) " + "but got %s" % str(limits)) + + return GeometryEntity.__new__(cls, Tuple(*function), Tuple(*limits)) + + def __call__(self, f): + return self.subs(self.parameter, f) + + def _eval_subs(self, old, new): + if old == self.parameter: + return Point(*[f.subs(old, new) for f in self.functions]) + + def _eval_evalf(self, prec=15, **options): + f, (t, a, b) = self.args + dps = prec_to_dps(prec) + f = tuple([i.evalf(n=dps, **options) for i in f]) + a, b = [i.evalf(n=dps, **options) for i in (a, b)] + return self.func(f, (t, a, b)) + + def arbitrary_point(self, parameter='t'): + """A parameterized point on the curve. + + Parameters + ========== + + parameter : str or Symbol, optional + Default value is 't'. + The Curve's parameter is selected with None or self.parameter + otherwise the provided symbol is used. + + Returns + ======= + + Point : + Returns a point in parametric form. + + Raises + ====== + + ValueError + When `parameter` already appears in the functions. + + Examples + ======== + + >>> from sympy import Curve, Symbol + >>> from sympy.abc import s + >>> C = Curve([2*s, s**2], (s, 0, 2)) + >>> C.arbitrary_point() + Point2D(2*t, t**2) + >>> C.arbitrary_point(C.parameter) + Point2D(2*s, s**2) + >>> C.arbitrary_point(None) + Point2D(2*s, s**2) + >>> C.arbitrary_point(Symbol('a')) + Point2D(2*a, a**2) + + See Also + ======== + + sympy.geometry.point.Point + + """ + if parameter is None: + return Point(*self.functions) + + tnew = _symbol(parameter, self.parameter, real=True) + t = self.parameter + if (tnew.name != t.name and + tnew.name in (f.name for f in self.free_symbols)): + raise ValueError('Symbol %s already appears in object ' + 'and cannot be used as a parameter.' % tnew.name) + return Point(*[w.subs(t, tnew) for w in self.functions]) + + @property + def free_symbols(self): + """Return a set of symbols other than the bound symbols used to + parametrically define the Curve. + + Returns + ======= + + set : + Set of all non-parameterized symbols. + + Examples + ======== + + >>> from sympy.abc import t, a + >>> from sympy import Curve + >>> Curve((t, t**2), (t, 0, 2)).free_symbols + set() + >>> Curve((t, t**2), (t, a, 2)).free_symbols + {a} + + """ + free = set() + for a in self.functions + self.limits[1:]: + free |= a.free_symbols + free = free.difference({self.parameter}) + return free + + @property + def ambient_dimension(self): + """The dimension of the curve. + + Returns + ======= + + int : + the dimension of curve. + + Examples + ======== + + >>> from sympy.abc import t + >>> from sympy import Curve + >>> C = Curve((t, t**2), (t, 0, 2)) + >>> C.ambient_dimension + 2 + + """ + + return len(self.args[0]) + + @property + def functions(self): + """The functions specifying the curve. + + Returns + ======= + + functions : + list of parameterized coordinate functions. + + Examples + ======== + + >>> from sympy.abc import t + >>> from sympy import Curve + >>> C = Curve((t, t**2), (t, 0, 2)) + >>> C.functions + (t, t**2) + + See Also + ======== + + parameter + + """ + return self.args[0] + + @property + def limits(self): + """The limits for the curve. + + Returns + ======= + + limits : tuple + Contains parameter and lower and upper limits. + + Examples + ======== + + >>> from sympy.abc import t + >>> from sympy import Curve + >>> C = Curve([t, t**3], (t, -2, 2)) + >>> C.limits + (t, -2, 2) + + See Also + ======== + + plot_interval + + """ + return self.args[1] + + @property + def parameter(self): + """The curve function variable. + + Returns + ======= + + Symbol : + returns a bound symbol. + + Examples + ======== + + >>> from sympy.abc import t + >>> from sympy import Curve + >>> C = Curve([t, t**2], (t, 0, 2)) + >>> C.parameter + t + + See Also + ======== + + functions + + """ + return self.args[1][0] + + @property + def length(self): + """The curve length. + + Examples + ======== + + >>> from sympy import Curve + >>> from sympy.abc import t + >>> Curve((t, t), (t, 0, 1)).length + sqrt(2) + + """ + integrand = sqrt(sum(diff(func, self.limits[0])**2 for func in self.functions)) + return integrate(integrand, self.limits) + + def plot_interval(self, parameter='t'): + """The plot interval for the default geometric plot of the curve. + + Parameters + ========== + + parameter : str or Symbol, optional + Default value is 't'; + otherwise the provided symbol is used. + + Returns + ======= + + List : + the plot interval as below: + [parameter, lower_bound, upper_bound] + + Examples + ======== + + >>> from sympy import Curve, sin + >>> from sympy.abc import x, s + >>> Curve((x, sin(x)), (x, 1, 2)).plot_interval() + [t, 1, 2] + >>> Curve((x, sin(x)), (x, 1, 2)).plot_interval(s) + [s, 1, 2] + + See Also + ======== + + limits : Returns limits of the parameter interval + + """ + t = _symbol(parameter, self.parameter, real=True) + return [t] + list(self.limits[1:]) + + def rotate(self, angle=0, pt=None): + """This function is used to rotate a curve along given point ``pt`` at given angle(in radian). + + Parameters + ========== + + angle : + the angle at which the curve will be rotated(in radian) in counterclockwise direction. + default value of angle is 0. + + pt : Point + the point along which the curve will be rotated. + If no point given, the curve will be rotated around origin. + + Returns + ======= + + Curve : + returns a curve rotated at given angle along given point. + + Examples + ======== + + >>> from sympy import Curve, pi + >>> from sympy.abc import x + >>> Curve((x, x), (x, 0, 1)).rotate(pi/2) + Curve((-x, x), (x, 0, 1)) + + """ + if pt: + pt = -Point(pt, dim=2) + else: + pt = Point(0,0) + rv = self.translate(*pt.args) + f = list(rv.functions) + f.append(0) + f = Matrix(1, 3, f) + f *= rot_axis3(angle) + rv = self.func(f[0, :2].tolist()[0], self.limits) + pt = -pt + return rv.translate(*pt.args) + + def scale(self, x=1, y=1, pt=None): + """Override GeometryEntity.scale since Curve is not made up of Points. + + Returns + ======= + + Curve : + returns scaled curve. + + Examples + ======== + + >>> from sympy import Curve + >>> from sympy.abc import x + >>> Curve((x, x), (x, 0, 1)).scale(2) + Curve((2*x, x), (x, 0, 1)) + + """ + if pt: + pt = Point(pt, dim=2) + return self.translate(*(-pt).args).scale(x, y).translate(*pt.args) + fx, fy = self.functions + return self.func((fx*x, fy*y), self.limits) + + def translate(self, x=0, y=0): + """Translate the Curve by (x, y). + + Returns + ======= + + Curve : + returns a translated curve. + + Examples + ======== + + >>> from sympy import Curve + >>> from sympy.abc import x + >>> Curve((x, x), (x, 0, 1)).translate(1, 2) + Curve((x + 1, x + 2), (x, 0, 1)) + + """ + fx, fy = self.functions + return self.func((fx + x, fy + y), self.limits) diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/ellipse.py b/.venv/lib/python3.13/site-packages/sympy/geometry/ellipse.py new file mode 100644 index 0000000000000000000000000000000000000000..199db25fde9b019893a275d69959154990e8a4a7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/ellipse.py @@ -0,0 +1,1768 @@ +"""Elliptical geometrical entities. + +Contains +* Ellipse +* Circle + +""" + +from sympy.core.expr import Expr +from sympy.core.relational import Eq +from sympy.core import S, pi, sympify +from sympy.core.evalf import N +from sympy.core.parameters import global_parameters +from sympy.core.logic import fuzzy_bool +from sympy.core.numbers import Rational, oo +from sympy.core.sorting import ordered +from sympy.core.symbol import Dummy, uniquely_named_symbol, _symbol +from sympy.simplify.simplify import simplify +from sympy.simplify.trigsimp import trigsimp +from sympy.functions.elementary.miscellaneous import sqrt, Max +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.functions.special.elliptic_integrals import elliptic_e +from .entity import GeometryEntity, GeometrySet +from .exceptions import GeometryError +from .line import Line, Segment, Ray2D, Segment2D, Line2D, LinearEntity3D +from .point import Point, Point2D, Point3D +from .util import idiff, find +from sympy.polys import DomainError, Poly, PolynomialError +from sympy.polys.polyutils import _not_a_coeff, _nsort +from sympy.solvers import solve +from sympy.solvers.solveset import linear_coeffs +from sympy.utilities.misc import filldedent, func_name + +from mpmath.libmp.libmpf import prec_to_dps + +import random + +x, y = [Dummy('ellipse_dummy', real=True) for i in range(2)] + + +class Ellipse(GeometrySet): + """An elliptical GeometryEntity. + + Parameters + ========== + + center : Point, optional + Default value is Point(0, 0) + hradius : number or SymPy expression, optional + vradius : number or SymPy expression, optional + eccentricity : number or SymPy expression, optional + Two of `hradius`, `vradius` and `eccentricity` must be supplied to + create an Ellipse. The third is derived from the two supplied. + + Attributes + ========== + + center + hradius + vradius + area + circumference + eccentricity + periapsis + apoapsis + focus_distance + foci + + Raises + ====== + + GeometryError + When `hradius`, `vradius` and `eccentricity` are incorrectly supplied + as parameters. + TypeError + When `center` is not a Point. + + See Also + ======== + + Circle + + Notes + ----- + Constructed from a center and two radii, the first being the horizontal + radius (along the x-axis) and the second being the vertical radius (along + the y-axis). + + When symbolic value for hradius and vradius are used, any calculation that + refers to the foci or the major or minor axis will assume that the ellipse + has its major radius on the x-axis. If this is not true then a manual + rotation is necessary. + + Examples + ======== + + >>> from sympy import Ellipse, Point, Rational + >>> e1 = Ellipse(Point(0, 0), 5, 1) + >>> e1.hradius, e1.vradius + (5, 1) + >>> e2 = Ellipse(Point(3, 1), hradius=3, eccentricity=Rational(4, 5)) + >>> e2 + Ellipse(Point2D(3, 1), 3, 9/5) + + """ + + def __contains__(self, o): + if isinstance(o, Point): + res = self.equation(x, y).subs({x: o.x, y: o.y}) + return trigsimp(simplify(res)) is S.Zero + elif isinstance(o, Ellipse): + return self == o + return False + + def __eq__(self, o): + """Is the other GeometryEntity the same as this ellipse?""" + return isinstance(o, Ellipse) and (self.center == o.center and + self.hradius == o.hradius and + self.vradius == o.vradius) + + def __hash__(self): + return super().__hash__() + + def __new__( + cls, center=None, hradius=None, vradius=None, eccentricity=None, **kwargs): + + hradius = sympify(hradius) + vradius = sympify(vradius) + + if center is None: + center = Point(0, 0) + else: + if len(center) != 2: + raise ValueError('The center of "{}" must be a two dimensional point'.format(cls)) + center = Point(center, dim=2) + + if len(list(filter(lambda x: x is not None, (hradius, vradius, eccentricity)))) != 2: + raise ValueError(filldedent(''' + Exactly two arguments of "hradius", "vradius", and + "eccentricity" must not be None.''')) + + if eccentricity is not None: + eccentricity = sympify(eccentricity) + if eccentricity.is_negative: + raise GeometryError("Eccentricity of ellipse/circle should lie between [0, 1)") + elif hradius is None: + hradius = vradius / sqrt(1 - eccentricity**2) + elif vradius is None: + vradius = hradius * sqrt(1 - eccentricity**2) + + if hradius == vradius: + return Circle(center, hradius, **kwargs) + + if S.Zero in (hradius, vradius): + return Segment(Point(center[0] - hradius, center[1] - vradius), Point(center[0] + hradius, center[1] + vradius)) + + if hradius.is_real is False or vradius.is_real is False: + raise GeometryError("Invalid value encountered when computing hradius / vradius.") + + return GeometryEntity.__new__(cls, center, hradius, vradius, **kwargs) + + def _svg(self, scale_factor=1., fill_color="#66cc99"): + """Returns SVG ellipse element for the Ellipse. + + Parameters + ========== + + scale_factor : float + Multiplication factor for the SVG stroke-width. Default is 1. + fill_color : str, optional + Hex string for fill color. Default is "#66cc99". + """ + + c = N(self.center) + h, v = N(self.hradius), N(self.vradius) + return ( + '' + ).format(2. * scale_factor, fill_color, c.x, c.y, h, v) + + @property + def ambient_dimension(self): + return 2 + + @property + def apoapsis(self): + """The apoapsis of the ellipse. + + The greatest distance between the focus and the contour. + + Returns + ======= + + apoapsis : number + + See Also + ======== + + periapsis : Returns shortest distance between foci and contour + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, 1) + >>> e1.apoapsis + 2*sqrt(2) + 3 + + """ + return self.major * (1 + self.eccentricity) + + def arbitrary_point(self, parameter='t'): + """A parameterized point on the ellipse. + + Parameters + ========== + + parameter : str, optional + Default value is 't'. + + Returns + ======= + + arbitrary_point : Point + + Raises + ====== + + ValueError + When `parameter` already appears in the functions. + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> e1 = Ellipse(Point(0, 0), 3, 2) + >>> e1.arbitrary_point() + Point2D(3*cos(t), 2*sin(t)) + + """ + t = _symbol(parameter, real=True) + if t.name in (f.name for f in self.free_symbols): + raise ValueError(filldedent('Symbol %s already appears in object ' + 'and cannot be used as a parameter.' % t.name)) + return Point(self.center.x + self.hradius*cos(t), + self.center.y + self.vradius*sin(t)) + + @property + def area(self): + """The area of the ellipse. + + Returns + ======= + + area : number + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, 1) + >>> e1.area + 3*pi + + """ + return simplify(S.Pi * self.hradius * self.vradius) + + @property + def bounds(self): + """Return a tuple (xmin, ymin, xmax, ymax) representing the bounding + rectangle for the geometric figure. + + """ + + h, v = self.hradius, self.vradius + return (self.center.x - h, self.center.y - v, self.center.x + h, self.center.y + v) + + @property + def center(self): + """The center of the ellipse. + + Returns + ======= + + center : number + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, 1) + >>> e1.center + Point2D(0, 0) + + """ + return self.args[0] + + @property + def circumference(self): + """The circumference of the ellipse. + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, 1) + >>> e1.circumference + 12*elliptic_e(8/9) + + """ + if self.eccentricity == 1: + # degenerate + return 4*self.major + elif self.eccentricity == 0: + # circle + return 2*pi*self.hradius + else: + return 4*self.major*elliptic_e(self.eccentricity**2) + + @property + def eccentricity(self): + """The eccentricity of the ellipse. + + Returns + ======= + + eccentricity : number + + Examples + ======== + + >>> from sympy import Point, Ellipse, sqrt + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, sqrt(2)) + >>> e1.eccentricity + sqrt(7)/3 + + """ + return self.focus_distance / self.major + + def encloses_point(self, p): + """ + Return True if p is enclosed by (is inside of) self. + + Notes + ----- + Being on the border of self is considered False. + + Parameters + ========== + + p : Point + + Returns + ======= + + encloses_point : True, False or None + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Ellipse, S + >>> from sympy.abc import t + >>> e = Ellipse((0, 0), 3, 2) + >>> e.encloses_point((0, 0)) + True + >>> e.encloses_point(e.arbitrary_point(t).subs(t, S.Half)) + False + >>> e.encloses_point((4, 0)) + False + + """ + p = Point(p, dim=2) + if p in self: + return False + + if len(self.foci) == 2: + # if the combined distance from the foci to p (h1 + h2) is less + # than the combined distance from the foci to the minor axis + # (which is the same as the major axis length) then p is inside + # the ellipse + h1, h2 = [f.distance(p) for f in self.foci] + test = 2*self.major - (h1 + h2) + else: + test = self.radius - self.center.distance(p) + + return fuzzy_bool(test.is_positive) + + def equation(self, x='x', y='y', _slope=None): + """ + Returns the equation of an ellipse aligned with the x and y axes; + when slope is given, the equation returned corresponds to an ellipse + with a major axis having that slope. + + Parameters + ========== + + x : str, optional + Label for the x-axis. Default value is 'x'. + y : str, optional + Label for the y-axis. Default value is 'y'. + _slope : Expr, optional + The slope of the major axis. Ignored when 'None'. + + Returns + ======= + + equation : SymPy expression + + See Also + ======== + + arbitrary_point : Returns parameterized point on ellipse + + Examples + ======== + + >>> from sympy import Point, Ellipse, pi + >>> from sympy.abc import x, y + >>> e1 = Ellipse(Point(1, 0), 3, 2) + >>> eq1 = e1.equation(x, y); eq1 + y**2/4 + (x/3 - 1/3)**2 - 1 + >>> eq2 = e1.equation(x, y, _slope=1); eq2 + (-x + y + 1)**2/8 + (x + y - 1)**2/18 - 1 + + A point on e1 satisfies eq1. Let's use one on the x-axis: + + >>> p1 = e1.center + Point(e1.major, 0) + >>> assert eq1.subs(x, p1.x).subs(y, p1.y) == 0 + + When rotated the same as the rotated ellipse, about the center + point of the ellipse, it will satisfy the rotated ellipse's + equation, too: + + >>> r1 = p1.rotate(pi/4, e1.center) + >>> assert eq2.subs(x, r1.x).subs(y, r1.y) == 0 + + References + ========== + + .. [1] https://math.stackexchange.com/questions/108270/what-is-the-equation-of-an-ellipse-that-is-not-aligned-with-the-axis + .. [2] https://en.wikipedia.org/wiki/Ellipse#Shifted_ellipse + + """ + + x = _symbol(x, real=True) + y = _symbol(y, real=True) + + dx = x - self.center.x + dy = y - self.center.y + + if _slope is not None: + L = (dy - _slope*dx)**2 + l = (_slope*dy + dx)**2 + h = 1 + _slope**2 + b = h*self.major**2 + a = h*self.minor**2 + return l/b + L/a - 1 + + else: + t1 = (dx/self.hradius)**2 + t2 = (dy/self.vradius)**2 + return t1 + t2 - 1 + + def evolute(self, x='x', y='y'): + """The equation of evolute of the ellipse. + + Parameters + ========== + + x : str, optional + Label for the x-axis. Default value is 'x'. + y : str, optional + Label for the y-axis. Default value is 'y'. + + Returns + ======= + + equation : SymPy expression + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> e1 = Ellipse(Point(1, 0), 3, 2) + >>> e1.evolute() + 2**(2/3)*y**(2/3) + (3*x - 3)**(2/3) - 5**(2/3) + """ + if len(self.args) != 3: + raise NotImplementedError('Evolute of arbitrary Ellipse is not supported.') + x = _symbol(x, real=True) + y = _symbol(y, real=True) + t1 = (self.hradius*(x - self.center.x))**Rational(2, 3) + t2 = (self.vradius*(y - self.center.y))**Rational(2, 3) + return t1 + t2 - (self.hradius**2 - self.vradius**2)**Rational(2, 3) + + @property + def foci(self): + """The foci of the ellipse. + + Notes + ----- + The foci can only be calculated if the major/minor axes are known. + + Raises + ====== + + ValueError + When the major and minor axis cannot be determined. + + See Also + ======== + + sympy.geometry.point.Point + focus_distance : Returns the distance between focus and center + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, 1) + >>> e1.foci + (Point2D(-2*sqrt(2), 0), Point2D(2*sqrt(2), 0)) + + """ + c = self.center + hr, vr = self.hradius, self.vradius + if hr == vr: + return (c, c) + + # calculate focus distance manually, since focus_distance calls this + # routine + fd = sqrt(self.major**2 - self.minor**2) + if hr == self.minor: + # foci on the y-axis + return (c + Point(0, -fd), c + Point(0, fd)) + elif hr == self.major: + # foci on the x-axis + return (c + Point(-fd, 0), c + Point(fd, 0)) + + @property + def focus_distance(self): + """The focal distance of the ellipse. + + The distance between the center and one focus. + + Returns + ======= + + focus_distance : number + + See Also + ======== + + foci + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, 1) + >>> e1.focus_distance + 2*sqrt(2) + + """ + return Point.distance(self.center, self.foci[0]) + + @property + def hradius(self): + """The horizontal radius of the ellipse. + + Returns + ======= + + hradius : number + + See Also + ======== + + vradius, major, minor + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, 1) + >>> e1.hradius + 3 + + """ + return self.args[1] + + def intersection(self, o): + """The intersection of this ellipse and another geometrical entity + `o`. + + Parameters + ========== + + o : GeometryEntity + + Returns + ======= + + intersection : list of GeometryEntity objects + + Notes + ----- + Currently supports intersections with Point, Line, Segment, Ray, + Circle and Ellipse types. + + See Also + ======== + + sympy.geometry.entity.GeometryEntity + + Examples + ======== + + >>> from sympy import Ellipse, Point, Line + >>> e = Ellipse(Point(0, 0), 5, 7) + >>> e.intersection(Point(0, 0)) + [] + >>> e.intersection(Point(5, 0)) + [Point2D(5, 0)] + >>> e.intersection(Line(Point(0,0), Point(0, 1))) + [Point2D(0, -7), Point2D(0, 7)] + >>> e.intersection(Line(Point(5,0), Point(5, 1))) + [Point2D(5, 0)] + >>> e.intersection(Line(Point(6,0), Point(6, 1))) + [] + >>> e = Ellipse(Point(-1, 0), 4, 3) + >>> e.intersection(Ellipse(Point(1, 0), 4, 3)) + [Point2D(0, -3*sqrt(15)/4), Point2D(0, 3*sqrt(15)/4)] + >>> e.intersection(Ellipse(Point(5, 0), 4, 3)) + [Point2D(2, -3*sqrt(7)/4), Point2D(2, 3*sqrt(7)/4)] + >>> e.intersection(Ellipse(Point(100500, 0), 4, 3)) + [] + >>> e.intersection(Ellipse(Point(0, 0), 3, 4)) + [Point2D(3, 0), Point2D(-363/175, -48*sqrt(111)/175), Point2D(-363/175, 48*sqrt(111)/175)] + >>> e.intersection(Ellipse(Point(-1, 0), 3, 4)) + [Point2D(-17/5, -12/5), Point2D(-17/5, 12/5), Point2D(7/5, -12/5), Point2D(7/5, 12/5)] + """ + # TODO: Replace solve with nonlinsolve, when nonlinsolve will be able to solve in real domain + + if isinstance(o, Point): + if o in self: + return [o] + else: + return [] + + elif isinstance(o, (Segment2D, Ray2D)): + ellipse_equation = self.equation(x, y) + result = solve([ellipse_equation, Line( + o.points[0], o.points[1]).equation(x, y)], [x, y], + set=True)[1] + return list(ordered([Point(i) for i in result if i in o])) + + elif isinstance(o, Polygon): + return o.intersection(self) + + elif isinstance(o, (Ellipse, Line2D)): + if o == self: + return self + else: + ellipse_equation = self.equation(x, y) + return list(ordered([Point(i) for i in solve( + [ellipse_equation, o.equation(x, y)], [x, y], + set=True)[1]])) + elif isinstance(o, LinearEntity3D): + raise TypeError('Entity must be two dimensional, not three dimensional') + else: + raise TypeError('Intersection not handled for %s' % func_name(o)) + + def is_tangent(self, o): + """Is `o` tangent to the ellipse? + + Parameters + ========== + + o : GeometryEntity + An Ellipse, LinearEntity or Polygon + + Raises + ====== + + NotImplementedError + When the wrong type of argument is supplied. + + Returns + ======= + + is_tangent: boolean + True if o is tangent to the ellipse, False otherwise. + + See Also + ======== + + tangent_lines + + Examples + ======== + + >>> from sympy import Point, Ellipse, Line + >>> p0, p1, p2 = Point(0, 0), Point(3, 0), Point(3, 3) + >>> e1 = Ellipse(p0, 3, 2) + >>> l1 = Line(p1, p2) + >>> e1.is_tangent(l1) + True + + """ + if isinstance(o, Point2D): + return False + elif isinstance(o, Ellipse): + intersect = self.intersection(o) + if isinstance(intersect, Ellipse): + return True + elif intersect: + return all((self.tangent_lines(i)[0]).equals(o.tangent_lines(i)[0]) for i in intersect) + else: + return False + elif isinstance(o, Line2D): + hit = self.intersection(o) + if not hit: + return False + if len(hit) == 1: + return True + # might return None if it can't decide + return hit[0].equals(hit[1]) + elif isinstance(o, (Segment2D, Ray2D)): + intersect = self.intersection(o) + if len(intersect) == 1: + return o in self.tangent_lines(intersect[0])[0] + else: + return False + elif isinstance(o, Polygon): + return all(self.is_tangent(s) for s in o.sides) + elif isinstance(o, (LinearEntity3D, Point3D)): + raise TypeError('Entity must be two dimensional, not three dimensional') + else: + raise TypeError('Is_tangent not handled for %s' % func_name(o)) + + @property + def major(self): + """Longer axis of the ellipse (if it can be determined) else hradius. + + Returns + ======= + + major : number or expression + + See Also + ======== + + hradius, vradius, minor + + Examples + ======== + + >>> from sympy import Point, Ellipse, Symbol + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, 1) + >>> e1.major + 3 + + >>> a = Symbol('a') + >>> b = Symbol('b') + >>> Ellipse(p1, a, b).major + a + >>> Ellipse(p1, b, a).major + b + + >>> m = Symbol('m') + >>> M = m + 1 + >>> Ellipse(p1, m, M).major + m + 1 + + """ + ab = self.args[1:3] + if len(ab) == 1: + return ab[0] + a, b = ab + o = b - a < 0 + if o == True: + return a + elif o == False: + return b + return self.hradius + + @property + def minor(self): + """Shorter axis of the ellipse (if it can be determined) else vradius. + + Returns + ======= + + minor : number or expression + + See Also + ======== + + hradius, vradius, major + + Examples + ======== + + >>> from sympy import Point, Ellipse, Symbol + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, 1) + >>> e1.minor + 1 + + >>> a = Symbol('a') + >>> b = Symbol('b') + >>> Ellipse(p1, a, b).minor + b + >>> Ellipse(p1, b, a).minor + a + + >>> m = Symbol('m') + >>> M = m + 1 + >>> Ellipse(p1, m, M).minor + m + + """ + ab = self.args[1:3] + if len(ab) == 1: + return ab[0] + a, b = ab + o = a - b < 0 + if o == True: + return a + elif o == False: + return b + return self.vradius + + def normal_lines(self, p, prec=None): + """Normal lines between `p` and the ellipse. + + Parameters + ========== + + p : Point + + Returns + ======= + + normal_lines : list with 1, 2 or 4 Lines + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> e = Ellipse((0, 0), 2, 3) + >>> c = e.center + >>> e.normal_lines(c + Point(1, 0)) + [Line2D(Point2D(0, 0), Point2D(1, 0))] + >>> e.normal_lines(c) + [Line2D(Point2D(0, 0), Point2D(0, 1)), Line2D(Point2D(0, 0), Point2D(1, 0))] + + Off-axis points require the solution of a quartic equation. This + often leads to very large expressions that may be of little practical + use. An approximate solution of `prec` digits can be obtained by + passing in the desired value: + + >>> e.normal_lines((3, 3), prec=2) + [Line2D(Point2D(-0.81, -2.7), Point2D(0.19, -1.2)), + Line2D(Point2D(1.5, -2.0), Point2D(2.5, -2.7))] + + Whereas the above solution has an operation count of 12, the exact + solution has an operation count of 2020. + """ + p = Point(p, dim=2) + + # XXX change True to something like self.angle == 0 if the arbitrarily + # rotated ellipse is introduced. + # https://github.com/sympy/sympy/issues/2815) + if True: + rv = [] + if p.x == self.center.x: + rv.append(Line(self.center, slope=oo)) + if p.y == self.center.y: + rv.append(Line(self.center, slope=0)) + if rv: + # at these special orientations of p either 1 or 2 normals + # exist and we are done + return rv + + # find the 4 normal points and construct lines through them with + # the corresponding slope + eq = self.equation(x, y) + dydx = idiff(eq, y, x) + norm = -1/dydx + slope = Line(p, (x, y)).slope + seq = slope - norm + + # TODO: Replace solve with solveset, when this line is tested + yis = solve(seq, y)[0] + xeq = eq.subs(y, yis).as_numer_denom()[0].expand() + if len(xeq.free_symbols) == 1: + try: + # this is so much faster, it's worth a try + xsol = Poly(xeq, x).real_roots() + except (DomainError, PolynomialError, NotImplementedError): + # TODO: Replace solve with solveset, when these lines are tested + xsol = _nsort(solve(xeq, x), separated=True)[0] + points = [Point(i, solve(eq.subs(x, i), y)[0]) for i in xsol] + else: + raise NotImplementedError( + 'intersections for the general ellipse are not supported') + slopes = [norm.subs(zip((x, y), pt.args)) for pt in points] + if prec is not None: + points = [pt.n(prec) for pt in points] + slopes = [i if _not_a_coeff(i) else i.n(prec) for i in slopes] + return [Line(pt, slope=s) for pt, s in zip(points, slopes)] + + @property + def periapsis(self): + """The periapsis of the ellipse. + + The shortest distance between the focus and the contour. + + Returns + ======= + + periapsis : number + + See Also + ======== + + apoapsis : Returns greatest distance between focus and contour + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, 1) + >>> e1.periapsis + 3 - 2*sqrt(2) + + """ + return self.major * (1 - self.eccentricity) + + @property + def semilatus_rectum(self): + """ + Calculates the semi-latus rectum of the Ellipse. + + Semi-latus rectum is defined as one half of the chord through a + focus parallel to the conic section directrix of a conic section. + + Returns + ======= + + semilatus_rectum : number + + See Also + ======== + + apoapsis : Returns greatest distance between focus and contour + + periapsis : The shortest distance between the focus and the contour + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, 1) + >>> e1.semilatus_rectum + 1/3 + + References + ========== + + .. [1] https://mathworld.wolfram.com/SemilatusRectum.html + .. [2] https://en.wikipedia.org/wiki/Ellipse#Semi-latus_rectum + + """ + return self.major * (1 - self.eccentricity ** 2) + + def auxiliary_circle(self): + """Returns a Circle whose diameter is the major axis of the ellipse. + + Examples + ======== + + >>> from sympy import Ellipse, Point, symbols + >>> c = Point(1, 2) + >>> Ellipse(c, 8, 7).auxiliary_circle() + Circle(Point2D(1, 2), 8) + >>> a, b = symbols('a b') + >>> Ellipse(c, a, b).auxiliary_circle() + Circle(Point2D(1, 2), Max(a, b)) + """ + return Circle(self.center, Max(self.hradius, self.vradius)) + + def director_circle(self): + """ + Returns a Circle consisting of all points where two perpendicular + tangent lines to the ellipse cross each other. + + Returns + ======= + + Circle + A director circle returned as a geometric object. + + Examples + ======== + + >>> from sympy import Ellipse, Point, symbols + >>> c = Point(3,8) + >>> Ellipse(c, 7, 9).director_circle() + Circle(Point2D(3, 8), sqrt(130)) + >>> a, b = symbols('a b') + >>> Ellipse(c, a, b).director_circle() + Circle(Point2D(3, 8), sqrt(a**2 + b**2)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Director_circle + + """ + return Circle(self.center, sqrt(self.hradius**2 + self.vradius**2)) + + def plot_interval(self, parameter='t'): + """The plot interval for the default geometric plot of the Ellipse. + + Parameters + ========== + + parameter : str, optional + Default value is 't'. + + Returns + ======= + + plot_interval : list + [parameter, lower_bound, upper_bound] + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> e1 = Ellipse(Point(0, 0), 3, 2) + >>> e1.plot_interval() + [t, -pi, pi] + + """ + t = _symbol(parameter, real=True) + return [t, -S.Pi, S.Pi] + + def random_point(self, seed=None): + """A random point on the ellipse. + + Returns + ======= + + point : Point + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> e1 = Ellipse(Point(0, 0), 3, 2) + >>> e1.random_point() # gives some random point + Point2D(...) + >>> p1 = e1.random_point(seed=0); p1.n(2) + Point2D(2.1, 1.4) + + Notes + ===== + + When creating a random point, one may simply replace the + parameter with a random number. When doing so, however, the + random number should be made a Rational or else the point + may not test as being in the ellipse: + + >>> from sympy.abc import t + >>> from sympy import Rational + >>> arb = e1.arbitrary_point(t); arb + Point2D(3*cos(t), 2*sin(t)) + >>> arb.subs(t, .1) in e1 + False + >>> arb.subs(t, Rational(.1)) in e1 + True + >>> arb.subs(t, Rational('.1')) in e1 + True + + See Also + ======== + sympy.geometry.point.Point + arbitrary_point : Returns parameterized point on ellipse + """ + t = _symbol('t', real=True) + x, y = self.arbitrary_point(t).args + # get a random value in [-1, 1) corresponding to cos(t) + # and confirm that it will test as being in the ellipse + if seed is not None: + rng = random.Random(seed) + else: + rng = random + # simplify this now or else the Float will turn s into a Float + r = Rational(rng.random()) + c = 2*r - 1 + s = sqrt(1 - c**2) + return Point(x.subs(cos(t), c), y.subs(sin(t), s)) + + def reflect(self, line): + """Override GeometryEntity.reflect since the radius + is not a GeometryEntity. + + Examples + ======== + + >>> from sympy import Circle, Line + >>> Circle((0, 1), 1).reflect(Line((0, 0), (1, 1))) + Circle(Point2D(1, 0), -1) + >>> from sympy import Ellipse, Line, Point + >>> Ellipse(Point(3, 4), 1, 3).reflect(Line(Point(0, -4), Point(5, 0))) + Traceback (most recent call last): + ... + NotImplementedError: + General Ellipse is not supported but the equation of the reflected + Ellipse is given by the zeros of: f(x, y) = (9*x/41 + 40*y/41 + + 37/41)**2 + (40*x/123 - 3*y/41 - 364/123)**2 - 1 + + Notes + ===== + + Until the general ellipse (with no axis parallel to the x-axis) is + supported a NotImplemented error is raised and the equation whose + zeros define the rotated ellipse is given. + + """ + + if line.slope in (0, oo): + c = self.center + c = c.reflect(line) + return self.func(c, -self.hradius, self.vradius) + else: + x, y = [uniquely_named_symbol( + name, (self, line), modify=lambda s: '_' + s, real=True) + for name in 'xy'] + expr = self.equation(x, y) + p = Point(x, y).reflect(line) + result = expr.subs(zip((x, y), p.args + ), simultaneous=True) + raise NotImplementedError(filldedent( + 'General Ellipse is not supported but the equation ' + 'of the reflected Ellipse is given by the zeros of: ' + + "f(%s, %s) = %s" % (str(x), str(y), str(result)))) + + def rotate(self, angle=0, pt=None): + """Rotate ``angle`` radians counterclockwise about Point ``pt``. + + Note: since the general ellipse is not supported, only rotations that + are integer multiples of pi/2 are allowed. + + Examples + ======== + + >>> from sympy import Ellipse, pi + >>> Ellipse((1, 0), 2, 1).rotate(pi/2) + Ellipse(Point2D(0, 1), 1, 2) + >>> Ellipse((1, 0), 2, 1).rotate(pi) + Ellipse(Point2D(-1, 0), 2, 1) + """ + if self.hradius == self.vradius: + return self.func(self.center.rotate(angle, pt), self.hradius) + if (angle/S.Pi).is_integer: + return super().rotate(angle, pt) + if (2*angle/S.Pi).is_integer: + return self.func(self.center.rotate(angle, pt), self.vradius, self.hradius) + # XXX see https://github.com/sympy/sympy/issues/2815 for general ellipes + raise NotImplementedError('Only rotations of pi/2 are currently supported for Ellipse.') + + def scale(self, x=1, y=1, pt=None): + """Override GeometryEntity.scale since it is the major and minor + axes which must be scaled and they are not GeometryEntities. + + Examples + ======== + + >>> from sympy import Ellipse + >>> Ellipse((0, 0), 2, 1).scale(2, 4) + Circle(Point2D(0, 0), 4) + >>> Ellipse((0, 0), 2, 1).scale(2) + Ellipse(Point2D(0, 0), 4, 1) + """ + c = self.center + if pt: + pt = Point(pt, dim=2) + return self.translate(*(-pt).args).scale(x, y).translate(*pt.args) + h = self.hradius + v = self.vradius + return self.func(c.scale(x, y), hradius=h*x, vradius=v*y) + + def tangent_lines(self, p): + """Tangent lines between `p` and the ellipse. + + If `p` is on the ellipse, returns the tangent line through point `p`. + Otherwise, returns the tangent line(s) from `p` to the ellipse, or + None if no tangent line is possible (e.g., `p` inside ellipse). + + Parameters + ========== + + p : Point + + Returns + ======= + + tangent_lines : list with 1 or 2 Lines + + Raises + ====== + + NotImplementedError + Can only find tangent lines for a point, `p`, on the ellipse. + + See Also + ======== + + sympy.geometry.point.Point, sympy.geometry.line.Line + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> e1 = Ellipse(Point(0, 0), 3, 2) + >>> e1.tangent_lines(Point(3, 0)) + [Line2D(Point2D(3, 0), Point2D(3, -12))] + + """ + p = Point(p, dim=2) + if self.encloses_point(p): + return [] + + if p in self: + delta = self.center - p + rise = (self.vradius**2)*delta.x + run = -(self.hradius**2)*delta.y + p2 = Point(simplify(p.x + run), + simplify(p.y + rise)) + return [Line(p, p2)] + else: + if len(self.foci) == 2: + f1, f2 = self.foci + maj = self.hradius + test = (2*maj - + Point.distance(f1, p) - + Point.distance(f2, p)) + else: + test = self.radius - Point.distance(self.center, p) + if test.is_number and test.is_positive: + return [] + # else p is outside the ellipse or we can't tell. In case of the + # latter, the solutions returned will only be valid if + # the point is not inside the ellipse; if it is, nan will result. + eq = self.equation(x, y) + dydx = idiff(eq, y, x) + slope = Line(p, Point(x, y)).slope + + # TODO: Replace solve with solveset, when this line is tested + tangent_points = solve([slope - dydx, eq], [x, y]) + + # handle horizontal and vertical tangent lines + if len(tangent_points) == 1: + if tangent_points[0][ + 0] == p.x or tangent_points[0][1] == p.y: + return [Line(p, p + Point(1, 0)), Line(p, p + Point(0, 1))] + else: + return [Line(p, p + Point(0, 1)), Line(p, tangent_points[0])] + + # others + return [Line(p, tangent_points[0]), Line(p, tangent_points[1])] + + @property + def vradius(self): + """The vertical radius of the ellipse. + + Returns + ======= + + vradius : number + + See Also + ======== + + hradius, major, minor + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, 1) + >>> e1.vradius + 1 + + """ + return self.args[2] + + + def second_moment_of_area(self, point=None): + """Returns the second moment and product moment area of an ellipse. + + Parameters + ========== + + point : Point, two-tuple of sympifiable objects, or None(default=None) + point is the point about which second moment of area is to be found. + If "point=None" it will be calculated about the axis passing through the + centroid of the ellipse. + + Returns + ======= + + I_xx, I_yy, I_xy : number or SymPy expression + I_xx, I_yy are second moment of area of an ellise. + I_xy is product moment of area of an ellipse. + + Examples + ======== + + >>> from sympy import Point, Ellipse + >>> p1 = Point(0, 0) + >>> e1 = Ellipse(p1, 3, 1) + >>> e1.second_moment_of_area() + (3*pi/4, 27*pi/4, 0) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/List_of_second_moments_of_area + + """ + + I_xx = (S.Pi*(self.hradius)*(self.vradius**3))/4 + I_yy = (S.Pi*(self.hradius**3)*(self.vradius))/4 + I_xy = 0 + + if point is None: + return I_xx, I_yy, I_xy + + # parallel axis theorem + I_xx = I_xx + self.area*((point[1] - self.center.y)**2) + I_yy = I_yy + self.area*((point[0] - self.center.x)**2) + I_xy = I_xy + self.area*(point[0] - self.center.x)*(point[1] - self.center.y) + + return I_xx, I_yy, I_xy + + + def polar_second_moment_of_area(self): + """Returns the polar second moment of area of an Ellipse + + It is a constituent of the second moment of area, linked through + the perpendicular axis theorem. While the planar second moment of + area describes an object's resistance to deflection (bending) when + subjected to a force applied to a plane parallel to the central + axis, the polar second moment of area describes an object's + resistance to deflection when subjected to a moment applied in a + plane perpendicular to the object's central axis (i.e. parallel to + the cross-section) + + Examples + ======== + + >>> from sympy import symbols, Circle, Ellipse + >>> c = Circle((5, 5), 4) + >>> c.polar_second_moment_of_area() + 128*pi + >>> a, b = symbols('a, b') + >>> e = Ellipse((0, 0), a, b) + >>> e.polar_second_moment_of_area() + pi*a**3*b/4 + pi*a*b**3/4 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Polar_moment_of_inertia + + """ + second_moment = self.second_moment_of_area() + return second_moment[0] + second_moment[1] + + + def section_modulus(self, point=None): + """Returns a tuple with the section modulus of an ellipse + + Section modulus is a geometric property of an ellipse defined as the + ratio of second moment of area to the distance of the extreme end of + the ellipse from the centroidal axis. + + Parameters + ========== + + point : Point, two-tuple of sympifyable objects, or None(default=None) + point is the point at which section modulus is to be found. + If "point=None" section modulus will be calculated for the + point farthest from the centroidal axis of the ellipse. + + Returns + ======= + + S_x, S_y: numbers or SymPy expressions + S_x is the section modulus with respect to the x-axis + S_y is the section modulus with respect to the y-axis + A negative sign indicates that the section modulus is + determined for a point below the centroidal axis. + + Examples + ======== + + >>> from sympy import Symbol, Ellipse, Circle, Point2D + >>> d = Symbol('d', positive=True) + >>> c = Circle((0, 0), d/2) + >>> c.section_modulus() + (pi*d**3/32, pi*d**3/32) + >>> e = Ellipse(Point2D(0, 0), 2, 4) + >>> e.section_modulus() + (8*pi, 4*pi) + >>> e.section_modulus((2, 2)) + (16*pi, 4*pi) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Section_modulus + + """ + x_c, y_c = self.center + if point is None: + # taking x and y as maximum distances from centroid + x_min, y_min, x_max, y_max = self.bounds + y = max(y_c - y_min, y_max - y_c) + x = max(x_c - x_min, x_max - x_c) + else: + # taking x and y as distances of the given point from the center + point = Point2D(point) + y = point.y - y_c + x = point.x - x_c + + second_moment = self.second_moment_of_area() + S_x = second_moment[0]/y + S_y = second_moment[1]/x + + return S_x, S_y + + +class Circle(Ellipse): + r"""A circle in space. + + Constructed simply from a center and a radius, from three + non-collinear points, or the equation of a circle. + + Parameters + ========== + + center : Point + radius : number or SymPy expression + points : sequence of three Points + equation : equation of a circle + + Attributes + ========== + + radius (synonymous with hradius, vradius, major and minor) + circumference + equation + + Raises + ====== + + GeometryError + When the given equation is not that of a circle. + When trying to construct circle from incorrect parameters. + + See Also + ======== + + Ellipse, sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Circle, Eq + >>> from sympy.abc import x, y, a, b + + A circle constructed from a center and radius: + + >>> c1 = Circle(Point(0, 0), 5) + >>> c1.hradius, c1.vradius, c1.radius + (5, 5, 5) + + A circle constructed from three points: + + >>> c2 = Circle(Point(0, 0), Point(1, 1), Point(1, 0)) + >>> c2.hradius, c2.vradius, c2.radius, c2.center + (sqrt(2)/2, sqrt(2)/2, sqrt(2)/2, Point2D(1/2, 1/2)) + + A circle can be constructed from an equation in the form + `ax^2 + by^2 + gx + hy + c = 0`, too: + + >>> Circle(x**2 + y**2 - 25) + Circle(Point2D(0, 0), 5) + + If the variables corresponding to x and y are named something + else, their name or symbol can be supplied: + + >>> Circle(Eq(a**2 + b**2, 25), x='a', y=b) + Circle(Point2D(0, 0), 5) + """ + + def __new__(cls, *args, **kwargs): + evaluate = kwargs.get('evaluate', global_parameters.evaluate) + if len(args) == 1 and isinstance(args[0], (Expr, Eq)): + x = kwargs.get('x', 'x') + y = kwargs.get('y', 'y') + equation = args[0].expand() + if isinstance(equation, Eq): + equation = equation.lhs - equation.rhs + x = find(x, equation) + y = find(y, equation) + + try: + a, b, c, d, e = linear_coeffs(equation, x**2, y**2, x, y) + except ValueError: + raise GeometryError("The given equation is not that of a circle.") + + if S.Zero in (a, b) or a != b: + raise GeometryError("The given equation is not that of a circle.") + + center_x = -c/a/2 + center_y = -d/b/2 + r2 = (center_x**2) + (center_y**2) - e/a + + return Circle((center_x, center_y), sqrt(r2), evaluate=evaluate) + + else: + c, r = None, None + if len(args) == 3: + args = [Point(a, dim=2, evaluate=evaluate) for a in args] + t = Triangle(*args) + if not isinstance(t, Triangle): + return t + c = t.circumcenter + r = t.circumradius + elif len(args) == 2: + # Assume (center, radius) pair + c = Point(args[0], dim=2, evaluate=evaluate) + r = args[1] + # this will prohibit imaginary radius + try: + r = Point(r, 0, evaluate=evaluate).x + except ValueError: + raise GeometryError("Circle with imaginary radius is not permitted") + + if not (c is None or r is None): + if r == 0: + return c + return GeometryEntity.__new__(cls, c, r, **kwargs) + + raise GeometryError("Circle.__new__ received unknown arguments") + + def _eval_evalf(self, prec=15, **options): + pt, r = self.args + dps = prec_to_dps(prec) + pt = pt.evalf(n=dps, **options) + r = r.evalf(n=dps, **options) + return self.func(pt, r, evaluate=False) + + @property + def circumference(self): + """The circumference of the circle. + + Returns + ======= + + circumference : number or SymPy expression + + Examples + ======== + + >>> from sympy import Point, Circle + >>> c1 = Circle(Point(3, 4), 6) + >>> c1.circumference + 12*pi + + """ + return 2 * S.Pi * self.radius + + def equation(self, x='x', y='y'): + """The equation of the circle. + + Parameters + ========== + + x : str or Symbol, optional + Default value is 'x'. + y : str or Symbol, optional + Default value is 'y'. + + Returns + ======= + + equation : SymPy expression + + Examples + ======== + + >>> from sympy import Point, Circle + >>> c1 = Circle(Point(0, 0), 5) + >>> c1.equation() + x**2 + y**2 - 25 + + """ + x = _symbol(x, real=True) + y = _symbol(y, real=True) + t1 = (x - self.center.x)**2 + t2 = (y - self.center.y)**2 + return t1 + t2 - self.major**2 + + def intersection(self, o): + """The intersection of this circle with another geometrical entity. + + Parameters + ========== + + o : GeometryEntity + + Returns + ======= + + intersection : list of GeometryEntities + + Examples + ======== + + >>> from sympy import Point, Circle, Line, Ray + >>> p1, p2, p3 = Point(0, 0), Point(5, 5), Point(6, 0) + >>> p4 = Point(5, 0) + >>> c1 = Circle(p1, 5) + >>> c1.intersection(p2) + [] + >>> c1.intersection(p4) + [Point2D(5, 0)] + >>> c1.intersection(Ray(p1, p2)) + [Point2D(5*sqrt(2)/2, 5*sqrt(2)/2)] + >>> c1.intersection(Line(p2, p3)) + [] + + """ + return Ellipse.intersection(self, o) + + @property + def radius(self): + """The radius of the circle. + + Returns + ======= + + radius : number or SymPy expression + + See Also + ======== + + Ellipse.major, Ellipse.minor, Ellipse.hradius, Ellipse.vradius + + Examples + ======== + + >>> from sympy import Point, Circle + >>> c1 = Circle(Point(3, 4), 6) + >>> c1.radius + 6 + + """ + return self.args[1] + + def reflect(self, line): + """Override GeometryEntity.reflect since the radius + is not a GeometryEntity. + + Examples + ======== + + >>> from sympy import Circle, Line + >>> Circle((0, 1), 1).reflect(Line((0, 0), (1, 1))) + Circle(Point2D(1, 0), -1) + """ + c = self.center + c = c.reflect(line) + return self.func(c, -self.radius) + + def scale(self, x=1, y=1, pt=None): + """Override GeometryEntity.scale since the radius + is not a GeometryEntity. + + Examples + ======== + + >>> from sympy import Circle + >>> Circle((0, 0), 1).scale(2, 2) + Circle(Point2D(0, 0), 2) + >>> Circle((0, 0), 1).scale(2, 4) + Ellipse(Point2D(0, 0), 2, 4) + """ + c = self.center + if pt: + pt = Point(pt, dim=2) + return self.translate(*(-pt).args).scale(x, y).translate(*pt.args) + c = c.scale(x, y) + x, y = [abs(i) for i in (x, y)] + if x == y: + return self.func(c, x*self.radius) + h = v = self.radius + return Ellipse(c, hradius=h*x, vradius=v*y) + + @property + def vradius(self): + """ + This Ellipse property is an alias for the Circle's radius. + + Whereas hradius, major and minor can use Ellipse's conventions, + the vradius does not exist for a circle. It is always a positive + value in order that the Circle, like Polygons, will have an + area that can be positive or negative as determined by the sign + of the hradius. + + Examples + ======== + + >>> from sympy import Point, Circle + >>> c1 = Circle(Point(3, 4), 6) + >>> c1.vradius + 6 + """ + return abs(self.radius) + + +from .polygon import Polygon, Triangle diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/entity.py b/.venv/lib/python3.13/site-packages/sympy/geometry/entity.py new file mode 100644 index 0000000000000000000000000000000000000000..5ea1e807542c43eb955c2d778cec0f101d78bdce --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/entity.py @@ -0,0 +1,641 @@ +"""The definition of the base geometrical entity with attributes common to +all derived geometrical entities. + +Contains +======== + +GeometryEntity +GeometricSet + +Notes +===== + +A GeometryEntity is any object that has special geometric properties. +A GeometrySet is a superclass of any GeometryEntity that can also +be viewed as a sympy.sets.Set. In particular, points are the only +GeometryEntity not considered a Set. + +Rn is a GeometrySet representing n-dimensional Euclidean space. R2 and +R3 are currently the only ambient spaces implemented. + +""" +from __future__ import annotations + +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.evalf import EvalfMixin, N +from sympy.core.numbers import oo +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.functions.elementary.trigonometric import cos, sin, atan +from sympy.matrices import eye +from sympy.multipledispatch import dispatch +from sympy.printing import sstr +from sympy.sets import Set, Union, FiniteSet +from sympy.sets.handlers.intersection import intersection_sets +from sympy.sets.handlers.union import union_sets +from sympy.solvers.solvers import solve +from sympy.utilities.misc import func_name +from sympy.utilities.iterables import is_sequence + + +# How entities are ordered; used by __cmp__ in GeometryEntity +ordering_of_classes = [ + "Point2D", + "Point3D", + "Point", + "Segment2D", + "Ray2D", + "Line2D", + "Segment3D", + "Line3D", + "Ray3D", + "Segment", + "Ray", + "Line", + "Plane", + "Triangle", + "RegularPolygon", + "Polygon", + "Circle", + "Ellipse", + "Curve", + "Parabola" +] + + +x, y = [Dummy('entity_dummy') for i in range(2)] +T = Dummy('entity_dummy', real=True) + + +class GeometryEntity(Basic, EvalfMixin): + """The base class for all geometrical entities. + + This class does not represent any particular geometric entity, it only + provides the implementation of some methods common to all subclasses. + + """ + + __slots__: tuple[str, ...] = () + + def __cmp__(self, other): + """Comparison of two GeometryEntities.""" + n1 = self.__class__.__name__ + n2 = other.__class__.__name__ + c = (n1 > n2) - (n1 < n2) + if not c: + return 0 + + i1 = -1 + for cls in self.__class__.__mro__: + try: + i1 = ordering_of_classes.index(cls.__name__) + break + except ValueError: + i1 = -1 + if i1 == -1: + return c + + i2 = -1 + for cls in other.__class__.__mro__: + try: + i2 = ordering_of_classes.index(cls.__name__) + break + except ValueError: + i2 = -1 + if i2 == -1: + return c + + return (i1 > i2) - (i1 < i2) + + def __contains__(self, other): + """Subclasses should implement this method for anything more complex than equality.""" + if type(self) is type(other): + return self == other + raise NotImplementedError() + + def __getnewargs__(self): + """Returns a tuple that will be passed to __new__ on unpickling.""" + return tuple(self.args) + + def __ne__(self, o): + """Test inequality of two geometrical entities.""" + return not self == o + + def __new__(cls, *args, **kwargs): + # Points are sequences, but they should not + # be converted to Tuples, so use this detection function instead. + def is_seq_and_not_point(a): + # we cannot use isinstance(a, Point) since we cannot import Point + if hasattr(a, 'is_Point') and a.is_Point: + return False + return is_sequence(a) + + args = [Tuple(*a) if is_seq_and_not_point(a) else sympify(a) for a in args] + return Basic.__new__(cls, *args) + + def __radd__(self, a): + """Implementation of reverse add method.""" + return a.__add__(self) + + def __rtruediv__(self, a): + """Implementation of reverse division method.""" + return a.__truediv__(self) + + def __repr__(self): + """String representation of a GeometryEntity that can be evaluated + by sympy.""" + return type(self).__name__ + repr(self.args) + + def __rmul__(self, a): + """Implementation of reverse multiplication method.""" + return a.__mul__(self) + + def __rsub__(self, a): + """Implementation of reverse subtraction method.""" + return a.__sub__(self) + + def __str__(self): + """String representation of a GeometryEntity.""" + return type(self).__name__ + sstr(self.args) + + def _eval_subs(self, old, new): + from sympy.geometry.point import Point, Point3D + if is_sequence(old) or is_sequence(new): + if isinstance(self, Point3D): + old = Point3D(old) + new = Point3D(new) + else: + old = Point(old) + new = Point(new) + return self._subs(old, new) + + def _repr_svg_(self): + """SVG representation of a GeometryEntity suitable for IPython""" + + try: + bounds = self.bounds + except (NotImplementedError, TypeError): + # if we have no SVG representation, return None so IPython + # will fall back to the next representation + return None + + if not all(x.is_number and x.is_finite for x in bounds): + return None + + svg_top = ''' + + + + + + + + + + + ''' + + # Establish SVG canvas that will fit all the data + small space + xmin, ymin, xmax, ymax = map(N, bounds) + if xmin == xmax and ymin == ymax: + # This is a point; buffer using an arbitrary size + xmin, ymin, xmax, ymax = xmin - .5, ymin -.5, xmax + .5, ymax + .5 + else: + # Expand bounds by a fraction of the data ranges + expand = 0.1 # or 10%; this keeps arrowheads in view (R plots use 4%) + widest_part = max([xmax - xmin, ymax - ymin]) + expand_amount = widest_part * expand + xmin -= expand_amount + ymin -= expand_amount + xmax += expand_amount + ymax += expand_amount + dx = xmax - xmin + dy = ymax - ymin + width = min([max([100., dx]), 300]) + height = min([max([100., dy]), 300]) + + scale_factor = 1. if max(width, height) == 0 else max(dx, dy) / max(width, height) + try: + svg = self._svg(scale_factor) + except (NotImplementedError, TypeError): + # if we have no SVG representation, return None so IPython + # will fall back to the next representation + return None + + view_box = "{} {} {} {}".format(xmin, ymin, dx, dy) + transform = "matrix(1,0,0,-1,0,{})".format(ymax + ymin) + svg_top = svg_top.format(view_box, width, height) + + return svg_top + ( + '{}' + ).format(transform, svg) + + def _svg(self, scale_factor=1., fill_color="#66cc99"): + """Returns SVG path element for the GeometryEntity. + + Parameters + ========== + + scale_factor : float + Multiplication factor for the SVG stroke-width. Default is 1. + fill_color : str, optional + Hex string for fill color. Default is "#66cc99". + """ + raise NotImplementedError() + + def _sympy_(self): + return self + + @property + def ambient_dimension(self): + """What is the dimension of the space that the object is contained in?""" + raise NotImplementedError() + + @property + def bounds(self): + """Return a tuple (xmin, ymin, xmax, ymax) representing the bounding + rectangle for the geometric figure. + + """ + + raise NotImplementedError() + + def encloses(self, o): + """ + Return True if o is inside (not on or outside) the boundaries of self. + + The object will be decomposed into Points and individual Entities need + only define an encloses_point method for their class. + + See Also + ======== + + sympy.geometry.ellipse.Ellipse.encloses_point + sympy.geometry.polygon.Polygon.encloses_point + + Examples + ======== + + >>> from sympy import RegularPolygon, Point, Polygon + >>> t = Polygon(*RegularPolygon(Point(0, 0), 1, 3).vertices) + >>> t2 = Polygon(*RegularPolygon(Point(0, 0), 2, 3).vertices) + >>> t2.encloses(t) + True + >>> t.encloses(t2) + False + + """ + + from sympy.geometry.point import Point + from sympy.geometry.line import Segment, Ray, Line + from sympy.geometry.ellipse import Ellipse + from sympy.geometry.polygon import Polygon, RegularPolygon + + if isinstance(o, Point): + return self.encloses_point(o) + elif isinstance(o, Segment): + return all(self.encloses_point(x) for x in o.points) + elif isinstance(o, (Ray, Line)): + return False + elif isinstance(o, Ellipse): + return self.encloses_point(o.center) and \ + self.encloses_point( + Point(o.center.x + o.hradius, o.center.y)) and \ + not self.intersection(o) + elif isinstance(o, Polygon): + if isinstance(o, RegularPolygon): + if not self.encloses_point(o.center): + return False + return all(self.encloses_point(v) for v in o.vertices) + raise NotImplementedError() + + def equals(self, o): + return self == o + + def intersection(self, o): + """ + Returns a list of all of the intersections of self with o. + + Notes + ===== + + An entity is not required to implement this method. + + If two different types of entities can intersect, the item with + higher index in ordering_of_classes should implement + intersections with anything having a lower index. + + See Also + ======== + + sympy.geometry.util.intersection + + """ + raise NotImplementedError() + + def is_similar(self, other): + """Is this geometrical entity similar to another geometrical entity? + + Two entities are similar if a uniform scaling (enlarging or + shrinking) of one of the entities will allow one to obtain the other. + + Notes + ===== + + This method is not intended to be used directly but rather + through the `are_similar` function found in util.py. + An entity is not required to implement this method. + If two different types of entities can be similar, it is only + required that one of them be able to determine this. + + See Also + ======== + + scale + + """ + raise NotImplementedError() + + def reflect(self, line): + """ + Reflects an object across a line. + + Parameters + ========== + + line: Line + + Examples + ======== + + >>> from sympy import pi, sqrt, Line, RegularPolygon + >>> l = Line((0, pi), slope=sqrt(2)) + >>> pent = RegularPolygon((1, 2), 1, 5) + >>> rpent = pent.reflect(l) + >>> rpent + RegularPolygon(Point2D(-2*sqrt(2)*pi/3 - 1/3 + 4*sqrt(2)/3, 2/3 + 2*sqrt(2)/3 + 2*pi/3), -1, 5, -atan(2*sqrt(2)) + 3*pi/5) + + >>> from sympy import pi, Line, Circle, Point + >>> l = Line((0, pi), slope=1) + >>> circ = Circle(Point(0, 0), 5) + >>> rcirc = circ.reflect(l) + >>> rcirc + Circle(Point2D(-pi, pi), -5) + + """ + from sympy.geometry.point import Point + + g = self + l = line + o = Point(0, 0) + if l.slope.is_zero: + v = l.args[0].y + if not v: # x-axis + return g.scale(y=-1) + reps = [(p, p.translate(y=2*(v - p.y))) for p in g.atoms(Point)] + elif l.slope is oo: + v = l.args[0].x + if not v: # y-axis + return g.scale(x=-1) + reps = [(p, p.translate(x=2*(v - p.x))) for p in g.atoms(Point)] + else: + if not hasattr(g, 'reflect') and not all( + isinstance(arg, Point) for arg in g.args): + raise NotImplementedError( + 'reflect undefined or non-Point args in %s' % g) + a = atan(l.slope) + c = l.coefficients + d = -c[-1]/c[1] # y-intercept + # apply the transform to a single point + xf = Point(x, y) + xf = xf.translate(y=-d).rotate(-a, o).scale(y=-1 + ).rotate(a, o).translate(y=d) + # replace every point using that transform + reps = [(p, xf.xreplace({x: p.x, y: p.y})) for p in g.atoms(Point)] + return g.xreplace(dict(reps)) + + def rotate(self, angle, pt=None): + """Rotate ``angle`` radians counterclockwise about Point ``pt``. + + The default pt is the origin, Point(0, 0) + + See Also + ======== + + scale, translate + + Examples + ======== + + >>> from sympy import Point, RegularPolygon, Polygon, pi + >>> t = Polygon(*RegularPolygon(Point(0, 0), 1, 3).vertices) + >>> t # vertex on x axis + Triangle(Point2D(1, 0), Point2D(-1/2, sqrt(3)/2), Point2D(-1/2, -sqrt(3)/2)) + >>> t.rotate(pi/2) # vertex on y axis now + Triangle(Point2D(0, 1), Point2D(-sqrt(3)/2, -1/2), Point2D(sqrt(3)/2, -1/2)) + + """ + newargs = [] + for a in self.args: + if isinstance(a, GeometryEntity): + newargs.append(a.rotate(angle, pt)) + else: + newargs.append(a) + return type(self)(*newargs) + + def scale(self, x=1, y=1, pt=None): + """Scale the object by multiplying the x,y-coordinates by x and y. + + If pt is given, the scaling is done relative to that point; the + object is shifted by -pt, scaled, and shifted by pt. + + See Also + ======== + + rotate, translate + + Examples + ======== + + >>> from sympy import RegularPolygon, Point, Polygon + >>> t = Polygon(*RegularPolygon(Point(0, 0), 1, 3).vertices) + >>> t + Triangle(Point2D(1, 0), Point2D(-1/2, sqrt(3)/2), Point2D(-1/2, -sqrt(3)/2)) + >>> t.scale(2) + Triangle(Point2D(2, 0), Point2D(-1, sqrt(3)/2), Point2D(-1, -sqrt(3)/2)) + >>> t.scale(2, 2) + Triangle(Point2D(2, 0), Point2D(-1, sqrt(3)), Point2D(-1, -sqrt(3))) + + """ + from sympy.geometry.point import Point + if pt: + pt = Point(pt, dim=2) + return self.translate(*(-pt).args).scale(x, y).translate(*pt.args) + return type(self)(*[a.scale(x, y) for a in self.args]) # if this fails, override this class + + def translate(self, x=0, y=0): + """Shift the object by adding to the x,y-coordinates the values x and y. + + See Also + ======== + + rotate, scale + + Examples + ======== + + >>> from sympy import RegularPolygon, Point, Polygon + >>> t = Polygon(*RegularPolygon(Point(0, 0), 1, 3).vertices) + >>> t + Triangle(Point2D(1, 0), Point2D(-1/2, sqrt(3)/2), Point2D(-1/2, -sqrt(3)/2)) + >>> t.translate(2) + Triangle(Point2D(3, 0), Point2D(3/2, sqrt(3)/2), Point2D(3/2, -sqrt(3)/2)) + >>> t.translate(2, 2) + Triangle(Point2D(3, 2), Point2D(3/2, sqrt(3)/2 + 2), Point2D(3/2, 2 - sqrt(3)/2)) + + """ + newargs = [] + for a in self.args: + if isinstance(a, GeometryEntity): + newargs.append(a.translate(x, y)) + else: + newargs.append(a) + return self.func(*newargs) + + def parameter_value(self, other, t): + """Return the parameter corresponding to the given point. + Evaluating an arbitrary point of the entity at this parameter + value will return the given point. + + Examples + ======== + + >>> from sympy import Line, Point + >>> from sympy.abc import t + >>> a = Point(0, 0) + >>> b = Point(2, 2) + >>> Line(a, b).parameter_value((1, 1), t) + {t: 1/2} + >>> Line(a, b).arbitrary_point(t).subs(_) + Point2D(1, 1) + """ + from sympy.geometry.point import Point + if not isinstance(other, GeometryEntity): + other = Point(other, dim=self.ambient_dimension) + if not isinstance(other, Point): + raise ValueError("other must be a point") + sol = solve(self.arbitrary_point(T) - other, T, dict=True) + if not sol: + raise ValueError("Given point is not on %s" % func_name(self)) + return {t: sol[0][T]} + + +class GeometrySet(GeometryEntity, Set): + """Parent class of all GeometryEntity that are also Sets + (compatible with sympy.sets) + """ + __slots__ = () + + def _contains(self, other): + """sympy.sets uses the _contains method, so include it for compatibility.""" + + if isinstance(other, Set) and other.is_FiniteSet: + return all(self.__contains__(i) for i in other) + + return self.__contains__(other) + +@dispatch(GeometrySet, Set) # type:ignore # noqa:F811 +def union_sets(self, o): # noqa:F811 + """ Returns the union of self and o + for use with sympy.sets.Set, if possible. """ + + + # if its a FiniteSet, merge any points + # we contain and return a union with the rest + if o.is_FiniteSet: + other_points = [p for p in o if not self._contains(p)] + if len(other_points) == len(o): + return None + return Union(self, FiniteSet(*other_points)) + if self._contains(o): + return self + return None + + +@dispatch(GeometrySet, Set) # type: ignore # noqa:F811 +def intersection_sets(self, o): # noqa:F811 + """ Returns a sympy.sets.Set of intersection objects, + if possible. """ + + from sympy.geometry.point import Point + + try: + # if o is a FiniteSet, find the intersection directly + # to avoid infinite recursion + if o.is_FiniteSet: + inter = FiniteSet(*(p for p in o if self.contains(p))) + else: + inter = self.intersection(o) + except NotImplementedError: + # sympy.sets.Set.reduce expects None if an object + # doesn't know how to simplify + return None + + # put the points in a FiniteSet + points = FiniteSet(*[p for p in inter if isinstance(p, Point)]) + non_points = [p for p in inter if not isinstance(p, Point)] + + return Union(*(non_points + [points])) + +def translate(x, y): + """Return the matrix to translate a 2-D point by x and y.""" + rv = eye(3) + rv[2, 0] = x + rv[2, 1] = y + return rv + + +def scale(x, y, pt=None): + """Return the matrix to multiply a 2-D point's coordinates by x and y. + + If pt is given, the scaling is done relative to that point.""" + rv = eye(3) + rv[0, 0] = x + rv[1, 1] = y + if pt: + from sympy.geometry.point import Point + pt = Point(pt, dim=2) + tr1 = translate(*(-pt).args) + tr2 = translate(*pt.args) + return tr1*rv*tr2 + return rv + + +def rotate(th): + """Return the matrix to rotate a 2-D point about the origin by ``angle``. + + The angle is measured in radians. To Point a point about a point other + then the origin, translate the Point, do the rotation, and + translate it back: + + >>> from sympy.geometry.entity import rotate, translate + >>> from sympy import Point, pi + >>> rot_about_11 = translate(-1, -1)*rotate(pi/2)*translate(1, 1) + >>> Point(1, 1).transform(rot_about_11) + Point2D(1, 1) + >>> Point(0, 0).transform(rot_about_11) + Point2D(2, 0) + """ + s = sin(th) + rv = eye(3)*cos(th) + rv[0, 1] = s + rv[1, 0] = -s + rv[2, 2] = 1 + return rv diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/exceptions.py b/.venv/lib/python3.13/site-packages/sympy/geometry/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..41d97af718de2cebad3accefcd60e43ccf74a3f6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/exceptions.py @@ -0,0 +1,5 @@ +"""Geometry Errors.""" + +class GeometryError(ValueError): + """An exception raised by classes in the geometry module.""" + pass diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/line.py b/.venv/lib/python3.13/site-packages/sympy/geometry/line.py new file mode 100644 index 0000000000000000000000000000000000000000..ed73d43d0c9581f9d51f299cf4425acb11958e57 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/line.py @@ -0,0 +1,2877 @@ +"""Line-like geometrical entities. + +Contains +======== +LinearEntity +Line +Ray +Segment +LinearEntity2D +Line2D +Ray2D +Segment2D +LinearEntity3D +Line3D +Ray3D +Segment3D + +""" + +from sympy.core.containers import Tuple +from sympy.core.evalf import N +from sympy.core.expr import Expr +from sympy.core.numbers import Rational, oo, Float +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.sorting import ordered +from sympy.core.symbol import _symbol, Dummy, uniquely_named_symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (_pi_coeff, acos, tan, atan2) +from .entity import GeometryEntity, GeometrySet +from .exceptions import GeometryError +from .point import Point, Point3D +from .util import find, intersection +from sympy.logic.boolalg import And +from sympy.matrices import Matrix +from sympy.sets.sets import Intersection +from sympy.simplify.simplify import simplify +from sympy.solvers.solvers import solve +from sympy.solvers.solveset import linear_coeffs +from sympy.utilities.misc import Undecidable, filldedent + + +import random + + +t, u = [Dummy('line_dummy') for i in range(2)] + + +class LinearEntity(GeometrySet): + """A base class for all linear entities (Line, Ray and Segment) + in n-dimensional Euclidean space. + + Attributes + ========== + + ambient_dimension + direction + length + p1 + p2 + points + + Notes + ===== + + This is an abstract class and is not meant to be instantiated. + + See Also + ======== + + sympy.geometry.entity.GeometryEntity + + """ + def __new__(cls, p1, p2=None, **kwargs): + p1, p2 = Point._normalize_dimension(p1, p2) + if p1 == p2: + # sometimes we return a single point if we are not given two unique + # points. This is done in the specific subclass + raise ValueError( + "%s.__new__ requires two unique Points." % cls.__name__) + if len(p1) != len(p2): + raise ValueError( + "%s.__new__ requires two Points of equal dimension." % cls.__name__) + + return GeometryEntity.__new__(cls, p1, p2, **kwargs) + + def __contains__(self, other): + """Return a definitive answer or else raise an error if it cannot + be determined that other is on the boundaries of self.""" + result = self.contains(other) + + if result is not None: + return result + else: + raise Undecidable( + "Cannot decide whether '%s' contains '%s'" % (self, other)) + + def _span_test(self, other): + """Test whether the point `other` lies in the positive span of `self`. + A point x is 'in front' of a point y if x.dot(y) >= 0. Return + -1 if `other` is behind `self.p1`, 0 if `other` is `self.p1` and + and 1 if `other` is in front of `self.p1`.""" + if self.p1 == other: + return 0 + + rel_pos = other - self.p1 + d = self.direction + if d.dot(rel_pos) > 0: + return 1 + return -1 + + @property + def ambient_dimension(self): + """A property method that returns the dimension of LinearEntity + object. + + Parameters + ========== + + p1 : LinearEntity + + Returns + ======= + + dimension : integer + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2 = Point(0, 0), Point(1, 1) + >>> l1 = Line(p1, p2) + >>> l1.ambient_dimension + 2 + + >>> from sympy import Point, Line + >>> p1, p2 = Point(0, 0, 0), Point(1, 1, 1) + >>> l1 = Line(p1, p2) + >>> l1.ambient_dimension + 3 + + """ + return len(self.p1) + + def angle_between(l1, l2): + """Return the non-reflex angle formed by rays emanating from + the origin with directions the same as the direction vectors + of the linear entities. + + Parameters + ========== + + l1 : LinearEntity + l2 : LinearEntity + + Returns + ======= + + angle : angle in radians + + Notes + ===== + + From the dot product of vectors v1 and v2 it is known that: + + ``dot(v1, v2) = |v1|*|v2|*cos(A)`` + + where A is the angle formed between the two vectors. We can + get the directional vectors of the two lines and readily + find the angle between the two using the above formula. + + See Also + ======== + + is_perpendicular, Ray2D.closing_angle + + Examples + ======== + + >>> from sympy import Line + >>> e = Line((0, 0), (1, 0)) + >>> ne = Line((0, 0), (1, 1)) + >>> sw = Line((1, 1), (0, 0)) + >>> ne.angle_between(e) + pi/4 + >>> sw.angle_between(e) + 3*pi/4 + + To obtain the non-obtuse angle at the intersection of lines, use + the ``smallest_angle_between`` method: + + >>> sw.smallest_angle_between(e) + pi/4 + + >>> from sympy import Point3D, Line3D + >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(1, 1, 1), Point3D(-1, 2, 0) + >>> l1, l2 = Line3D(p1, p2), Line3D(p2, p3) + >>> l1.angle_between(l2) + acos(-sqrt(2)/3) + >>> l1.smallest_angle_between(l2) + acos(sqrt(2)/3) + """ + if not isinstance(l1, LinearEntity) and not isinstance(l2, LinearEntity): + raise TypeError('Must pass only LinearEntity objects') + + v1, v2 = l1.direction, l2.direction + return acos(v1.dot(v2)/(abs(v1)*abs(v2))) + + def smallest_angle_between(l1, l2): + """Return the smallest angle formed at the intersection of the + lines containing the linear entities. + + Parameters + ========== + + l1 : LinearEntity + l2 : LinearEntity + + Returns + ======= + + angle : angle in radians + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2, p3 = Point(0, 0), Point(0, 4), Point(2, -2) + >>> l1, l2 = Line(p1, p2), Line(p1, p3) + >>> l1.smallest_angle_between(l2) + pi/4 + + See Also + ======== + + angle_between, is_perpendicular, Ray2D.closing_angle + """ + if not isinstance(l1, LinearEntity) and not isinstance(l2, LinearEntity): + raise TypeError('Must pass only LinearEntity objects') + + v1, v2 = l1.direction, l2.direction + return acos(abs(v1.dot(v2))/(abs(v1)*abs(v2))) + + def arbitrary_point(self, parameter='t'): + """A parameterized point on the Line. + + Parameters + ========== + + parameter : str, optional + The name of the parameter which will be used for the parametric + point. The default value is 't'. When this parameter is 0, the + first point used to define the line will be returned, and when + it is 1 the second point will be returned. + + Returns + ======= + + point : Point + + Raises + ====== + + ValueError + When ``parameter`` already appears in the Line's definition. + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2 = Point(1, 0), Point(5, 3) + >>> l1 = Line(p1, p2) + >>> l1.arbitrary_point() + Point2D(4*t + 1, 3*t) + >>> from sympy import Point3D, Line3D + >>> p1, p2 = Point3D(1, 0, 0), Point3D(5, 3, 1) + >>> l1 = Line3D(p1, p2) + >>> l1.arbitrary_point() + Point3D(4*t + 1, 3*t, t) + + """ + t = _symbol(parameter, real=True) + if t.name in (f.name for f in self.free_symbols): + raise ValueError(filldedent(''' + Symbol %s already appears in object + and cannot be used as a parameter. + ''' % t.name)) + # multiply on the right so the variable gets + # combined with the coordinates of the point + return self.p1 + (self.p2 - self.p1)*t + + @staticmethod + def are_concurrent(*lines): + """Is a sequence of linear entities concurrent? + + Two or more linear entities are concurrent if they all + intersect at a single point. + + Parameters + ========== + + lines + A sequence of linear entities. + + Returns + ======= + + True : if the set of linear entities intersect in one point + False : otherwise. + + See Also + ======== + + sympy.geometry.util.intersection + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2 = Point(0, 0), Point(3, 5) + >>> p3, p4 = Point(-2, -2), Point(0, 2) + >>> l1, l2, l3 = Line(p1, p2), Line(p1, p3), Line(p1, p4) + >>> Line.are_concurrent(l1, l2, l3) + True + >>> l4 = Line(p2, p3) + >>> Line.are_concurrent(l2, l3, l4) + False + >>> from sympy import Point3D, Line3D + >>> p1, p2 = Point3D(0, 0, 0), Point3D(3, 5, 2) + >>> p3, p4 = Point3D(-2, -2, -2), Point3D(0, 2, 1) + >>> l1, l2, l3 = Line3D(p1, p2), Line3D(p1, p3), Line3D(p1, p4) + >>> Line3D.are_concurrent(l1, l2, l3) + True + >>> l4 = Line3D(p2, p3) + >>> Line3D.are_concurrent(l2, l3, l4) + False + + """ + common_points = Intersection(*lines) + if common_points.is_FiniteSet and len(common_points) == 1: + return True + return False + + def contains(self, other): + """Subclasses should implement this method and should return + True if other is on the boundaries of self; + False if not on the boundaries of self; + None if a determination cannot be made.""" + raise NotImplementedError() + + @property + def direction(self): + """The direction vector of the LinearEntity. + + Returns + ======= + + p : a Point; the ray from the origin to this point is the + direction of `self` + + Examples + ======== + + >>> from sympy import Line + >>> a, b = (1, 1), (1, 3) + >>> Line(a, b).direction + Point2D(0, 2) + >>> Line(b, a).direction + Point2D(0, -2) + + This can be reported so the distance from the origin is 1: + + >>> Line(b, a).direction.unit + Point2D(0, -1) + + See Also + ======== + + sympy.geometry.point.Point.unit + + """ + return self.p2 - self.p1 + + def intersection(self, other): + """The intersection with another geometrical entity. + + Parameters + ========== + + o : Point or LinearEntity + + Returns + ======= + + intersection : list of geometrical entities + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Line, Segment + >>> p1, p2, p3 = Point(0, 0), Point(1, 1), Point(7, 7) + >>> l1 = Line(p1, p2) + >>> l1.intersection(p3) + [Point2D(7, 7)] + >>> p4, p5 = Point(5, 0), Point(0, 3) + >>> l2 = Line(p4, p5) + >>> l1.intersection(l2) + [Point2D(15/8, 15/8)] + >>> p6, p7 = Point(0, 5), Point(2, 6) + >>> s1 = Segment(p6, p7) + >>> l1.intersection(s1) + [] + >>> from sympy import Point3D, Line3D, Segment3D + >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(1, 1, 1), Point3D(7, 7, 7) + >>> l1 = Line3D(p1, p2) + >>> l1.intersection(p3) + [Point3D(7, 7, 7)] + >>> l1 = Line3D(Point3D(4,19,12), Point3D(5,25,17)) + >>> l2 = Line3D(Point3D(-3, -15, -19), direction_ratio=[2,8,8]) + >>> l1.intersection(l2) + [Point3D(1, 1, -3)] + >>> p6, p7 = Point3D(0, 5, 2), Point3D(2, 6, 3) + >>> s1 = Segment3D(p6, p7) + >>> l1.intersection(s1) + [] + + """ + def intersect_parallel_rays(ray1, ray2): + if ray1.direction.dot(ray2.direction) > 0: + # rays point in the same direction + # so return the one that is "in front" + return [ray2] if ray1._span_test(ray2.p1) >= 0 else [ray1] + else: + # rays point in opposite directions + st = ray1._span_test(ray2.p1) + if st < 0: + return [] + elif st == 0: + return [ray2.p1] + return [Segment(ray1.p1, ray2.p1)] + + def intersect_parallel_ray_and_segment(ray, seg): + st1, st2 = ray._span_test(seg.p1), ray._span_test(seg.p2) + if st1 < 0 and st2 < 0: + return [] + elif st1 >= 0 and st2 >= 0: + return [seg] + elif st1 >= 0: # st2 < 0: + return [Segment(ray.p1, seg.p1)] + else: # st1 < 0 and st2 >= 0: + return [Segment(ray.p1, seg.p2)] + + def intersect_parallel_segments(seg1, seg2): + if seg1.contains(seg2): + return [seg2] + if seg2.contains(seg1): + return [seg1] + + # direct the segments so they're oriented the same way + if seg1.direction.dot(seg2.direction) < 0: + seg2 = Segment(seg2.p2, seg2.p1) + # order the segments so seg1 is "behind" seg2 + if seg1._span_test(seg2.p1) < 0: + seg1, seg2 = seg2, seg1 + if seg2._span_test(seg1.p2) < 0: + return [] + return [Segment(seg2.p1, seg1.p2)] + + if not isinstance(other, GeometryEntity): + other = Point(other, dim=self.ambient_dimension) + if other.is_Point: + if self.contains(other): + return [other] + else: + return [] + elif isinstance(other, LinearEntity): + # break into cases based on whether + # the lines are parallel, non-parallel intersecting, or skew + pts = Point._normalize_dimension(self.p1, self.p2, other.p1, other.p2) + rank = Point.affine_rank(*pts) + + if rank == 1: + # we're collinear + if isinstance(self, Line): + return [other] + if isinstance(other, Line): + return [self] + + if isinstance(self, Ray) and isinstance(other, Ray): + return intersect_parallel_rays(self, other) + if isinstance(self, Ray) and isinstance(other, Segment): + return intersect_parallel_ray_and_segment(self, other) + if isinstance(self, Segment) and isinstance(other, Ray): + return intersect_parallel_ray_and_segment(other, self) + if isinstance(self, Segment) and isinstance(other, Segment): + return intersect_parallel_segments(self, other) + elif rank == 2: + # we're in the same plane + l1 = Line(*pts[:2]) + l2 = Line(*pts[2:]) + + # check to see if we're parallel. If we are, we can't + # be intersecting, since the collinear case was already + # handled + if l1.direction.is_scalar_multiple(l2.direction): + return [] + + # find the intersection as if everything were lines + # by solving the equation t*d + p1 == s*d' + p1' + m = Matrix([l1.direction, -l2.direction]).transpose() + v = Matrix([l2.p1 - l1.p1]).transpose() + + # we cannot use m.solve(v) because that only works for square matrices + m_rref, pivots = m.col_insert(2, v).rref(simplify=True) + # rank == 2 ensures we have 2 pivots, but let's check anyway + if len(pivots) != 2: + raise GeometryError("Failed when solving Mx=b when M={} and b={}".format(m, v)) + coeff = m_rref[0, 2] + line_intersection = l1.direction*coeff + self.p1 + + # if both are lines, skip a containment check + if isinstance(self, Line) and isinstance(other, Line): + return [line_intersection] + + if ((isinstance(self, Line) or + self.contains(line_intersection)) and + other.contains(line_intersection)): + return [line_intersection] + if not self.atoms(Float) and not other.atoms(Float): + # if it can fail when there are no Floats then + # maybe the following parametric check should be + # done + return [] + # floats may fail exact containment so check that the + # arbitrary points, when equal, both give a + # non-negative parameter when the arbitrary point + # coordinates are equated + tu = solve(self.arbitrary_point(t) - other.arbitrary_point(u), + t, u, dict=True)[0] + def ok(p, l): + if isinstance(l, Line): + # p > -oo + return True + if isinstance(l, Ray): + # p >= 0 + return p.is_nonnegative + if isinstance(l, Segment): + # 0 <= p <= 1 + return p.is_nonnegative and (1 - p).is_nonnegative + raise ValueError("unexpected line type") + if ok(tu[t], self) and ok(tu[u], other): + return [line_intersection] + return [] + else: + # we're skew + return [] + + return other.intersection(self) + + def is_parallel(l1, l2): + """Are two linear entities parallel? + + Parameters + ========== + + l1 : LinearEntity + l2 : LinearEntity + + Returns + ======= + + True : if l1 and l2 are parallel, + False : otherwise. + + See Also + ======== + + coefficients + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2 = Point(0, 0), Point(1, 1) + >>> p3, p4 = Point(3, 4), Point(6, 7) + >>> l1, l2 = Line(p1, p2), Line(p3, p4) + >>> Line.is_parallel(l1, l2) + True + >>> p5 = Point(6, 6) + >>> l3 = Line(p3, p5) + >>> Line.is_parallel(l1, l3) + False + >>> from sympy import Point3D, Line3D + >>> p1, p2 = Point3D(0, 0, 0), Point3D(3, 4, 5) + >>> p3, p4 = Point3D(2, 1, 1), Point3D(8, 9, 11) + >>> l1, l2 = Line3D(p1, p2), Line3D(p3, p4) + >>> Line3D.is_parallel(l1, l2) + True + >>> p5 = Point3D(6, 6, 6) + >>> l3 = Line3D(p3, p5) + >>> Line3D.is_parallel(l1, l3) + False + + """ + if not isinstance(l1, LinearEntity) and not isinstance(l2, LinearEntity): + raise TypeError('Must pass only LinearEntity objects') + + return l1.direction.is_scalar_multiple(l2.direction) + + def is_perpendicular(l1, l2): + """Are two linear entities perpendicular? + + Parameters + ========== + + l1 : LinearEntity + l2 : LinearEntity + + Returns + ======= + + True : if l1 and l2 are perpendicular, + False : otherwise. + + See Also + ======== + + coefficients + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2, p3 = Point(0, 0), Point(1, 1), Point(-1, 1) + >>> l1, l2 = Line(p1, p2), Line(p1, p3) + >>> l1.is_perpendicular(l2) + True + >>> p4 = Point(5, 3) + >>> l3 = Line(p1, p4) + >>> l1.is_perpendicular(l3) + False + >>> from sympy import Point3D, Line3D + >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(1, 1, 1), Point3D(-1, 2, 0) + >>> l1, l2 = Line3D(p1, p2), Line3D(p2, p3) + >>> l1.is_perpendicular(l2) + False + >>> p4 = Point3D(5, 3, 7) + >>> l3 = Line3D(p1, p4) + >>> l1.is_perpendicular(l3) + False + + """ + if not isinstance(l1, LinearEntity) and not isinstance(l2, LinearEntity): + raise TypeError('Must pass only LinearEntity objects') + + return S.Zero.equals(l1.direction.dot(l2.direction)) + + def is_similar(self, other): + """ + Return True if self and other are contained in the same line. + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2, p3 = Point(0, 1), Point(3, 4), Point(2, 3) + >>> l1 = Line(p1, p2) + >>> l2 = Line(p1, p3) + >>> l1.is_similar(l2) + True + """ + l = Line(self.p1, self.p2) + return l.contains(other) + + @property + def length(self): + """ + The length of the line. + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2 = Point(0, 0), Point(3, 5) + >>> l1 = Line(p1, p2) + >>> l1.length + oo + """ + return S.Infinity + + @property + def p1(self): + """The first defining point of a linear entity. + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2 = Point(0, 0), Point(5, 3) + >>> l = Line(p1, p2) + >>> l.p1 + Point2D(0, 0) + + """ + return self.args[0] + + @property + def p2(self): + """The second defining point of a linear entity. + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2 = Point(0, 0), Point(5, 3) + >>> l = Line(p1, p2) + >>> l.p2 + Point2D(5, 3) + + """ + return self.args[1] + + def parallel_line(self, p): + """Create a new Line parallel to this linear entity which passes + through the point `p`. + + Parameters + ========== + + p : Point + + Returns + ======= + + line : Line + + See Also + ======== + + is_parallel + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2, p3 = Point(0, 0), Point(2, 3), Point(-2, 2) + >>> l1 = Line(p1, p2) + >>> l2 = l1.parallel_line(p3) + >>> p3 in l2 + True + >>> l1.is_parallel(l2) + True + >>> from sympy import Point3D, Line3D + >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(2, 3, 4), Point3D(-2, 2, 0) + >>> l1 = Line3D(p1, p2) + >>> l2 = l1.parallel_line(p3) + >>> p3 in l2 + True + >>> l1.is_parallel(l2) + True + + """ + p = Point(p, dim=self.ambient_dimension) + return Line(p, p + self.direction) + + def perpendicular_line(self, p): + """Create a new Line perpendicular to this linear entity which passes + through the point `p`. + + Parameters + ========== + + p : Point + + Returns + ======= + + line : Line + + See Also + ======== + + sympy.geometry.line.LinearEntity.is_perpendicular, perpendicular_segment + + Examples + ======== + + >>> from sympy import Point3D, Line3D + >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(2, 3, 4), Point3D(-2, 2, 0) + >>> L = Line3D(p1, p2) + >>> P = L.perpendicular_line(p3); P + Line3D(Point3D(-2, 2, 0), Point3D(4/29, 6/29, 8/29)) + >>> L.is_perpendicular(P) + True + + In 3D the, the first point used to define the line is the point + through which the perpendicular was required to pass; the + second point is (arbitrarily) contained in the given line: + + >>> P.p2 in L + True + """ + p = Point(p, dim=self.ambient_dimension) + if p in self: + p = p + self.direction.orthogonal_direction + return Line(p, self.projection(p)) + + def perpendicular_segment(self, p): + """Create a perpendicular line segment from `p` to this line. + + The endpoints of the segment are ``p`` and the closest point in + the line containing self. (If self is not a line, the point might + not be in self.) + + Parameters + ========== + + p : Point + + Returns + ======= + + segment : Segment + + Notes + ===== + + Returns `p` itself if `p` is on this linear entity. + + See Also + ======== + + perpendicular_line + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2, p3 = Point(0, 0), Point(1, 1), Point(0, 2) + >>> l1 = Line(p1, p2) + >>> s1 = l1.perpendicular_segment(p3) + >>> l1.is_perpendicular(s1) + True + >>> p3 in s1 + True + >>> l1.perpendicular_segment(Point(4, 0)) + Segment2D(Point2D(4, 0), Point2D(2, 2)) + >>> from sympy import Point3D, Line3D + >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(1, 1, 1), Point3D(0, 2, 0) + >>> l1 = Line3D(p1, p2) + >>> s1 = l1.perpendicular_segment(p3) + >>> l1.is_perpendicular(s1) + True + >>> p3 in s1 + True + >>> l1.perpendicular_segment(Point3D(4, 0, 0)) + Segment3D(Point3D(4, 0, 0), Point3D(4/3, 4/3, 4/3)) + + """ + p = Point(p, dim=self.ambient_dimension) + if p in self: + return p + l = self.perpendicular_line(p) + # The intersection should be unique, so unpack the singleton + p2, = Intersection(Line(self.p1, self.p2), l) + + return Segment(p, p2) + + @property + def points(self): + """The two points used to define this linear entity. + + Returns + ======= + + points : tuple of Points + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2 = Point(0, 0), Point(5, 11) + >>> l1 = Line(p1, p2) + >>> l1.points + (Point2D(0, 0), Point2D(5, 11)) + + """ + return (self.p1, self.p2) + + def projection(self, other): + """Project a point, line, ray, or segment onto this linear entity. + + Parameters + ========== + + other : Point or LinearEntity (Line, Ray, Segment) + + Returns + ======= + + projection : Point or LinearEntity (Line, Ray, Segment) + The return type matches the type of the parameter ``other``. + + Raises + ====== + + GeometryError + When method is unable to perform projection. + + Notes + ===== + + A projection involves taking the two points that define + the linear entity and projecting those points onto a + Line and then reforming the linear entity using these + projections. + A point P is projected onto a line L by finding the point + on L that is closest to P. This point is the intersection + of L and the line perpendicular to L that passes through P. + + See Also + ======== + + sympy.geometry.point.Point, perpendicular_line + + Examples + ======== + + >>> from sympy import Point, Line, Segment, Rational + >>> p1, p2, p3 = Point(0, 0), Point(1, 1), Point(Rational(1, 2), 0) + >>> l1 = Line(p1, p2) + >>> l1.projection(p3) + Point2D(1/4, 1/4) + >>> p4, p5 = Point(10, 0), Point(12, 1) + >>> s1 = Segment(p4, p5) + >>> l1.projection(s1) + Segment2D(Point2D(5, 5), Point2D(13/2, 13/2)) + >>> p1, p2, p3 = Point(0, 0, 1), Point(1, 1, 2), Point(2, 0, 1) + >>> l1 = Line(p1, p2) + >>> l1.projection(p3) + Point3D(2/3, 2/3, 5/3) + >>> p4, p5 = Point(10, 0, 1), Point(12, 1, 3) + >>> s1 = Segment(p4, p5) + >>> l1.projection(s1) + Segment3D(Point3D(10/3, 10/3, 13/3), Point3D(5, 5, 6)) + + """ + if not isinstance(other, GeometryEntity): + other = Point(other, dim=self.ambient_dimension) + + def proj_point(p): + return Point.project(p - self.p1, self.direction) + self.p1 + + if isinstance(other, Point): + return proj_point(other) + elif isinstance(other, LinearEntity): + p1, p2 = proj_point(other.p1), proj_point(other.p2) + # test to see if we're degenerate + if p1 == p2: + return p1 + projected = other.__class__(p1, p2) + projected = Intersection(self, projected) + if projected.is_empty: + return projected + # if we happen to have intersected in only a point, return that + if projected.is_FiniteSet and len(projected) == 1: + # projected is a set of size 1, so unpack it in `a` + a, = projected + return a + # order args so projection is in the same direction as self + if self.direction.dot(projected.direction) < 0: + p1, p2 = projected.args + projected = projected.func(p2, p1) + return projected + + raise GeometryError( + "Do not know how to project %s onto %s" % (other, self)) + + def random_point(self, seed=None): + """A random point on a LinearEntity. + + Returns + ======= + + point : Point + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Line, Ray, Segment + >>> p1, p2 = Point(0, 0), Point(5, 3) + >>> line = Line(p1, p2) + >>> r = line.random_point(seed=42) # seed value is optional + >>> r.n(3) + Point2D(-0.72, -0.432) + >>> r in line + True + >>> Ray(p1, p2).random_point(seed=42).n(3) + Point2D(0.72, 0.432) + >>> Segment(p1, p2).random_point(seed=42).n(3) + Point2D(3.2, 1.92) + + """ + if seed is not None: + rng = random.Random(seed) + else: + rng = random + pt = self.arbitrary_point(t) + if isinstance(self, Ray): + v = abs(rng.gauss(0, 1)) + elif isinstance(self, Segment): + v = rng.random() + elif isinstance(self, Line): + v = rng.gauss(0, 1) + else: + raise NotImplementedError('unhandled line type') + return pt.subs(t, Rational(v)) + + def bisectors(self, other): + """Returns the perpendicular lines which pass through the intersections + of self and other that are in the same plane. + + Parameters + ========== + + line : Line3D + + Returns + ======= + + list: two Line instances + + Examples + ======== + + >>> from sympy import Point3D, Line3D + >>> r1 = Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)) + >>> r2 = Line3D(Point3D(0, 0, 0), Point3D(0, 1, 0)) + >>> r1.bisectors(r2) + [Line3D(Point3D(0, 0, 0), Point3D(1, 1, 0)), Line3D(Point3D(0, 0, 0), Point3D(1, -1, 0))] + + """ + if not isinstance(other, LinearEntity): + raise GeometryError("Expecting LinearEntity, not %s" % other) + + l1, l2 = self, other + + # make sure dimensions match or else a warning will rise from + # intersection calculation + if l1.p1.ambient_dimension != l2.p1.ambient_dimension: + if isinstance(l1, Line2D): + l1, l2 = l2, l1 + _, p1 = Point._normalize_dimension(l1.p1, l2.p1, on_morph='ignore') + _, p2 = Point._normalize_dimension(l1.p2, l2.p2, on_morph='ignore') + l2 = Line(p1, p2) + + point = intersection(l1, l2) + + # Three cases: Lines may intersect in a point, may be equal or may not intersect. + if not point: + raise GeometryError("The lines do not intersect") + else: + pt = point[0] + if isinstance(pt, Line): + # Intersection is a line because both lines are coincident + return [self] + + + d1 = l1.direction.unit + d2 = l2.direction.unit + + bis1 = Line(pt, pt + d1 + d2) + bis2 = Line(pt, pt + d1 - d2) + + return [bis1, bis2] + + +class Line(LinearEntity): + """An infinite line in space. + + A 2D line is declared with two distinct points, point and slope, or + an equation. A 3D line may be defined with a point and a direction ratio. + + Parameters + ========== + + p1 : Point + p2 : Point + slope : SymPy expression + direction_ratio : list + equation : equation of a line + + Notes + ===== + + `Line` will automatically subclass to `Line2D` or `Line3D` based + on the dimension of `p1`. The `slope` argument is only relevant + for `Line2D` and the `direction_ratio` argument is only relevant + for `Line3D`. + + The order of the points will define the direction of the line + which is used when calculating the angle between lines. + + See Also + ======== + + sympy.geometry.point.Point + sympy.geometry.line.Line2D + sympy.geometry.line.Line3D + + Examples + ======== + + >>> from sympy import Line, Segment, Point, Eq + >>> from sympy.abc import x, y, a, b + + >>> L = Line(Point(2,3), Point(3,5)) + >>> L + Line2D(Point2D(2, 3), Point2D(3, 5)) + >>> L.points + (Point2D(2, 3), Point2D(3, 5)) + >>> L.equation() + -2*x + y + 1 + >>> L.coefficients + (-2, 1, 1) + + Instantiate with keyword ``slope``: + + >>> Line(Point(0, 0), slope=0) + Line2D(Point2D(0, 0), Point2D(1, 0)) + + Instantiate with another linear object + + >>> s = Segment((0, 0), (0, 1)) + >>> Line(s).equation() + x + + The line corresponding to an equation in the for `ax + by + c = 0`, + can be entered: + + >>> Line(3*x + y + 18) + Line2D(Point2D(0, -18), Point2D(1, -21)) + + If `x` or `y` has a different name, then they can be specified, too, + as a string (to match the name) or symbol: + + >>> Line(Eq(3*a + b, -18), x='a', y=b) + Line2D(Point2D(0, -18), Point2D(1, -21)) + """ + def __new__(cls, *args, **kwargs): + if len(args) == 1 and isinstance(args[0], (Expr, Eq)): + missing = uniquely_named_symbol('?', args) + if not kwargs: + x = 'x' + y = 'y' + else: + x = kwargs.pop('x', missing) + y = kwargs.pop('y', missing) + if kwargs: + raise ValueError('expecting only x and y as keywords') + + equation = args[0] + if isinstance(equation, Eq): + equation = equation.lhs - equation.rhs + + def find_or_missing(x): + try: + return find(x, equation) + except ValueError: + return missing + x = find_or_missing(x) + y = find_or_missing(y) + + a, b, c = linear_coeffs(equation, x, y) + + if b: + return Line((0, -c/b), slope=-a/b) + if a: + return Line((-c/a, 0), slope=oo) + + raise ValueError('not found in equation: %s' % (set('xy') - {x, y})) + + else: + if len(args) > 0: + p1 = args[0] + if len(args) > 1: + p2 = args[1] + else: + p2 = None + + if isinstance(p1, LinearEntity): + if p2: + raise ValueError('If p1 is a LinearEntity, p2 must be None.') + dim = len(p1.p1) + else: + p1 = Point(p1) + dim = len(p1) + if p2 is not None or isinstance(p2, Point) and p2.ambient_dimension != dim: + p2 = Point(p2) + + if dim == 2: + return Line2D(p1, p2, **kwargs) + elif dim == 3: + return Line3D(p1, p2, **kwargs) + return LinearEntity.__new__(cls, p1, p2, **kwargs) + + def contains(self, other): + """ + Return True if `other` is on this Line, or False otherwise. + + Examples + ======== + + >>> from sympy import Line,Point + >>> p1, p2 = Point(0, 1), Point(3, 4) + >>> l = Line(p1, p2) + >>> l.contains(p1) + True + >>> l.contains((0, 1)) + True + >>> l.contains((0, 0)) + False + >>> a = (0, 0, 0) + >>> b = (1, 1, 1) + >>> c = (2, 2, 2) + >>> l1 = Line(a, b) + >>> l2 = Line(b, a) + >>> l1 == l2 + False + >>> l1 in l2 + True + + """ + if not isinstance(other, GeometryEntity): + other = Point(other, dim=self.ambient_dimension) + if isinstance(other, Point): + return Point.is_collinear(other, self.p1, self.p2) + if isinstance(other, LinearEntity): + return Point.is_collinear(self.p1, self.p2, other.p1, other.p2) + return False + + def distance(self, other): + """ + Finds the shortest distance between a line and a point. + + Raises + ====== + + NotImplementedError is raised if `other` is not a Point + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2 = Point(0, 0), Point(1, 1) + >>> s = Line(p1, p2) + >>> s.distance(Point(-1, 1)) + sqrt(2) + >>> s.distance((-1, 2)) + 3*sqrt(2)/2 + >>> p1, p2 = Point(0, 0, 0), Point(1, 1, 1) + >>> s = Line(p1, p2) + >>> s.distance(Point(-1, 1, 1)) + 2*sqrt(6)/3 + >>> s.distance((-1, 1, 1)) + 2*sqrt(6)/3 + + """ + if not isinstance(other, GeometryEntity): + other = Point(other, dim=self.ambient_dimension) + if self.contains(other): + return S.Zero + return self.perpendicular_segment(other).length + + def equals(self, other): + """Returns True if self and other are the same mathematical entities""" + if not isinstance(other, Line): + return False + return Point.is_collinear(self.p1, other.p1, self.p2, other.p2) + + def plot_interval(self, parameter='t'): + """The plot interval for the default geometric plot of line. Gives + values that will produce a line that is +/- 5 units long (where a + unit is the distance between the two points that define the line). + + Parameters + ========== + + parameter : str, optional + Default value is 't'. + + Returns + ======= + + plot_interval : list (plot interval) + [parameter, lower_bound, upper_bound] + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2 = Point(0, 0), Point(5, 3) + >>> l1 = Line(p1, p2) + >>> l1.plot_interval() + [t, -5, 5] + + """ + t = _symbol(parameter, real=True) + return [t, -5, 5] + + +class Ray(LinearEntity): + """A Ray is a semi-line in the space with a source point and a direction. + + Parameters + ========== + + p1 : Point + The source of the Ray + p2 : Point or radian value + This point determines the direction in which the Ray propagates. + If given as an angle it is interpreted in radians with the positive + direction being ccw. + + Attributes + ========== + + source + + See Also + ======== + + sympy.geometry.line.Ray2D + sympy.geometry.line.Ray3D + sympy.geometry.point.Point + sympy.geometry.line.Line + + Notes + ===== + + `Ray` will automatically subclass to `Ray2D` or `Ray3D` based on the + dimension of `p1`. + + Examples + ======== + + >>> from sympy import Ray, Point, pi + >>> r = Ray(Point(2, 3), Point(3, 5)) + >>> r + Ray2D(Point2D(2, 3), Point2D(3, 5)) + >>> r.points + (Point2D(2, 3), Point2D(3, 5)) + >>> r.source + Point2D(2, 3) + >>> r.xdirection + oo + >>> r.ydirection + oo + >>> r.slope + 2 + >>> Ray(Point(0, 0), angle=pi/4).slope + 1 + + """ + def __new__(cls, p1, p2=None, **kwargs): + p1 = Point(p1) + if p2 is not None: + p1, p2 = Point._normalize_dimension(p1, Point(p2)) + dim = len(p1) + + if dim == 2: + return Ray2D(p1, p2, **kwargs) + elif dim == 3: + return Ray3D(p1, p2, **kwargs) + return LinearEntity.__new__(cls, p1, p2, **kwargs) + + def _svg(self, scale_factor=1., fill_color="#66cc99"): + """Returns SVG path element for the LinearEntity. + + Parameters + ========== + + scale_factor : float + Multiplication factor for the SVG stroke-width. Default is 1. + fill_color : str, optional + Hex string for fill color. Default is "#66cc99". + """ + verts = (N(self.p1), N(self.p2)) + coords = ["{},{}".format(p.x, p.y) for p in verts] + path = "M {} L {}".format(coords[0], " L ".join(coords[1:])) + + return ( + '' + ).format(2.*scale_factor, path, fill_color) + + def contains(self, other): + """ + Is other GeometryEntity contained in this Ray? + + Examples + ======== + + >>> from sympy import Ray,Point,Segment + >>> p1, p2 = Point(0, 0), Point(4, 4) + >>> r = Ray(p1, p2) + >>> r.contains(p1) + True + >>> r.contains((1, 1)) + True + >>> r.contains((1, 3)) + False + >>> s = Segment((1, 1), (2, 2)) + >>> r.contains(s) + True + >>> s = Segment((1, 2), (2, 5)) + >>> r.contains(s) + False + >>> r1 = Ray((2, 2), (3, 3)) + >>> r.contains(r1) + True + >>> r1 = Ray((2, 2), (3, 5)) + >>> r.contains(r1) + False + """ + if not isinstance(other, GeometryEntity): + other = Point(other, dim=self.ambient_dimension) + if isinstance(other, Point): + if Point.is_collinear(self.p1, self.p2, other): + # if we're in the direction of the ray, our + # direction vector dot the ray's direction vector + # should be non-negative + return bool((self.p2 - self.p1).dot(other - self.p1) >= S.Zero) + return False + elif isinstance(other, Ray): + if Point.is_collinear(self.p1, self.p2, other.p1, other.p2): + return bool((self.p2 - self.p1).dot(other.p2 - other.p1) > S.Zero) + return False + elif isinstance(other, Segment): + return other.p1 in self and other.p2 in self + + # No other known entity can be contained in a Ray + return False + + def distance(self, other): + """ + Finds the shortest distance between the ray and a point. + + Raises + ====== + + NotImplementedError is raised if `other` is not a Point + + Examples + ======== + + >>> from sympy import Point, Ray + >>> p1, p2 = Point(0, 0), Point(1, 1) + >>> s = Ray(p1, p2) + >>> s.distance(Point(-1, -1)) + sqrt(2) + >>> s.distance((-1, 2)) + 3*sqrt(2)/2 + >>> p1, p2 = Point(0, 0, 0), Point(1, 1, 2) + >>> s = Ray(p1, p2) + >>> s + Ray3D(Point3D(0, 0, 0), Point3D(1, 1, 2)) + >>> s.distance(Point(-1, -1, 2)) + 4*sqrt(3)/3 + >>> s.distance((-1, -1, 2)) + 4*sqrt(3)/3 + + """ + if not isinstance(other, GeometryEntity): + other = Point(other, dim=self.ambient_dimension) + if self.contains(other): + return S.Zero + + proj = Line(self.p1, self.p2).projection(other) + if self.contains(proj): + return abs(other - proj) + else: + return abs(other - self.source) + + def equals(self, other): + """Returns True if self and other are the same mathematical entities""" + if not isinstance(other, Ray): + return False + return self.source == other.source and other.p2 in self + + def plot_interval(self, parameter='t'): + """The plot interval for the default geometric plot of the Ray. Gives + values that will produce a ray that is 10 units long (where a unit is + the distance between the two points that define the ray). + + Parameters + ========== + + parameter : str, optional + Default value is 't'. + + Returns + ======= + + plot_interval : list + [parameter, lower_bound, upper_bound] + + Examples + ======== + + >>> from sympy import Ray, pi + >>> r = Ray((0, 0), angle=pi/4) + >>> r.plot_interval() + [t, 0, 10] + + """ + t = _symbol(parameter, real=True) + return [t, 0, 10] + + @property + def source(self): + """The point from which the ray emanates. + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Ray + >>> p1, p2 = Point(0, 0), Point(4, 1) + >>> r1 = Ray(p1, p2) + >>> r1.source + Point2D(0, 0) + >>> p1, p2 = Point(0, 0, 0), Point(4, 1, 5) + >>> r1 = Ray(p2, p1) + >>> r1.source + Point3D(4, 1, 5) + + """ + return self.p1 + + +class Segment(LinearEntity): + """A line segment in space. + + Parameters + ========== + + p1 : Point + p2 : Point + + Attributes + ========== + + length : number or SymPy expression + midpoint : Point + + See Also + ======== + + sympy.geometry.line.Segment2D + sympy.geometry.line.Segment3D + sympy.geometry.point.Point + sympy.geometry.line.Line + + Notes + ===== + + If 2D or 3D points are used to define `Segment`, it will + be automatically subclassed to `Segment2D` or `Segment3D`. + + Examples + ======== + + >>> from sympy import Point, Segment + >>> Segment((1, 0), (1, 1)) # tuples are interpreted as pts + Segment2D(Point2D(1, 0), Point2D(1, 1)) + >>> s = Segment(Point(4, 3), Point(1, 1)) + >>> s.points + (Point2D(4, 3), Point2D(1, 1)) + >>> s.slope + 2/3 + >>> s.length + sqrt(13) + >>> s.midpoint + Point2D(5/2, 2) + >>> Segment((1, 0, 0), (1, 1, 1)) # tuples are interpreted as pts + Segment3D(Point3D(1, 0, 0), Point3D(1, 1, 1)) + >>> s = Segment(Point(4, 3, 9), Point(1, 1, 7)); s + Segment3D(Point3D(4, 3, 9), Point3D(1, 1, 7)) + >>> s.points + (Point3D(4, 3, 9), Point3D(1, 1, 7)) + >>> s.length + sqrt(17) + >>> s.midpoint + Point3D(5/2, 2, 8) + + """ + def __new__(cls, p1, p2, **kwargs): + p1, p2 = Point._normalize_dimension(Point(p1), Point(p2)) + dim = len(p1) + + if dim == 2: + return Segment2D(p1, p2, **kwargs) + elif dim == 3: + return Segment3D(p1, p2, **kwargs) + return LinearEntity.__new__(cls, p1, p2, **kwargs) + + def contains(self, other): + """ + Is the other GeometryEntity contained within this Segment? + + Examples + ======== + + >>> from sympy import Point, Segment + >>> p1, p2 = Point(0, 1), Point(3, 4) + >>> s = Segment(p1, p2) + >>> s2 = Segment(p2, p1) + >>> s.contains(s2) + True + >>> from sympy import Point3D, Segment3D + >>> p1, p2 = Point3D(0, 1, 1), Point3D(3, 4, 5) + >>> s = Segment3D(p1, p2) + >>> s2 = Segment3D(p2, p1) + >>> s.contains(s2) + True + >>> s.contains((p1 + p2)/2) + True + """ + if not isinstance(other, GeometryEntity): + other = Point(other, dim=self.ambient_dimension) + if isinstance(other, Point): + if Point.is_collinear(other, self.p1, self.p2): + if isinstance(self, Segment2D): + # if it is collinear and is in the bounding box of the + # segment then it must be on the segment + vert = (1/self.slope).equals(0) + if vert is False: + isin = (self.p1.x - other.x)*(self.p2.x - other.x) <= 0 + if isin in (True, False): + return isin + if vert is True: + isin = (self.p1.y - other.y)*(self.p2.y - other.y) <= 0 + if isin in (True, False): + return isin + # use the triangle inequality + d1, d2 = other - self.p1, other - self.p2 + d = self.p2 - self.p1 + # without the call to simplify, SymPy cannot tell that an expression + # like (a+b)*(a/2+b/2) is always non-negative. If it cannot be + # determined, raise an Undecidable error + try: + # the triangle inequality says that |d1|+|d2| >= |d| and is strict + # only if other lies in the line segment + return bool(simplify(Eq(abs(d1) + abs(d2) - abs(d), 0))) + except TypeError: + raise Undecidable("Cannot determine if {} is in {}".format(other, self)) + if isinstance(other, Segment): + return other.p1 in self and other.p2 in self + + return False + + def equals(self, other): + """Returns True if self and other are the same mathematical entities""" + return isinstance(other, self.func) and list( + ordered(self.args)) == list(ordered(other.args)) + + def distance(self, other): + """ + Finds the shortest distance between a line segment and a point. + + Raises + ====== + + NotImplementedError is raised if `other` is not a Point + + Examples + ======== + + >>> from sympy import Point, Segment + >>> p1, p2 = Point(0, 1), Point(3, 4) + >>> s = Segment(p1, p2) + >>> s.distance(Point(10, 15)) + sqrt(170) + >>> s.distance((0, 12)) + sqrt(73) + >>> from sympy import Point3D, Segment3D + >>> p1, p2 = Point3D(0, 0, 3), Point3D(1, 1, 4) + >>> s = Segment3D(p1, p2) + >>> s.distance(Point3D(10, 15, 12)) + sqrt(341) + >>> s.distance((10, 15, 12)) + sqrt(341) + """ + if not isinstance(other, GeometryEntity): + other = Point(other, dim=self.ambient_dimension) + if isinstance(other, Point): + vp1 = other - self.p1 + vp2 = other - self.p2 + + dot_prod_sign_1 = self.direction.dot(vp1) >= 0 + dot_prod_sign_2 = self.direction.dot(vp2) <= 0 + if dot_prod_sign_1 and dot_prod_sign_2: + return Line(self.p1, self.p2).distance(other) + if dot_prod_sign_1 and not dot_prod_sign_2: + return abs(vp2) + if not dot_prod_sign_1 and dot_prod_sign_2: + return abs(vp1) + raise NotImplementedError() + + @property + def length(self): + """The length of the line segment. + + See Also + ======== + + sympy.geometry.point.Point.distance + + Examples + ======== + + >>> from sympy import Point, Segment + >>> p1, p2 = Point(0, 0), Point(4, 3) + >>> s1 = Segment(p1, p2) + >>> s1.length + 5 + >>> from sympy import Point3D, Segment3D + >>> p1, p2 = Point3D(0, 0, 0), Point3D(4, 3, 3) + >>> s1 = Segment3D(p1, p2) + >>> s1.length + sqrt(34) + + """ + return Point.distance(self.p1, self.p2) + + @property + def midpoint(self): + """The midpoint of the line segment. + + See Also + ======== + + sympy.geometry.point.Point.midpoint + + Examples + ======== + + >>> from sympy import Point, Segment + >>> p1, p2 = Point(0, 0), Point(4, 3) + >>> s1 = Segment(p1, p2) + >>> s1.midpoint + Point2D(2, 3/2) + >>> from sympy import Point3D, Segment3D + >>> p1, p2 = Point3D(0, 0, 0), Point3D(4, 3, 3) + >>> s1 = Segment3D(p1, p2) + >>> s1.midpoint + Point3D(2, 3/2, 3/2) + + """ + return Point.midpoint(self.p1, self.p2) + + def perpendicular_bisector(self, p=None): + """The perpendicular bisector of this segment. + + If no point is specified or the point specified is not on the + bisector then the bisector is returned as a Line. Otherwise a + Segment is returned that joins the point specified and the + intersection of the bisector and the segment. + + Parameters + ========== + + p : Point + + Returns + ======= + + bisector : Line or Segment + + See Also + ======== + + LinearEntity.perpendicular_segment + + Examples + ======== + + >>> from sympy import Point, Segment + >>> p1, p2, p3 = Point(0, 0), Point(6, 6), Point(5, 1) + >>> s1 = Segment(p1, p2) + >>> s1.perpendicular_bisector() + Line2D(Point2D(3, 3), Point2D(-3, 9)) + + >>> s1.perpendicular_bisector(p3) + Segment2D(Point2D(5, 1), Point2D(3, 3)) + + """ + l = self.perpendicular_line(self.midpoint) + if p is not None: + p2 = Point(p, dim=self.ambient_dimension) + if p2 in l: + return Segment(p2, self.midpoint) + return l + + def plot_interval(self, parameter='t'): + """The plot interval for the default geometric plot of the Segment gives + values that will produce the full segment in a plot. + + Parameters + ========== + + parameter : str, optional + Default value is 't'. + + Returns + ======= + + plot_interval : list + [parameter, lower_bound, upper_bound] + + Examples + ======== + + >>> from sympy import Point, Segment + >>> p1, p2 = Point(0, 0), Point(5, 3) + >>> s1 = Segment(p1, p2) + >>> s1.plot_interval() + [t, 0, 1] + + """ + t = _symbol(parameter, real=True) + return [t, 0, 1] + + +class LinearEntity2D(LinearEntity): + """A base class for all linear entities (line, ray and segment) + in a 2-dimensional Euclidean space. + + Attributes + ========== + + p1 + p2 + coefficients + slope + points + + Notes + ===== + + This is an abstract class and is not meant to be instantiated. + + See Also + ======== + + sympy.geometry.entity.GeometryEntity + + """ + @property + def bounds(self): + """Return a tuple (xmin, ymin, xmax, ymax) representing the bounding + rectangle for the geometric figure. + + """ + verts = self.points + xs = [p.x for p in verts] + ys = [p.y for p in verts] + return (min(xs), min(ys), max(xs), max(ys)) + + def perpendicular_line(self, p): + """Create a new Line perpendicular to this linear entity which passes + through the point `p`. + + Parameters + ========== + + p : Point + + Returns + ======= + + line : Line + + See Also + ======== + + sympy.geometry.line.LinearEntity.is_perpendicular, perpendicular_segment + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2, p3 = Point(0, 0), Point(2, 3), Point(-2, 2) + >>> L = Line(p1, p2) + >>> P = L.perpendicular_line(p3); P + Line2D(Point2D(-2, 2), Point2D(-5, 4)) + >>> L.is_perpendicular(P) + True + + In 2D, the first point of the perpendicular line is the + point through which was required to pass; the second + point is arbitrarily chosen. To get a line that explicitly + uses a point in the line, create a line from the perpendicular + segment from the line to the point: + + >>> Line(L.perpendicular_segment(p3)) + Line2D(Point2D(-2, 2), Point2D(4/13, 6/13)) + """ + p = Point(p, dim=self.ambient_dimension) + # any two lines in R^2 intersect, so blindly making + # a line through p in an orthogonal direction will work + # and is faster than finding the projection point as in 3D + return Line(p, p + self.direction.orthogonal_direction) + + @property + def slope(self): + """The slope of this linear entity, or infinity if vertical. + + Returns + ======= + + slope : number or SymPy expression + + See Also + ======== + + coefficients + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2 = Point(0, 0), Point(3, 5) + >>> l1 = Line(p1, p2) + >>> l1.slope + 5/3 + + >>> p3 = Point(0, 4) + >>> l2 = Line(p1, p3) + >>> l2.slope + oo + + """ + d1, d2 = (self.p1 - self.p2).args + if d1 == 0: + return S.Infinity + return simplify(d2/d1) + + +class Line2D(LinearEntity2D, Line): + """An infinite line in space 2D. + + A line is declared with two distinct points or a point and slope + as defined using keyword `slope`. + + Parameters + ========== + + p1 : Point + pt : Point + slope : SymPy expression + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Line, Segment, Point + >>> L = Line(Point(2,3), Point(3,5)) + >>> L + Line2D(Point2D(2, 3), Point2D(3, 5)) + >>> L.points + (Point2D(2, 3), Point2D(3, 5)) + >>> L.equation() + -2*x + y + 1 + >>> L.coefficients + (-2, 1, 1) + + Instantiate with keyword ``slope``: + + >>> Line(Point(0, 0), slope=0) + Line2D(Point2D(0, 0), Point2D(1, 0)) + + Instantiate with another linear object + + >>> s = Segment((0, 0), (0, 1)) + >>> Line(s).equation() + x + """ + def __new__(cls, p1, pt=None, slope=None, **kwargs): + if isinstance(p1, LinearEntity): + if pt is not None: + raise ValueError('When p1 is a LinearEntity, pt should be None') + p1, pt = Point._normalize_dimension(*p1.args, dim=2) + else: + p1 = Point(p1, dim=2) + if pt is not None and slope is None: + try: + p2 = Point(pt, dim=2) + except (NotImplementedError, TypeError, ValueError): + raise ValueError(filldedent(''' + The 2nd argument was not a valid Point. + If it was a slope, enter it with keyword "slope". + ''')) + elif slope is not None and pt is None: + slope = sympify(slope) + if slope.is_finite is False: + # when infinite slope, don't change x + dx = 0 + dy = 1 + else: + # go over 1 up slope + dx = 1 + dy = slope + # XXX avoiding simplification by adding to coords directly + p2 = Point(p1.x + dx, p1.y + dy, evaluate=False) + else: + raise ValueError('A 2nd Point or keyword "slope" must be used.') + return LinearEntity2D.__new__(cls, p1, p2, **kwargs) + + def _svg(self, scale_factor=1., fill_color="#66cc99"): + """Returns SVG path element for the LinearEntity. + + Parameters + ========== + + scale_factor : float + Multiplication factor for the SVG stroke-width. Default is 1. + fill_color : str, optional + Hex string for fill color. Default is "#66cc99". + """ + verts = (N(self.p1), N(self.p2)) + coords = ["{},{}".format(p.x, p.y) for p in verts] + path = "M {} L {}".format(coords[0], " L ".join(coords[1:])) + + return ( + '' + ).format(2.*scale_factor, path, fill_color) + + @property + def coefficients(self): + """The coefficients (`a`, `b`, `c`) for `ax + by + c = 0`. + + See Also + ======== + + sympy.geometry.line.Line2D.equation + + Examples + ======== + + >>> from sympy import Point, Line + >>> from sympy.abc import x, y + >>> p1, p2 = Point(0, 0), Point(5, 3) + >>> l = Line(p1, p2) + >>> l.coefficients + (-3, 5, 0) + + >>> p3 = Point(x, y) + >>> l2 = Line(p1, p3) + >>> l2.coefficients + (-y, x, 0) + + """ + p1, p2 = self.points + if p1.x == p2.x: + return (S.One, S.Zero, -p1.x) + elif p1.y == p2.y: + return (S.Zero, S.One, -p1.y) + return tuple([simplify(i) for i in + (self.p1.y - self.p2.y, + self.p2.x - self.p1.x, + self.p1.x*self.p2.y - self.p1.y*self.p2.x)]) + + def equation(self, x='x', y='y'): + """The equation of the line: ax + by + c. + + Parameters + ========== + + x : str, optional + The name to use for the x-axis, default value is 'x'. + y : str, optional + The name to use for the y-axis, default value is 'y'. + + Returns + ======= + + equation : SymPy expression + + See Also + ======== + + sympy.geometry.line.Line2D.coefficients + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2 = Point(1, 0), Point(5, 3) + >>> l1 = Line(p1, p2) + >>> l1.equation() + -3*x + 4*y + 3 + + """ + x = _symbol(x, real=True) + y = _symbol(y, real=True) + p1, p2 = self.points + if p1.x == p2.x: + return x - p1.x + elif p1.y == p2.y: + return y - p1.y + + a, b, c = self.coefficients + return a*x + b*y + c + + +class Ray2D(LinearEntity2D, Ray): + """ + A Ray is a semi-line in the space with a source point and a direction. + + Parameters + ========== + + p1 : Point + The source of the Ray + p2 : Point or radian value + This point determines the direction in which the Ray propagates. + If given as an angle it is interpreted in radians with the positive + direction being ccw. + + Attributes + ========== + + source + xdirection + ydirection + + See Also + ======== + + sympy.geometry.point.Point, Line + + Examples + ======== + + >>> from sympy import Point, pi, Ray + >>> r = Ray(Point(2, 3), Point(3, 5)) + >>> r + Ray2D(Point2D(2, 3), Point2D(3, 5)) + >>> r.points + (Point2D(2, 3), Point2D(3, 5)) + >>> r.source + Point2D(2, 3) + >>> r.xdirection + oo + >>> r.ydirection + oo + >>> r.slope + 2 + >>> Ray(Point(0, 0), angle=pi/4).slope + 1 + + """ + def __new__(cls, p1, pt=None, angle=None, **kwargs): + p1 = Point(p1, dim=2) + if pt is not None and angle is None: + try: + p2 = Point(pt, dim=2) + except (NotImplementedError, TypeError, ValueError): + raise ValueError(filldedent(''' + The 2nd argument was not a valid Point; if + it was meant to be an angle it should be + given with keyword "angle".''')) + if p1 == p2: + raise ValueError('A Ray requires two distinct points.') + elif angle is not None and pt is None: + # we need to know if the angle is an odd multiple of pi/2 + angle = sympify(angle) + c = _pi_coeff(angle) + p2 = None + if c is not None: + if c.is_Rational: + if c.q == 2: + if c.p == 1: + p2 = p1 + Point(0, 1) + elif c.p == 3: + p2 = p1 + Point(0, -1) + elif c.q == 1: + if c.p == 0: + p2 = p1 + Point(1, 0) + elif c.p == 1: + p2 = p1 + Point(-1, 0) + if p2 is None: + c *= S.Pi + else: + c = angle % (2*S.Pi) + if not p2: + m = 2*c/S.Pi + left = And(1 < m, m < 3) # is it in quadrant 2 or 3? + x = Piecewise((-1, left), (Piecewise((0, Eq(m % 1, 0)), (1, True)), True)) + y = Piecewise((-tan(c), left), (Piecewise((1, Eq(m, 1)), (-1, Eq(m, 3)), (tan(c), True)), True)) + p2 = p1 + Point(x, y) + else: + raise ValueError('A 2nd point or keyword "angle" must be used.') + + return LinearEntity2D.__new__(cls, p1, p2, **kwargs) + + @property + def xdirection(self): + """The x direction of the ray. + + Positive infinity if the ray points in the positive x direction, + negative infinity if the ray points in the negative x direction, + or 0 if the ray is vertical. + + See Also + ======== + + ydirection + + Examples + ======== + + >>> from sympy import Point, Ray + >>> p1, p2, p3 = Point(0, 0), Point(1, 1), Point(0, -1) + >>> r1, r2 = Ray(p1, p2), Ray(p1, p3) + >>> r1.xdirection + oo + >>> r2.xdirection + 0 + + """ + if self.p1.x < self.p2.x: + return S.Infinity + elif self.p1.x == self.p2.x: + return S.Zero + else: + return S.NegativeInfinity + + @property + def ydirection(self): + """The y direction of the ray. + + Positive infinity if the ray points in the positive y direction, + negative infinity if the ray points in the negative y direction, + or 0 if the ray is horizontal. + + See Also + ======== + + xdirection + + Examples + ======== + + >>> from sympy import Point, Ray + >>> p1, p2, p3 = Point(0, 0), Point(-1, -1), Point(-1, 0) + >>> r1, r2 = Ray(p1, p2), Ray(p1, p3) + >>> r1.ydirection + -oo + >>> r2.ydirection + 0 + + """ + if self.p1.y < self.p2.y: + return S.Infinity + elif self.p1.y == self.p2.y: + return S.Zero + else: + return S.NegativeInfinity + + def closing_angle(r1, r2): + """Return the angle by which r2 must be rotated so it faces the same + direction as r1. + + Parameters + ========== + + r1 : Ray2D + r2 : Ray2D + + Returns + ======= + + angle : angle in radians (ccw angle is positive) + + See Also + ======== + + LinearEntity.angle_between + + Examples + ======== + + >>> from sympy import Ray, pi + >>> r1 = Ray((0, 0), (1, 0)) + >>> r2 = r1.rotate(-pi/2) + >>> angle = r1.closing_angle(r2); angle + pi/2 + >>> r2.rotate(angle).direction.unit == r1.direction.unit + True + >>> r2.closing_angle(r1) + -pi/2 + """ + if not all(isinstance(r, Ray2D) for r in (r1, r2)): + # although the direction property is defined for + # all linear entities, only the Ray is truly a + # directed object + raise TypeError('Both arguments must be Ray2D objects.') + + a1 = atan2(*list(reversed(r1.direction.args))) + a2 = atan2(*list(reversed(r2.direction.args))) + if a1*a2 < 0: + a1 = 2*S.Pi + a1 if a1 < 0 else a1 + a2 = 2*S.Pi + a2 if a2 < 0 else a2 + return a1 - a2 + + +class Segment2D(LinearEntity2D, Segment): + """A line segment in 2D space. + + Parameters + ========== + + p1 : Point + p2 : Point + + Attributes + ========== + + length : number or SymPy expression + midpoint : Point + + See Also + ======== + + sympy.geometry.point.Point, Line + + Examples + ======== + + >>> from sympy import Point, Segment + >>> Segment((1, 0), (1, 1)) # tuples are interpreted as pts + Segment2D(Point2D(1, 0), Point2D(1, 1)) + >>> s = Segment(Point(4, 3), Point(1, 1)); s + Segment2D(Point2D(4, 3), Point2D(1, 1)) + >>> s.points + (Point2D(4, 3), Point2D(1, 1)) + >>> s.slope + 2/3 + >>> s.length + sqrt(13) + >>> s.midpoint + Point2D(5/2, 2) + + """ + def __new__(cls, p1, p2, **kwargs): + p1 = Point(p1, dim=2) + p2 = Point(p2, dim=2) + + if p1 == p2: + return p1 + + return LinearEntity2D.__new__(cls, p1, p2, **kwargs) + + def _svg(self, scale_factor=1., fill_color="#66cc99"): + """Returns SVG path element for the LinearEntity. + + Parameters + ========== + + scale_factor : float + Multiplication factor for the SVG stroke-width. Default is 1. + fill_color : str, optional + Hex string for fill color. Default is "#66cc99". + """ + verts = (N(self.p1), N(self.p2)) + coords = ["{},{}".format(p.x, p.y) for p in verts] + path = "M {} L {}".format(coords[0], " L ".join(coords[1:])) + return ( + '' + ).format(2.*scale_factor, path, fill_color) + + +class LinearEntity3D(LinearEntity): + """An base class for all linear entities (line, ray and segment) + in a 3-dimensional Euclidean space. + + Attributes + ========== + + p1 + p2 + direction_ratio + direction_cosine + points + + Notes + ===== + + This is a base class and is not meant to be instantiated. + """ + def __new__(cls, p1, p2, **kwargs): + p1 = Point3D(p1, dim=3) + p2 = Point3D(p2, dim=3) + if p1 == p2: + # if it makes sense to return a Point, handle in subclass + raise ValueError( + "%s.__new__ requires two unique Points." % cls.__name__) + + return GeometryEntity.__new__(cls, p1, p2, **kwargs) + + ambient_dimension = 3 + + @property + def direction_ratio(self): + """The direction ratio of a given line in 3D. + + See Also + ======== + + sympy.geometry.line.Line3D.equation + + Examples + ======== + + >>> from sympy import Point3D, Line3D + >>> p1, p2 = Point3D(0, 0, 0), Point3D(5, 3, 1) + >>> l = Line3D(p1, p2) + >>> l.direction_ratio + [5, 3, 1] + """ + p1, p2 = self.points + return p1.direction_ratio(p2) + + @property + def direction_cosine(self): + """The normalized direction ratio of a given line in 3D. + + See Also + ======== + + sympy.geometry.line.Line3D.equation + + Examples + ======== + + >>> from sympy import Point3D, Line3D + >>> p1, p2 = Point3D(0, 0, 0), Point3D(5, 3, 1) + >>> l = Line3D(p1, p2) + >>> l.direction_cosine + [sqrt(35)/7, 3*sqrt(35)/35, sqrt(35)/35] + >>> sum(i**2 for i in _) + 1 + """ + p1, p2 = self.points + return p1.direction_cosine(p2) + + +class Line3D(LinearEntity3D, Line): + """An infinite 3D line in space. + + A line is declared with two distinct points or a point and direction_ratio + as defined using keyword `direction_ratio`. + + Parameters + ========== + + p1 : Point3D + pt : Point3D + direction_ratio : list + + See Also + ======== + + sympy.geometry.point.Point3D + sympy.geometry.line.Line + sympy.geometry.line.Line2D + + Examples + ======== + + >>> from sympy import Line3D, Point3D + >>> L = Line3D(Point3D(2, 3, 4), Point3D(3, 5, 1)) + >>> L + Line3D(Point3D(2, 3, 4), Point3D(3, 5, 1)) + >>> L.points + (Point3D(2, 3, 4), Point3D(3, 5, 1)) + """ + def __new__(cls, p1, pt=None, direction_ratio=(), **kwargs): + if isinstance(p1, LinearEntity3D): + if pt is not None: + raise ValueError('if p1 is a LinearEntity, pt must be None.') + p1, pt = p1.args + else: + p1 = Point(p1, dim=3) + if pt is not None and len(direction_ratio) == 0: + pt = Point(pt, dim=3) + elif len(direction_ratio) == 3 and pt is None: + pt = Point3D(p1.x + direction_ratio[0], p1.y + direction_ratio[1], + p1.z + direction_ratio[2]) + else: + raise ValueError('A 2nd Point or keyword "direction_ratio" must ' + 'be used.') + + return LinearEntity3D.__new__(cls, p1, pt, **kwargs) + + def equation(self, x='x', y='y', z='z'): + """Return the equations that define the line in 3D. + + Parameters + ========== + + x : str, optional + The name to use for the x-axis, default value is 'x'. + y : str, optional + The name to use for the y-axis, default value is 'y'. + z : str, optional + The name to use for the z-axis, default value is 'z'. + + Returns + ======= + + equation : Tuple of simultaneous equations + + Examples + ======== + + >>> from sympy import Point3D, Line3D, solve + >>> from sympy.abc import x, y, z + >>> p1, p2 = Point3D(1, 0, 0), Point3D(5, 3, 0) + >>> l1 = Line3D(p1, p2) + >>> eq = l1.equation(x, y, z); eq + (-3*x + 4*y + 3, z) + >>> solve(eq.subs(z, 0), (x, y, z)) + {x: 4*y/3 + 1} + """ + x, y, z, k = [_symbol(i, real=True) for i in (x, y, z, 'k')] + p1, p2 = self.points + d1, d2, d3 = p1.direction_ratio(p2) + x1, y1, z1 = p1 + eqs = [-d1*k + x - x1, -d2*k + y - y1, -d3*k + z - z1] + # eliminate k from equations by solving first eq with k for k + for i, e in enumerate(eqs): + if e.has(k): + kk = solve(e, k)[0] + eqs.pop(i) + break + return Tuple(*[i.subs(k, kk).as_numer_denom()[0] for i in eqs]) + + def distance(self, other): + """ + Finds the shortest distance between a line and another object. + + Parameters + ========== + + Point3D, Line3D, Plane, tuple, list + + Returns + ======= + + distance + + Notes + ===== + + This method accepts only 3D entities as it's parameter + + Tuples and lists are converted to Point3D and therefore must be of + length 3, 2 or 1. + + NotImplementedError is raised if `other` is not an instance of one + of the specified classes: Point3D, Line3D, or Plane. + + Examples + ======== + + >>> from sympy.geometry import Line3D + >>> l1 = Line3D((0, 0, 0), (0, 0, 1)) + >>> l2 = Line3D((0, 1, 0), (1, 1, 1)) + >>> l1.distance(l2) + 1 + + The computed distance may be symbolic, too: + + >>> from sympy.abc import x, y + >>> l1 = Line3D((0, 0, 0), (0, 0, 1)) + >>> l2 = Line3D((0, x, 0), (y, x, 1)) + >>> l1.distance(l2) + Abs(x*y)/Abs(sqrt(y**2)) + + """ + + from .plane import Plane # Avoid circular import + + if isinstance(other, (tuple, list)): + try: + other = Point3D(other) + except ValueError: + pass + + if isinstance(other, Point3D): + return super().distance(other) + + if isinstance(other, Line3D): + if self == other: + return S.Zero + if self.is_parallel(other): + return super().distance(other.p1) + + # Skew lines + self_direction = Matrix(self.direction_ratio) + other_direction = Matrix(other.direction_ratio) + normal = self_direction.cross(other_direction) + plane_through_self = Plane(p1=self.p1, normal_vector=normal) + return other.p1.distance(plane_through_self) + + if isinstance(other, Plane): + return other.distance(self) + + msg = f"{other} has type {type(other)}, which is unsupported" + raise NotImplementedError(msg) + + +class Ray3D(LinearEntity3D, Ray): + """ + A Ray is a semi-line in the space with a source point and a direction. + + Parameters + ========== + + p1 : Point3D + The source of the Ray + p2 : Point or a direction vector + direction_ratio: Determines the direction in which the Ray propagates. + + + Attributes + ========== + + source + xdirection + ydirection + zdirection + + See Also + ======== + + sympy.geometry.point.Point3D, Line3D + + + Examples + ======== + + >>> from sympy import Point3D, Ray3D + >>> r = Ray3D(Point3D(2, 3, 4), Point3D(3, 5, 0)) + >>> r + Ray3D(Point3D(2, 3, 4), Point3D(3, 5, 0)) + >>> r.points + (Point3D(2, 3, 4), Point3D(3, 5, 0)) + >>> r.source + Point3D(2, 3, 4) + >>> r.xdirection + oo + >>> r.ydirection + oo + >>> r.direction_ratio + [1, 2, -4] + + """ + def __new__(cls, p1, pt=None, direction_ratio=(), **kwargs): + if isinstance(p1, LinearEntity3D): + if pt is not None: + raise ValueError('If p1 is a LinearEntity, pt must be None') + p1, pt = p1.args + else: + p1 = Point(p1, dim=3) + if pt is not None and len(direction_ratio) == 0: + pt = Point(pt, dim=3) + elif len(direction_ratio) == 3 and pt is None: + pt = Point3D(p1.x + direction_ratio[0], p1.y + direction_ratio[1], + p1.z + direction_ratio[2]) + else: + raise ValueError(filldedent(''' + A 2nd Point or keyword "direction_ratio" must be used. + ''')) + + return LinearEntity3D.__new__(cls, p1, pt, **kwargs) + + @property + def xdirection(self): + """The x direction of the ray. + + Positive infinity if the ray points in the positive x direction, + negative infinity if the ray points in the negative x direction, + or 0 if the ray is vertical. + + See Also + ======== + + ydirection + + Examples + ======== + + >>> from sympy import Point3D, Ray3D + >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(1, 1, 1), Point3D(0, -1, 0) + >>> r1, r2 = Ray3D(p1, p2), Ray3D(p1, p3) + >>> r1.xdirection + oo + >>> r2.xdirection + 0 + + """ + if self.p1.x < self.p2.x: + return S.Infinity + elif self.p1.x == self.p2.x: + return S.Zero + else: + return S.NegativeInfinity + + @property + def ydirection(self): + """The y direction of the ray. + + Positive infinity if the ray points in the positive y direction, + negative infinity if the ray points in the negative y direction, + or 0 if the ray is horizontal. + + See Also + ======== + + xdirection + + Examples + ======== + + >>> from sympy import Point3D, Ray3D + >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(-1, -1, -1), Point3D(-1, 0, 0) + >>> r1, r2 = Ray3D(p1, p2), Ray3D(p1, p3) + >>> r1.ydirection + -oo + >>> r2.ydirection + 0 + + """ + if self.p1.y < self.p2.y: + return S.Infinity + elif self.p1.y == self.p2.y: + return S.Zero + else: + return S.NegativeInfinity + + @property + def zdirection(self): + """The z direction of the ray. + + Positive infinity if the ray points in the positive z direction, + negative infinity if the ray points in the negative z direction, + or 0 if the ray is horizontal. + + See Also + ======== + + xdirection + + Examples + ======== + + >>> from sympy import Point3D, Ray3D + >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(-1, -1, -1), Point3D(-1, 0, 0) + >>> r1, r2 = Ray3D(p1, p2), Ray3D(p1, p3) + >>> r1.ydirection + -oo + >>> r2.ydirection + 0 + >>> r2.zdirection + 0 + + """ + if self.p1.z < self.p2.z: + return S.Infinity + elif self.p1.z == self.p2.z: + return S.Zero + else: + return S.NegativeInfinity + + +class Segment3D(LinearEntity3D, Segment): + """A line segment in a 3D space. + + Parameters + ========== + + p1 : Point3D + p2 : Point3D + + Attributes + ========== + + length : number or SymPy expression + midpoint : Point3D + + See Also + ======== + + sympy.geometry.point.Point3D, Line3D + + Examples + ======== + + >>> from sympy import Point3D, Segment3D + >>> Segment3D((1, 0, 0), (1, 1, 1)) # tuples are interpreted as pts + Segment3D(Point3D(1, 0, 0), Point3D(1, 1, 1)) + >>> s = Segment3D(Point3D(4, 3, 9), Point3D(1, 1, 7)); s + Segment3D(Point3D(4, 3, 9), Point3D(1, 1, 7)) + >>> s.points + (Point3D(4, 3, 9), Point3D(1, 1, 7)) + >>> s.length + sqrt(17) + >>> s.midpoint + Point3D(5/2, 2, 8) + + """ + def __new__(cls, p1, p2, **kwargs): + p1 = Point(p1, dim=3) + p2 = Point(p2, dim=3) + + if p1 == p2: + return p1 + + return LinearEntity3D.__new__(cls, p1, p2, **kwargs) diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/parabola.py b/.venv/lib/python3.13/site-packages/sympy/geometry/parabola.py new file mode 100644 index 0000000000000000000000000000000000000000..183c593785bb610e6f451a0c87abb2aa34d22494 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/parabola.py @@ -0,0 +1,422 @@ +"""Parabolic geometrical entity. + +Contains +* Parabola + +""" + +from sympy.core import S +from sympy.core.sorting import ordered +from sympy.core.symbol import _symbol, symbols +from sympy.geometry.entity import GeometryEntity, GeometrySet +from sympy.geometry.point import Point, Point2D +from sympy.geometry.line import Line, Line2D, Ray2D, Segment2D, LinearEntity3D +from sympy.geometry.ellipse import Ellipse +from sympy.functions import sign +from sympy.simplify.simplify import simplify +from sympy.solvers.solvers import solve + + +class Parabola(GeometrySet): + """A parabolic GeometryEntity. + + A parabola is declared with a point, that is called 'focus', and + a line, that is called 'directrix'. + Only vertical or horizontal parabolas are currently supported. + + Parameters + ========== + + focus : Point + Default value is Point(0, 0) + directrix : Line + + Attributes + ========== + + focus + directrix + axis of symmetry + focal length + p parameter + vertex + eccentricity + + Raises + ====== + ValueError + When `focus` is not a two dimensional point. + When `focus` is a point of directrix. + NotImplementedError + When `directrix` is neither horizontal nor vertical. + + Examples + ======== + + >>> from sympy import Parabola, Point, Line + >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7,8))) + >>> p1.focus + Point2D(0, 0) + >>> p1.directrix + Line2D(Point2D(5, 8), Point2D(7, 8)) + + """ + + def __new__(cls, focus=None, directrix=None, **kwargs): + + if focus: + focus = Point(focus, dim=2) + else: + focus = Point(0, 0) + + directrix = Line(directrix) + + if directrix.contains(focus): + raise ValueError('The focus must not be a point of directrix') + + return GeometryEntity.__new__(cls, focus, directrix, **kwargs) + + @property + def ambient_dimension(self): + """Returns the ambient dimension of parabola. + + Returns + ======= + + ambient_dimension : integer + + Examples + ======== + + >>> from sympy import Parabola, Point, Line + >>> f1 = Point(0, 0) + >>> p1 = Parabola(f1, Line(Point(5, 8), Point(7, 8))) + >>> p1.ambient_dimension + 2 + + """ + return 2 + + @property + def axis_of_symmetry(self): + """Return the axis of symmetry of the parabola: a line + perpendicular to the directrix passing through the focus. + + Returns + ======= + + axis_of_symmetry : Line + + See Also + ======== + + sympy.geometry.line.Line + + Examples + ======== + + >>> from sympy import Parabola, Point, Line + >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7, 8))) + >>> p1.axis_of_symmetry + Line2D(Point2D(0, 0), Point2D(0, 1)) + + """ + return self.directrix.perpendicular_line(self.focus) + + @property + def directrix(self): + """The directrix of the parabola. + + Returns + ======= + + directrix : Line + + See Also + ======== + + sympy.geometry.line.Line + + Examples + ======== + + >>> from sympy import Parabola, Point, Line + >>> l1 = Line(Point(5, 8), Point(7, 8)) + >>> p1 = Parabola(Point(0, 0), l1) + >>> p1.directrix + Line2D(Point2D(5, 8), Point2D(7, 8)) + + """ + return self.args[1] + + @property + def eccentricity(self): + """The eccentricity of the parabola. + + Returns + ======= + + eccentricity : number + + A parabola may also be characterized as a conic section with an + eccentricity of 1. As a consequence of this, all parabolas are + similar, meaning that while they can be different sizes, + they are all the same shape. + + See Also + ======== + + https://en.wikipedia.org/wiki/Parabola + + + Examples + ======== + + >>> from sympy import Parabola, Point, Line + >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7, 8))) + >>> p1.eccentricity + 1 + + Notes + ----- + The eccentricity for every Parabola is 1 by definition. + + """ + return S.One + + def equation(self, x='x', y='y'): + """The equation of the parabola. + + Parameters + ========== + x : str, optional + Label for the x-axis. Default value is 'x'. + y : str, optional + Label for the y-axis. Default value is 'y'. + + Returns + ======= + equation : SymPy expression + + Examples + ======== + + >>> from sympy import Parabola, Point, Line + >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7, 8))) + >>> p1.equation() + -x**2 - 16*y + 64 + >>> p1.equation('f') + -f**2 - 16*y + 64 + >>> p1.equation(y='z') + -x**2 - 16*z + 64 + + """ + x = _symbol(x, real=True) + y = _symbol(y, real=True) + + m = self.directrix.slope + if m is S.Infinity: + t1 = 4 * (self.p_parameter) * (x - self.vertex.x) + t2 = (y - self.vertex.y)**2 + elif m == 0: + t1 = 4 * (self.p_parameter) * (y - self.vertex.y) + t2 = (x - self.vertex.x)**2 + else: + a, b = self.focus + c, d = self.directrix.coefficients[:2] + t1 = (x - a)**2 + (y - b)**2 + t2 = self.directrix.equation(x, y)**2/(c**2 + d**2) + return t1 - t2 + + @property + def focal_length(self): + """The focal length of the parabola. + + Returns + ======= + + focal_lenght : number or symbolic expression + + Notes + ===== + + The distance between the vertex and the focus + (or the vertex and directrix), measured along the axis + of symmetry, is the "focal length". + + See Also + ======== + + https://en.wikipedia.org/wiki/Parabola + + Examples + ======== + + >>> from sympy import Parabola, Point, Line + >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7, 8))) + >>> p1.focal_length + 4 + + """ + distance = self.directrix.distance(self.focus) + focal_length = distance/2 + + return focal_length + + @property + def focus(self): + """The focus of the parabola. + + Returns + ======= + + focus : Point + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Parabola, Point, Line + >>> f1 = Point(0, 0) + >>> p1 = Parabola(f1, Line(Point(5, 8), Point(7, 8))) + >>> p1.focus + Point2D(0, 0) + + """ + return self.args[0] + + def intersection(self, o): + """The intersection of the parabola and another geometrical entity `o`. + + Parameters + ========== + + o : GeometryEntity, LinearEntity + + Returns + ======= + + intersection : list of GeometryEntity objects + + Examples + ======== + + >>> from sympy import Parabola, Point, Ellipse, Line, Segment + >>> p1 = Point(0,0) + >>> l1 = Line(Point(1, -2), Point(-1,-2)) + >>> parabola1 = Parabola(p1, l1) + >>> parabola1.intersection(Ellipse(Point(0, 0), 2, 5)) + [Point2D(-2, 0), Point2D(2, 0)] + >>> parabola1.intersection(Line(Point(-7, 3), Point(12, 3))) + [Point2D(-4, 3), Point2D(4, 3)] + >>> parabola1.intersection(Segment((-12, -65), (14, -68))) + [] + + """ + x, y = symbols('x y', real=True) + parabola_eq = self.equation() + if isinstance(o, Parabola): + if o in self: + return [o] + else: + return list(ordered([Point(i) for i in solve( + [parabola_eq, o.equation()], [x, y], set=True)[1]])) + elif isinstance(o, Point2D): + if simplify(parabola_eq.subs([(x, o._args[0]), (y, o._args[1])])) == 0: + return [o] + else: + return [] + elif isinstance(o, (Segment2D, Ray2D)): + result = solve([parabola_eq, + Line2D(o.points[0], o.points[1]).equation()], + [x, y], set=True)[1] + return list(ordered([Point2D(i) for i in result if i in o])) + elif isinstance(o, (Line2D, Ellipse)): + return list(ordered([Point2D(i) for i in solve( + [parabola_eq, o.equation()], [x, y], set=True)[1]])) + elif isinstance(o, LinearEntity3D): + raise TypeError('Entity must be two dimensional, not three dimensional') + else: + raise TypeError('Wrong type of argument were put') + + @property + def p_parameter(self): + """P is a parameter of parabola. + + Returns + ======= + + p : number or symbolic expression + + Notes + ===== + + The absolute value of p is the focal length. The sign on p tells + which way the parabola faces. Vertical parabolas that open up + and horizontal that open right, give a positive value for p. + Vertical parabolas that open down and horizontal that open left, + give a negative value for p. + + + See Also + ======== + + https://www.sparknotes.com/math/precalc/conicsections/section2/ + + Examples + ======== + + >>> from sympy import Parabola, Point, Line + >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7, 8))) + >>> p1.p_parameter + -4 + + """ + m = self.directrix.slope + if m is S.Infinity: + x = self.directrix.coefficients[2] + p = sign(self.focus.args[0] + x) + elif m == 0: + y = self.directrix.coefficients[2] + p = sign(self.focus.args[1] + y) + else: + d = self.directrix.projection(self.focus) + p = sign(self.focus.x - d.x) + return p * self.focal_length + + @property + def vertex(self): + """The vertex of the parabola. + + Returns + ======= + + vertex : Point + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Parabola, Point, Line + >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7, 8))) + >>> p1.vertex + Point2D(0, 4) + + """ + focus = self.focus + m = self.directrix.slope + if m is S.Infinity: + vertex = Point(focus.args[0] - self.p_parameter, focus.args[1]) + elif m == 0: + vertex = Point(focus.args[0], focus.args[1] - self.p_parameter) + else: + vertex = self.axis_of_symmetry.intersection(self)[0] + return vertex diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/plane.py b/.venv/lib/python3.13/site-packages/sympy/geometry/plane.py new file mode 100644 index 0000000000000000000000000000000000000000..509dc4be5dc41c5df7c33561fdbe5bb0b6620352 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/plane.py @@ -0,0 +1,878 @@ +"""Geometrical Planes. + +Contains +======== +Plane + +""" + +from sympy.core import Dummy, Rational, S, Symbol +from sympy.core.symbol import _symbol +from sympy.functions.elementary.trigonometric import cos, sin, acos, asin, sqrt +from .entity import GeometryEntity +from .line import (Line, Ray, Segment, Line3D, LinearEntity, LinearEntity3D, + Ray3D, Segment3D) +from .point import Point, Point3D +from sympy.matrices import Matrix +from sympy.polys.polytools import cancel +from sympy.solvers import solve, linsolve +from sympy.utilities.iterables import uniq, is_sequence +from sympy.utilities.misc import filldedent, func_name, Undecidable + +from mpmath.libmp.libmpf import prec_to_dps + +import random + + +x, y, z, t = [Dummy('plane_dummy') for i in range(4)] + + +class Plane(GeometryEntity): + """ + A plane is a flat, two-dimensional surface. A plane is the two-dimensional + analogue of a point (zero-dimensions), a line (one-dimension) and a solid + (three-dimensions). A plane can generally be constructed by two types of + inputs. They are: + - three non-collinear points + - a point and the plane's normal vector + + Attributes + ========== + + p1 + normal_vector + + Examples + ======== + + >>> from sympy import Plane, Point3D + >>> Plane(Point3D(1, 1, 1), Point3D(2, 3, 4), Point3D(2, 2, 2)) + Plane(Point3D(1, 1, 1), (-1, 2, -1)) + >>> Plane((1, 1, 1), (2, 3, 4), (2, 2, 2)) + Plane(Point3D(1, 1, 1), (-1, 2, -1)) + >>> Plane(Point3D(1, 1, 1), normal_vector=(1,4,7)) + Plane(Point3D(1, 1, 1), (1, 4, 7)) + + """ + def __new__(cls, p1, a=None, b=None, **kwargs): + p1 = Point3D(p1, dim=3) + if a and b: + p2 = Point(a, dim=3) + p3 = Point(b, dim=3) + if Point3D.are_collinear(p1, p2, p3): + raise ValueError('Enter three non-collinear points') + a = p1.direction_ratio(p2) + b = p1.direction_ratio(p3) + normal_vector = tuple(Matrix(a).cross(Matrix(b))) + else: + a = kwargs.pop('normal_vector', a) + evaluate = kwargs.get('evaluate', True) + if is_sequence(a) and len(a) == 3: + normal_vector = Point3D(a).args if evaluate else a + else: + raise ValueError(filldedent(''' + Either provide 3 3D points or a point with a + normal vector expressed as a sequence of length 3''')) + if all(coord.is_zero for coord in normal_vector): + raise ValueError('Normal vector cannot be zero vector') + return GeometryEntity.__new__(cls, p1, normal_vector, **kwargs) + + def __contains__(self, o): + k = self.equation(x, y, z) + if isinstance(o, (LinearEntity, LinearEntity3D)): + d = Point3D(o.arbitrary_point(t)) + e = k.subs([(x, d.x), (y, d.y), (z, d.z)]) + return e.equals(0) + try: + o = Point(o, dim=3, strict=True) + d = k.xreplace(dict(zip((x, y, z), o.args))) + return d.equals(0) + except TypeError: + return False + + def _eval_evalf(self, prec=15, **options): + pt, tup = self.args + dps = prec_to_dps(prec) + pt = pt.evalf(n=dps, **options) + tup = tuple([i.evalf(n=dps, **options) for i in tup]) + return self.func(pt, normal_vector=tup, evaluate=False) + + def angle_between(self, o): + """Angle between the plane and other geometric entity. + + Parameters + ========== + + LinearEntity3D, Plane. + + Returns + ======= + + angle : angle in radians + + Notes + ===== + + This method accepts only 3D entities as it's parameter, but if you want + to calculate the angle between a 2D entity and a plane you should + first convert to a 3D entity by projecting onto a desired plane and + then proceed to calculate the angle. + + Examples + ======== + + >>> from sympy import Point3D, Line3D, Plane + >>> a = Plane(Point3D(1, 2, 2), normal_vector=(1, 2, 3)) + >>> b = Line3D(Point3D(1, 3, 4), Point3D(2, 2, 2)) + >>> a.angle_between(b) + -asin(sqrt(21)/6) + + """ + if isinstance(o, LinearEntity3D): + a = Matrix(self.normal_vector) + b = Matrix(o.direction_ratio) + c = a.dot(b) + d = sqrt(sum(i**2 for i in self.normal_vector)) + e = sqrt(sum(i**2 for i in o.direction_ratio)) + return asin(c/(d*e)) + if isinstance(o, Plane): + a = Matrix(self.normal_vector) + b = Matrix(o.normal_vector) + c = a.dot(b) + d = sqrt(sum(i**2 for i in self.normal_vector)) + e = sqrt(sum(i**2 for i in o.normal_vector)) + return acos(c/(d*e)) + + + def arbitrary_point(self, u=None, v=None): + """ Returns an arbitrary point on the Plane. If given two + parameters, the point ranges over the entire plane. If given 1 + or no parameters, returns a point with one parameter which, + when varying from 0 to 2*pi, moves the point in a circle of + radius 1 about p1 of the Plane. + + Examples + ======== + + >>> from sympy import Plane, Ray + >>> from sympy.abc import u, v, t, r + >>> p = Plane((1, 1, 1), normal_vector=(1, 0, 0)) + >>> p.arbitrary_point(u, v) + Point3D(1, u + 1, v + 1) + >>> p.arbitrary_point(t) + Point3D(1, cos(t) + 1, sin(t) + 1) + + While arbitrary values of u and v can move the point anywhere in + the plane, the single-parameter point can be used to construct a + ray whose arbitrary point can be located at angle t and radius + r from p.p1: + + >>> Ray(p.p1, _).arbitrary_point(r) + Point3D(1, r*cos(t) + 1, r*sin(t) + 1) + + Returns + ======= + + Point3D + + """ + circle = v is None + if circle: + u = _symbol(u or 't', real=True) + else: + u = _symbol(u or 'u', real=True) + v = _symbol(v or 'v', real=True) + x, y, z = self.normal_vector + a, b, c = self.p1.args + # x1, y1, z1 is a nonzero vector parallel to the plane + if x.is_zero and y.is_zero: + x1, y1, z1 = S.One, S.Zero, S.Zero + else: + x1, y1, z1 = -y, x, S.Zero + # x2, y2, z2 is also parallel to the plane, and orthogonal to x1, y1, z1 + x2, y2, z2 = tuple(Matrix((x, y, z)).cross(Matrix((x1, y1, z1)))) + if circle: + x1, y1, z1 = (w/sqrt(x1**2 + y1**2 + z1**2) for w in (x1, y1, z1)) + x2, y2, z2 = (w/sqrt(x2**2 + y2**2 + z2**2) for w in (x2, y2, z2)) + p = Point3D(a + x1*cos(u) + x2*sin(u), \ + b + y1*cos(u) + y2*sin(u), \ + c + z1*cos(u) + z2*sin(u)) + else: + p = Point3D(a + x1*u + x2*v, b + y1*u + y2*v, c + z1*u + z2*v) + return p + + + @staticmethod + def are_concurrent(*planes): + """Is a sequence of Planes concurrent? + + Two or more Planes are concurrent if their intersections + are a common line. + + Parameters + ========== + + planes: list + + Returns + ======= + + Boolean + + Examples + ======== + + >>> from sympy import Plane, Point3D + >>> a = Plane(Point3D(5, 0, 0), normal_vector=(1, -1, 1)) + >>> b = Plane(Point3D(0, -2, 0), normal_vector=(3, 1, 1)) + >>> c = Plane(Point3D(0, -1, 0), normal_vector=(5, -1, 9)) + >>> Plane.are_concurrent(a, b) + True + >>> Plane.are_concurrent(a, b, c) + False + + """ + planes = list(uniq(planes)) + for i in planes: + if not isinstance(i, Plane): + raise ValueError('All objects should be Planes but got %s' % i.func) + if len(planes) < 2: + return False + planes = list(planes) + first = planes.pop(0) + sol = first.intersection(planes[0]) + if sol == []: + return False + else: + line = sol[0] + for i in planes[1:]: + l = first.intersection(i) + if not l or l[0] not in line: + return False + return True + + + def distance(self, o): + """Distance between the plane and another geometric entity. + + Parameters + ========== + + Point3D, LinearEntity3D, Plane. + + Returns + ======= + + distance + + Notes + ===== + + This method accepts only 3D entities as it's parameter, but if you want + to calculate the distance between a 2D entity and a plane you should + first convert to a 3D entity by projecting onto a desired plane and + then proceed to calculate the distance. + + Examples + ======== + + >>> from sympy import Point3D, Line3D, Plane + >>> a = Plane(Point3D(1, 1, 1), normal_vector=(1, 1, 1)) + >>> b = Point3D(1, 2, 3) + >>> a.distance(b) + sqrt(3) + >>> c = Line3D(Point3D(2, 3, 1), Point3D(1, 2, 2)) + >>> a.distance(c) + 0 + + """ + if self.intersection(o) != []: + return S.Zero + + if isinstance(o, (Segment3D, Ray3D)): + a, b = o.p1, o.p2 + pi, = self.intersection(Line3D(a, b)) + if pi in o: + return self.distance(pi) + elif a in Segment3D(pi, b): + return self.distance(a) + else: + assert isinstance(o, Segment3D) is True + return self.distance(b) + + # following code handles `Point3D`, `LinearEntity3D`, `Plane` + a = o if isinstance(o, Point3D) else o.p1 + n = Point3D(self.normal_vector).unit + d = (a - self.p1).dot(n) + return abs(d) + + + def equals(self, o): + """ + Returns True if self and o are the same mathematical entities. + + Examples + ======== + + >>> from sympy import Plane, Point3D + >>> a = Plane(Point3D(1, 2, 3), normal_vector=(1, 1, 1)) + >>> b = Plane(Point3D(1, 2, 3), normal_vector=(2, 2, 2)) + >>> c = Plane(Point3D(1, 2, 3), normal_vector=(-1, 4, 6)) + >>> a.equals(a) + True + >>> a.equals(b) + True + >>> a.equals(c) + False + """ + if isinstance(o, Plane): + a = self.equation() + b = o.equation() + return cancel(a/b).is_constant() + else: + return False + + + def equation(self, x=None, y=None, z=None): + """The equation of the Plane. + + Examples + ======== + + >>> from sympy import Point3D, Plane + >>> a = Plane(Point3D(1, 1, 2), Point3D(2, 4, 7), Point3D(3, 5, 1)) + >>> a.equation() + -23*x + 11*y - 2*z + 16 + >>> a = Plane(Point3D(1, 4, 2), normal_vector=(6, 6, 6)) + >>> a.equation() + 6*x + 6*y + 6*z - 42 + + """ + x, y, z = [i if i else Symbol(j, real=True) for i, j in zip((x, y, z), 'xyz')] + a = Point3D(x, y, z) + b = self.p1.direction_ratio(a) + c = self.normal_vector + return (sum(i*j for i, j in zip(b, c))) + + + def intersection(self, o): + """ The intersection with other geometrical entity. + + Parameters + ========== + + Point, Point3D, LinearEntity, LinearEntity3D, Plane + + Returns + ======= + + List + + Examples + ======== + + >>> from sympy import Point3D, Line3D, Plane + >>> a = Plane(Point3D(1, 2, 3), normal_vector=(1, 1, 1)) + >>> b = Point3D(1, 2, 3) + >>> a.intersection(b) + [Point3D(1, 2, 3)] + >>> c = Line3D(Point3D(1, 4, 7), Point3D(2, 2, 2)) + >>> a.intersection(c) + [Point3D(2, 2, 2)] + >>> d = Plane(Point3D(6, 0, 0), normal_vector=(2, -5, 3)) + >>> e = Plane(Point3D(2, 0, 0), normal_vector=(3, 4, -3)) + >>> d.intersection(e) + [Line3D(Point3D(78/23, -24/23, 0), Point3D(147/23, 321/23, 23))] + + """ + if not isinstance(o, GeometryEntity): + o = Point(o, dim=3) + if isinstance(o, Point): + if o in self: + return [o] + else: + return [] + if isinstance(o, (LinearEntity, LinearEntity3D)): + # recast to 3D + p1, p2 = o.p1, o.p2 + if isinstance(o, Segment): + o = Segment3D(p1, p2) + elif isinstance(o, Ray): + o = Ray3D(p1, p2) + elif isinstance(o, Line): + o = Line3D(p1, p2) + else: + raise ValueError('unhandled linear entity: %s' % o.func) + if o in self: + return [o] + else: + a = Point3D(o.arbitrary_point(t)) + p1, n = self.p1, Point3D(self.normal_vector) + + # TODO: Replace solve with solveset, when this line is tested + c = solve((a - p1).dot(n), t) + if not c: + return [] + else: + c = [i for i in c if i.is_real is not False] + if len(c) > 1: + c = [i for i in c if i.is_real] + if len(c) != 1: + raise Undecidable("not sure which point is real") + p = a.subs(t, c[0]) + if p not in o: + return [] # e.g. a segment might not intersect a plane + return [p] + if isinstance(o, Plane): + if self.equals(o): + return [self] + if self.is_parallel(o): + return [] + else: + x, y, z = map(Dummy, 'xyz') + a, b = Matrix([self.normal_vector]), Matrix([o.normal_vector]) + c = list(a.cross(b)) + d = self.equation(x, y, z) + e = o.equation(x, y, z) + result = list(linsolve([d, e], x, y, z))[0] + for i in (x, y, z): result = result.subs(i, 0) + return [Line3D(Point3D(result), direction_ratio=c)] + + + def is_coplanar(self, o): + """ Returns True if `o` is coplanar with self, else False. + + Examples + ======== + + >>> from sympy import Plane + >>> o = (0, 0, 0) + >>> p = Plane(o, (1, 1, 1)) + >>> p2 = Plane(o, (2, 2, 2)) + >>> p == p2 + False + >>> p.is_coplanar(p2) + True + """ + if isinstance(o, Plane): + return not cancel(self.equation(x, y, z)/o.equation(x, y, z)).has(x, y, z) + if isinstance(o, Point3D): + return o in self + elif isinstance(o, LinearEntity3D): + return all(i in self for i in self) + elif isinstance(o, GeometryEntity): # XXX should only be handling 2D objects now + return all(i == 0 for i in self.normal_vector[:2]) + + + def is_parallel(self, l): + """Is the given geometric entity parallel to the plane? + + Parameters + ========== + + LinearEntity3D or Plane + + Returns + ======= + + Boolean + + Examples + ======== + + >>> from sympy import Plane, Point3D + >>> a = Plane(Point3D(1,4,6), normal_vector=(2, 4, 6)) + >>> b = Plane(Point3D(3,1,3), normal_vector=(4, 8, 12)) + >>> a.is_parallel(b) + True + + """ + if isinstance(l, LinearEntity3D): + a = l.direction_ratio + b = self.normal_vector + return sum(i*j for i, j in zip(a, b)) == 0 + if isinstance(l, Plane): + a = Matrix(l.normal_vector) + b = Matrix(self.normal_vector) + return bool(a.cross(b).is_zero_matrix) + + + def is_perpendicular(self, l): + """Is the given geometric entity perpendicualar to the given plane? + + Parameters + ========== + + LinearEntity3D or Plane + + Returns + ======= + + Boolean + + Examples + ======== + + >>> from sympy import Plane, Point3D + >>> a = Plane(Point3D(1,4,6), normal_vector=(2, 4, 6)) + >>> b = Plane(Point3D(2, 2, 2), normal_vector=(-1, 2, -1)) + >>> a.is_perpendicular(b) + True + + """ + if isinstance(l, LinearEntity3D): + a = Matrix(l.direction_ratio) + b = Matrix(self.normal_vector) + if a.cross(b).is_zero_matrix: + return True + else: + return False + elif isinstance(l, Plane): + a = Matrix(l.normal_vector) + b = Matrix(self.normal_vector) + if a.dot(b) == 0: + return True + else: + return False + else: + return False + + @property + def normal_vector(self): + """Normal vector of the given plane. + + Examples + ======== + + >>> from sympy import Point3D, Plane + >>> a = Plane(Point3D(1, 1, 1), Point3D(2, 3, 4), Point3D(2, 2, 2)) + >>> a.normal_vector + (-1, 2, -1) + >>> a = Plane(Point3D(1, 1, 1), normal_vector=(1, 4, 7)) + >>> a.normal_vector + (1, 4, 7) + + """ + return self.args[1] + + @property + def p1(self): + """The only defining point of the plane. Others can be obtained from the + arbitrary_point method. + + See Also + ======== + + sympy.geometry.point.Point3D + + Examples + ======== + + >>> from sympy import Point3D, Plane + >>> a = Plane(Point3D(1, 1, 1), Point3D(2, 3, 4), Point3D(2, 2, 2)) + >>> a.p1 + Point3D(1, 1, 1) + + """ + return self.args[0] + + def parallel_plane(self, pt): + """ + Plane parallel to the given plane and passing through the point pt. + + Parameters + ========== + + pt: Point3D + + Returns + ======= + + Plane + + Examples + ======== + + >>> from sympy import Plane, Point3D + >>> a = Plane(Point3D(1, 4, 6), normal_vector=(2, 4, 6)) + >>> a.parallel_plane(Point3D(2, 3, 5)) + Plane(Point3D(2, 3, 5), (2, 4, 6)) + + """ + a = self.normal_vector + return Plane(pt, normal_vector=a) + + def perpendicular_line(self, pt): + """A line perpendicular to the given plane. + + Parameters + ========== + + pt: Point3D + + Returns + ======= + + Line3D + + Examples + ======== + + >>> from sympy import Plane, Point3D + >>> a = Plane(Point3D(1,4,6), normal_vector=(2, 4, 6)) + >>> a.perpendicular_line(Point3D(9, 8, 7)) + Line3D(Point3D(9, 8, 7), Point3D(11, 12, 13)) + + """ + a = self.normal_vector + return Line3D(pt, direction_ratio=a) + + def perpendicular_plane(self, *pts): + """ + Return a perpendicular passing through the given points. If the + direction ratio between the points is the same as the Plane's normal + vector then, to select from the infinite number of possible planes, + a third point will be chosen on the z-axis (or the y-axis + if the normal vector is already parallel to the z-axis). If less than + two points are given they will be supplied as follows: if no point is + given then pt1 will be self.p1; if a second point is not given it will + be a point through pt1 on a line parallel to the z-axis (if the normal + is not already the z-axis, otherwise on the line parallel to the + y-axis). + + Parameters + ========== + + pts: 0, 1 or 2 Point3D + + Returns + ======= + + Plane + + Examples + ======== + + >>> from sympy import Plane, Point3D + >>> a, b = Point3D(0, 0, 0), Point3D(0, 1, 0) + >>> Z = (0, 0, 1) + >>> p = Plane(a, normal_vector=Z) + >>> p.perpendicular_plane(a, b) + Plane(Point3D(0, 0, 0), (1, 0, 0)) + """ + if len(pts) > 2: + raise ValueError('No more than 2 pts should be provided.') + + pts = list(pts) + if len(pts) == 0: + pts.append(self.p1) + if len(pts) == 1: + x, y, z = self.normal_vector + if x == y == 0: + dir = (0, 1, 0) + else: + dir = (0, 0, 1) + pts.append(pts[0] + Point3D(*dir)) + + p1, p2 = [Point(i, dim=3) for i in pts] + l = Line3D(p1, p2) + n = Line3D(p1, direction_ratio=self.normal_vector) + if l in n: # XXX should an error be raised instead? + # there are infinitely many perpendicular planes; + x, y, z = self.normal_vector + if x == y == 0: + # the z axis is the normal so pick a pt on the y-axis + p3 = Point3D(0, 1, 0) # case 1 + else: + # else pick a pt on the z axis + p3 = Point3D(0, 0, 1) # case 2 + # in case that point is already given, move it a bit + if p3 in l: + p3 *= 2 # case 3 + else: + p3 = p1 + Point3D(*self.normal_vector) # case 4 + return Plane(p1, p2, p3) + + def projection_line(self, line): + """Project the given line onto the plane through the normal plane + containing the line. + + Parameters + ========== + + LinearEntity or LinearEntity3D + + Returns + ======= + + Point3D, Line3D, Ray3D or Segment3D + + Notes + ===== + + For the interaction between 2D and 3D lines(segments, rays), you should + convert the line to 3D by using this method. For example for finding the + intersection between a 2D and a 3D line, convert the 2D line to a 3D line + by projecting it on a required plane and then proceed to find the + intersection between those lines. + + Examples + ======== + + >>> from sympy import Plane, Line, Line3D, Point3D + >>> a = Plane(Point3D(1, 1, 1), normal_vector=(1, 1, 1)) + >>> b = Line(Point3D(1, 1), Point3D(2, 2)) + >>> a.projection_line(b) + Line3D(Point3D(4/3, 4/3, 1/3), Point3D(5/3, 5/3, -1/3)) + >>> c = Line3D(Point3D(1, 1, 1), Point3D(2, 2, 2)) + >>> a.projection_line(c) + Point3D(1, 1, 1) + + """ + if not isinstance(line, (LinearEntity, LinearEntity3D)): + raise NotImplementedError('Enter a linear entity only') + a, b = self.projection(line.p1), self.projection(line.p2) + if a == b: + # projection does not imply intersection so for + # this case (line parallel to plane's normal) we + # return the projection point + return a + if isinstance(line, (Line, Line3D)): + return Line3D(a, b) + if isinstance(line, (Ray, Ray3D)): + return Ray3D(a, b) + if isinstance(line, (Segment, Segment3D)): + return Segment3D(a, b) + + def projection(self, pt): + """Project the given point onto the plane along the plane normal. + + Parameters + ========== + + Point or Point3D + + Returns + ======= + + Point3D + + Examples + ======== + + >>> from sympy import Plane, Point3D + >>> A = Plane(Point3D(1, 1, 2), normal_vector=(1, 1, 1)) + + The projection is along the normal vector direction, not the z + axis, so (1, 1) does not project to (1, 1, 2) on the plane A: + + >>> b = Point3D(1, 1) + >>> A.projection(b) + Point3D(5/3, 5/3, 2/3) + >>> _ in A + True + + But the point (1, 1, 2) projects to (1, 1) on the XY-plane: + + >>> XY = Plane((0, 0, 0), (0, 0, 1)) + >>> XY.projection((1, 1, 2)) + Point3D(1, 1, 0) + """ + rv = Point(pt, dim=3) + if rv in self: + return rv + return self.intersection(Line3D(rv, rv + Point3D(self.normal_vector)))[0] + + def random_point(self, seed=None): + """ Returns a random point on the Plane. + + Returns + ======= + + Point3D + + Examples + ======== + + >>> from sympy import Plane + >>> p = Plane((1, 0, 0), normal_vector=(0, 1, 0)) + >>> r = p.random_point(seed=42) # seed value is optional + >>> r.n(3) + Point3D(2.29, 0, -1.35) + + The random point can be moved to lie on the circle of radius + 1 centered on p1: + + >>> c = p.p1 + (r - p.p1).unit + >>> c.distance(p.p1).equals(1) + True + """ + if seed is not None: + rng = random.Random(seed) + else: + rng = random + params = { + x: 2*Rational(rng.gauss(0, 1)) - 1, + y: 2*Rational(rng.gauss(0, 1)) - 1} + return self.arbitrary_point(x, y).subs(params) + + def parameter_value(self, other, u, v=None): + """Return the parameter(s) corresponding to the given point. + + Examples + ======== + + >>> from sympy import pi, Plane + >>> from sympy.abc import t, u, v + >>> p = Plane((2, 0, 0), (0, 0, 1), (0, 1, 0)) + + By default, the parameter value returned defines a point + that is a distance of 1 from the Plane's p1 value and + in line with the given point: + + >>> on_circle = p.arbitrary_point(t).subs(t, pi/4) + >>> on_circle.distance(p.p1) + 1 + >>> p.parameter_value(on_circle, t) + {t: pi/4} + + Moving the point twice as far from p1 does not change + the parameter value: + + >>> off_circle = p.p1 + (on_circle - p.p1)*2 + >>> off_circle.distance(p.p1) + 2 + >>> p.parameter_value(off_circle, t) + {t: pi/4} + + If the 2-value parameter is desired, supply the two + parameter symbols and a replacement dictionary will + be returned: + + >>> p.parameter_value(on_circle, u, v) + {u: sqrt(10)/10, v: sqrt(10)/30} + >>> p.parameter_value(off_circle, u, v) + {u: sqrt(10)/5, v: sqrt(10)/15} + """ + if not isinstance(other, GeometryEntity): + other = Point(other, dim=self.ambient_dimension) + if not isinstance(other, Point): + raise ValueError("other must be a point") + if other == self.p1: + return other + if isinstance(u, Symbol) and v is None: + delta = self.arbitrary_point(u) - self.p1 + eq = delta - (other - self.p1).unit + sol = solve(eq, u, dict=True) + elif isinstance(u, Symbol) and isinstance(v, Symbol): + pt = self.arbitrary_point(u, v) + sol = solve(pt - other, (u, v), dict=True) + else: + raise ValueError('expecting 1 or 2 symbols') + if not sol: + raise ValueError("Given point is not on %s" % func_name(self)) + return sol[0] # {t: tval} or {u: uval, v: vval} + + @property + def ambient_dimension(self): + return self.p1.ambient_dimension diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/point.py b/.venv/lib/python3.13/site-packages/sympy/geometry/point.py new file mode 100644 index 0000000000000000000000000000000000000000..19e6c566f06de4df086912470dc35d0f4af3bd38 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/point.py @@ -0,0 +1,1378 @@ +"""Geometrical Points. + +Contains +======== +Point +Point2D +Point3D + +When methods of Point require 1 or more points as arguments, they +can be passed as a sequence of coordinates or Points: + +>>> from sympy import Point +>>> Point(1, 1).is_collinear((2, 2), (3, 4)) +False +>>> Point(1, 1).is_collinear(Point(2, 2), Point(3, 4)) +False + +""" + +import warnings + +from sympy.core import S, sympify, Expr +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.numbers import Float +from sympy.core.parameters import global_parameters +from sympy.simplify.simplify import nsimplify, simplify +from sympy.geometry.exceptions import GeometryError +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.complexes import im +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.matrices import Matrix +from sympy.matrices.expressions import Transpose +from sympy.utilities.iterables import uniq, is_sequence +from sympy.utilities.misc import filldedent, func_name, Undecidable + +from .entity import GeometryEntity + +from mpmath.libmp.libmpf import prec_to_dps + + +class Point(GeometryEntity): + """A point in a n-dimensional Euclidean space. + + Parameters + ========== + + coords : sequence of n-coordinate values. In the special + case where n=2 or 3, a Point2D or Point3D will be created + as appropriate. + evaluate : if `True` (default), all floats are turn into + exact types. + dim : number of coordinates the point should have. If coordinates + are unspecified, they are padded with zeros. + on_morph : indicates what should happen when the number of + coordinates of a point need to be changed by adding or + removing zeros. Possible values are `'warn'`, `'error'`, or + `ignore` (default). No warning or error is given when `*args` + is empty and `dim` is given. An error is always raised when + trying to remove nonzero coordinates. + + + Attributes + ========== + + length + origin: A `Point` representing the origin of the + appropriately-dimensioned space. + + Raises + ====== + + TypeError : When instantiating with anything but a Point or sequence + ValueError : when instantiating with a sequence with length < 2 or + when trying to reduce dimensions if keyword `on_morph='error'` is + set. + + See Also + ======== + + sympy.geometry.line.Segment : Connects two Points + + Examples + ======== + + >>> from sympy import Point + >>> from sympy.abc import x + >>> Point(1, 2, 3) + Point3D(1, 2, 3) + >>> Point([1, 2]) + Point2D(1, 2) + >>> Point(0, x) + Point2D(0, x) + >>> Point(dim=4) + Point(0, 0, 0, 0) + + Floats are automatically converted to Rational unless the + evaluate flag is False: + + >>> Point(0.5, 0.25) + Point2D(1/2, 1/4) + >>> Point(0.5, 0.25, evaluate=False) + Point2D(0.5, 0.25) + + """ + + is_Point = True + + def __new__(cls, *args, **kwargs): + evaluate = kwargs.get('evaluate', global_parameters.evaluate) + on_morph = kwargs.get('on_morph', 'ignore') + + # unpack into coords + coords = args[0] if len(args) == 1 else args + + # check args and handle quickly handle Point instances + if isinstance(coords, Point): + # even if we're mutating the dimension of a point, we + # don't reevaluate its coordinates + evaluate = False + if len(coords) == kwargs.get('dim', len(coords)): + return coords + + if not is_sequence(coords): + raise TypeError(filldedent(''' + Expecting sequence of coordinates, not `{}`''' + .format(func_name(coords)))) + # A point where only `dim` is specified is initialized + # to zeros. + if len(coords) == 0 and kwargs.get('dim', None): + coords = (S.Zero,)*kwargs.get('dim') + + coords = Tuple(*coords) + dim = kwargs.get('dim', len(coords)) + + if len(coords) < 2: + raise ValueError(filldedent(''' + Point requires 2 or more coordinates or + keyword `dim` > 1.''')) + if len(coords) != dim: + message = ("Dimension of {} needs to be changed " + "from {} to {}.").format(coords, len(coords), dim) + if on_morph == 'ignore': + pass + elif on_morph == "error": + raise ValueError(message) + elif on_morph == 'warn': + warnings.warn(message, stacklevel=2) + else: + raise ValueError(filldedent(''' + on_morph value should be 'error', + 'warn' or 'ignore'.''')) + if any(coords[dim:]): + raise ValueError('Nonzero coordinates cannot be removed.') + if any(a.is_number and im(a).is_zero is False for a in coords): + raise ValueError('Imaginary coordinates are not permitted.') + if not all(isinstance(a, Expr) for a in coords): + raise TypeError('Coordinates must be valid SymPy expressions.') + + # pad with zeros appropriately + coords = coords[:dim] + (S.Zero,)*(dim - len(coords)) + + # Turn any Floats into rationals and simplify + # any expressions before we instantiate + if evaluate: + coords = coords.xreplace({ + f: simplify(nsimplify(f, rational=True)) + for f in coords.atoms(Float)}) + + # return 2D or 3D instances + if len(coords) == 2: + kwargs['_nocheck'] = True + return Point2D(*coords, **kwargs) + elif len(coords) == 3: + kwargs['_nocheck'] = True + return Point3D(*coords, **kwargs) + + # the general Point + return GeometryEntity.__new__(cls, *coords) + + def __abs__(self): + """Returns the distance between this point and the origin.""" + origin = Point([0]*len(self)) + return Point.distance(origin, self) + + def __add__(self, other): + """Add other to self by incrementing self's coordinates by + those of other. + + Notes + ===== + + >>> from sympy import Point + + When sequences of coordinates are passed to Point methods, they + are converted to a Point internally. This __add__ method does + not do that so if floating point values are used, a floating + point result (in terms of SymPy Floats) will be returned. + + >>> Point(1, 2) + (.1, .2) + Point2D(1.1, 2.2) + + If this is not desired, the `translate` method can be used or + another Point can be added: + + >>> Point(1, 2).translate(.1, .2) + Point2D(11/10, 11/5) + >>> Point(1, 2) + Point(.1, .2) + Point2D(11/10, 11/5) + + See Also + ======== + + sympy.geometry.point.Point.translate + + """ + try: + s, o = Point._normalize_dimension(self, Point(other, evaluate=False)) + except TypeError: + raise GeometryError("Don't know how to add {} and a Point object".format(other)) + + coords = [simplify(a + b) for a, b in zip(s, o)] + return Point(coords, evaluate=False) + + def __contains__(self, item): + return item in self.args + + def __truediv__(self, divisor): + """Divide point's coordinates by a factor.""" + divisor = sympify(divisor) + coords = [simplify(x/divisor) for x in self.args] + return Point(coords, evaluate=False) + + def __eq__(self, other): + if not isinstance(other, Point) or len(self.args) != len(other.args): + return False + return self.args == other.args + + def __getitem__(self, key): + return self.args[key] + + def __hash__(self): + return hash(self.args) + + def __iter__(self): + return self.args.__iter__() + + def __len__(self): + return len(self.args) + + def __mul__(self, factor): + """Multiply point's coordinates by a factor. + + Notes + ===== + + >>> from sympy import Point + + When multiplying a Point by a floating point number, + the coordinates of the Point will be changed to Floats: + + >>> Point(1, 2)*0.1 + Point2D(0.1, 0.2) + + If this is not desired, the `scale` method can be used or + else only multiply or divide by integers: + + >>> Point(1, 2).scale(1.1, 1.1) + Point2D(11/10, 11/5) + >>> Point(1, 2)*11/10 + Point2D(11/10, 11/5) + + See Also + ======== + + sympy.geometry.point.Point.scale + """ + factor = sympify(factor) + coords = [simplify(x*factor) for x in self.args] + return Point(coords, evaluate=False) + + def __rmul__(self, factor): + """Multiply a factor by point's coordinates.""" + return self.__mul__(factor) + + def __neg__(self): + """Negate the point.""" + coords = [-x for x in self.args] + return Point(coords, evaluate=False) + + def __sub__(self, other): + """Subtract two points, or subtract a factor from this point's + coordinates.""" + return self + [-x for x in other] + + @classmethod + def _normalize_dimension(cls, *points, **kwargs): + """Ensure that points have the same dimension. + By default `on_morph='warn'` is passed to the + `Point` constructor.""" + # if we have a built-in ambient dimension, use it + dim = getattr(cls, '_ambient_dimension', None) + # override if we specified it + dim = kwargs.get('dim', dim) + # if no dim was given, use the highest dimensional point + if dim is None: + dim = max(i.ambient_dimension for i in points) + if all(i.ambient_dimension == dim for i in points): + return list(points) + kwargs['dim'] = dim + kwargs['on_morph'] = kwargs.get('on_morph', 'warn') + return [Point(i, **kwargs) for i in points] + + @staticmethod + def affine_rank(*args): + """The affine rank of a set of points is the dimension + of the smallest affine space containing all the points. + For example, if the points lie on a line (and are not all + the same) their affine rank is 1. If the points lie on a plane + but not a line, their affine rank is 2. By convention, the empty + set has affine rank -1.""" + + if len(args) == 0: + return -1 + # make sure we're genuinely points + # and translate every point to the origin + points = Point._normalize_dimension(*[Point(i) for i in args]) + origin = points[0] + points = [i - origin for i in points[1:]] + + m = Matrix([i.args for i in points]) + # XXX fragile -- what is a better way? + return m.rank(iszerofunc = lambda x: + abs(x.n(2)) < 1e-12 if x.is_number else x.is_zero) + + @property + def ambient_dimension(self): + """Number of components this point has.""" + return getattr(self, '_ambient_dimension', len(self)) + + @classmethod + def are_coplanar(cls, *points): + """Return True if there exists a plane in which all the points + lie. A trivial True value is returned if `len(points) < 3` or + all Points are 2-dimensional. + + Parameters + ========== + + A set of points + + Raises + ====== + + ValueError : if less than 3 unique points are given + + Returns + ======= + + boolean + + Examples + ======== + + >>> from sympy import Point3D + >>> p1 = Point3D(1, 2, 2) + >>> p2 = Point3D(2, 7, 2) + >>> p3 = Point3D(0, 0, 2) + >>> p4 = Point3D(1, 1, 2) + >>> Point3D.are_coplanar(p1, p2, p3, p4) + True + >>> p5 = Point3D(0, 1, 3) + >>> Point3D.are_coplanar(p1, p2, p3, p5) + False + + """ + if len(points) <= 1: + return True + + points = cls._normalize_dimension(*[Point(i) for i in points]) + # quick exit if we are in 2D + if points[0].ambient_dimension == 2: + return True + points = list(uniq(points)) + return Point.affine_rank(*points) <= 2 + + def distance(self, other): + """The Euclidean distance between self and another GeometricEntity. + + Returns + ======= + + distance : number or symbolic expression. + + Raises + ====== + + TypeError : if other is not recognized as a GeometricEntity or is a + GeometricEntity for which distance is not defined. + + See Also + ======== + + sympy.geometry.line.Segment.length + sympy.geometry.point.Point.taxicab_distance + + Examples + ======== + + >>> from sympy import Point, Line + >>> p1, p2 = Point(1, 1), Point(4, 5) + >>> l = Line((3, 1), (2, 2)) + >>> p1.distance(p2) + 5 + >>> p1.distance(l) + sqrt(2) + + The computed distance may be symbolic, too: + + >>> from sympy.abc import x, y + >>> p3 = Point(x, y) + >>> p3.distance((0, 0)) + sqrt(x**2 + y**2) + + """ + if not isinstance(other, GeometryEntity): + try: + other = Point(other, dim=self.ambient_dimension) + except TypeError: + raise TypeError("not recognized as a GeometricEntity: %s" % type(other)) + if isinstance(other, Point): + s, p = Point._normalize_dimension(self, Point(other)) + return sqrt(Add(*((a - b)**2 for a, b in zip(s, p)))) + distance = getattr(other, 'distance', None) + if distance is None: + raise TypeError("distance between Point and %s is not defined" % type(other)) + return distance(self) + + def dot(self, p): + """Return dot product of self with another Point.""" + if not is_sequence(p): + p = Point(p) # raise the error via Point + return Add(*(a*b for a, b in zip(self, p))) + + def equals(self, other): + """Returns whether the coordinates of self and other agree.""" + # a point is equal to another point if all its components are equal + if not isinstance(other, Point) or len(self) != len(other): + return False + return all(a.equals(b) for a, b in zip(self, other)) + + def _eval_evalf(self, prec=15, **options): + """Evaluate the coordinates of the point. + + This method will, where possible, create and return a new Point + where the coordinates are evaluated as floating point numbers to + the precision indicated (default=15). + + Parameters + ========== + + prec : int + + Returns + ======= + + point : Point + + Examples + ======== + + >>> from sympy import Point, Rational + >>> p1 = Point(Rational(1, 2), Rational(3, 2)) + >>> p1 + Point2D(1/2, 3/2) + >>> p1.evalf() + Point2D(0.5, 1.5) + + """ + dps = prec_to_dps(prec) + coords = [x.evalf(n=dps, **options) for x in self.args] + return Point(*coords, evaluate=False) + + def intersection(self, other): + """The intersection between this point and another GeometryEntity. + + Parameters + ========== + + other : GeometryEntity or sequence of coordinates + + Returns + ======= + + intersection : list of Points + + Notes + ===== + + The return value will either be an empty list if there is no + intersection, otherwise it will contain this point. + + Examples + ======== + + >>> from sympy import Point + >>> p1, p2, p3 = Point(0, 0), Point(1, 1), Point(0, 0) + >>> p1.intersection(p2) + [] + >>> p1.intersection(p3) + [Point2D(0, 0)] + + """ + if not isinstance(other, GeometryEntity): + other = Point(other) + if isinstance(other, Point): + if self == other: + return [self] + p1, p2 = Point._normalize_dimension(self, other) + if p1 == self and p1 == p2: + return [self] + return [] + return other.intersection(self) + + def is_collinear(self, *args): + """Returns `True` if there exists a line + that contains `self` and `points`. Returns `False` otherwise. + A trivially True value is returned if no points are given. + + Parameters + ========== + + args : sequence of Points + + Returns + ======= + + is_collinear : boolean + + See Also + ======== + + sympy.geometry.line.Line + + Examples + ======== + + >>> from sympy import Point + >>> from sympy.abc import x + >>> p1, p2 = Point(0, 0), Point(1, 1) + >>> p3, p4, p5 = Point(2, 2), Point(x, x), Point(1, 2) + >>> Point.is_collinear(p1, p2, p3, p4) + True + >>> Point.is_collinear(p1, p2, p3, p5) + False + + """ + points = (self,) + args + points = Point._normalize_dimension(*[Point(i) for i in points]) + points = list(uniq(points)) + return Point.affine_rank(*points) <= 1 + + def is_concyclic(self, *args): + """Do `self` and the given sequence of points lie in a circle? + + Returns True if the set of points are concyclic and + False otherwise. A trivial value of True is returned + if there are fewer than 2 other points. + + Parameters + ========== + + args : sequence of Points + + Returns + ======= + + is_concyclic : boolean + + + Examples + ======== + + >>> from sympy import Point + + Define 4 points that are on the unit circle: + + >>> p1, p2, p3, p4 = Point(1, 0), (0, 1), (-1, 0), (0, -1) + + >>> p1.is_concyclic() == p1.is_concyclic(p2, p3, p4) == True + True + + Define a point not on that circle: + + >>> p = Point(1, 1) + + >>> p.is_concyclic(p1, p2, p3) + False + + """ + points = (self,) + args + points = Point._normalize_dimension(*[Point(i) for i in points]) + points = list(uniq(points)) + if not Point.affine_rank(*points) <= 2: + return False + origin = points[0] + points = [p - origin for p in points] + # points are concyclic if they are coplanar and + # there is a point c so that ||p_i-c|| == ||p_j-c|| for all + # i and j. Rearranging this equation gives us the following + # condition: the matrix `mat` must not a pivot in the last + # column. + mat = Matrix([list(i) + [i.dot(i)] for i in points]) + rref, pivots = mat.rref() + if len(origin) not in pivots: + return True + return False + + @property + def is_nonzero(self): + """True if any coordinate is nonzero, False if every coordinate is zero, + and None if it cannot be determined.""" + is_zero = self.is_zero + if is_zero is None: + return None + return not is_zero + + def is_scalar_multiple(self, p): + """Returns whether each coordinate of `self` is a scalar + multiple of the corresponding coordinate in point p. + """ + s, o = Point._normalize_dimension(self, Point(p)) + # 2d points happen a lot, so optimize this function call + if s.ambient_dimension == 2: + (x1, y1), (x2, y2) = s.args, o.args + rv = (x1*y2 - x2*y1).equals(0) + if rv is None: + raise Undecidable(filldedent( + '''Cannot determine if %s is a scalar multiple of + %s''' % (s, o))) + + # if the vectors p1 and p2 are linearly dependent, then they must + # be scalar multiples of each other + m = Matrix([s.args, o.args]) + return m.rank() < 2 + + @property + def is_zero(self): + """True if every coordinate is zero, False if any coordinate is not zero, + and None if it cannot be determined.""" + nonzero = [x.is_nonzero for x in self.args] + if any(nonzero): + return False + if any(x is None for x in nonzero): + return None + return True + + @property + def length(self): + """ + Treating a Point as a Line, this returns 0 for the length of a Point. + + Examples + ======== + + >>> from sympy import Point + >>> p = Point(0, 1) + >>> p.length + 0 + """ + return S.Zero + + def midpoint(self, p): + """The midpoint between self and point p. + + Parameters + ========== + + p : Point + + Returns + ======= + + midpoint : Point + + See Also + ======== + + sympy.geometry.line.Segment.midpoint + + Examples + ======== + + >>> from sympy import Point + >>> p1, p2 = Point(1, 1), Point(13, 5) + >>> p1.midpoint(p2) + Point2D(7, 3) + + """ + s, p = Point._normalize_dimension(self, Point(p)) + return Point([simplify((a + b)*S.Half) for a, b in zip(s, p)]) + + @property + def origin(self): + """A point of all zeros of the same ambient dimension + as the current point""" + return Point([0]*len(self), evaluate=False) + + @property + def orthogonal_direction(self): + """Returns a non-zero point that is orthogonal to the + line containing `self` and the origin. + + Examples + ======== + + >>> from sympy import Line, Point + >>> a = Point(1, 2, 3) + >>> a.orthogonal_direction + Point3D(-2, 1, 0) + >>> b = _ + >>> Line(b, b.origin).is_perpendicular(Line(a, a.origin)) + True + """ + dim = self.ambient_dimension + # if a coordinate is zero, we can put a 1 there and zeros elsewhere + if self[0].is_zero: + return Point([1] + (dim - 1)*[0]) + if self[1].is_zero: + return Point([0,1] + (dim - 2)*[0]) + # if the first two coordinates aren't zero, we can create a non-zero + # orthogonal vector by swapping them, negating one, and padding with zeros + return Point([-self[1], self[0]] + (dim - 2)*[0]) + + @staticmethod + def project(a, b): + """Project the point `a` onto the line between the origin + and point `b` along the normal direction. + + Parameters + ========== + + a : Point + b : Point + + Returns + ======= + + p : Point + + See Also + ======== + + sympy.geometry.line.LinearEntity.projection + + Examples + ======== + + >>> from sympy import Line, Point + >>> a = Point(1, 2) + >>> b = Point(2, 5) + >>> z = a.origin + >>> p = Point.project(a, b) + >>> Line(p, a).is_perpendicular(Line(p, b)) + True + >>> Point.is_collinear(z, p, b) + True + """ + a, b = Point._normalize_dimension(Point(a), Point(b)) + if b.is_zero: + raise ValueError("Cannot project to the zero vector.") + return b*(a.dot(b) / b.dot(b)) + + def taxicab_distance(self, p): + """The Taxicab Distance from self to point p. + + Returns the sum of the horizontal and vertical distances to point p. + + Parameters + ========== + + p : Point + + Returns + ======= + + taxicab_distance : The sum of the horizontal + and vertical distances to point p. + + See Also + ======== + + sympy.geometry.point.Point.distance + + Examples + ======== + + >>> from sympy import Point + >>> p1, p2 = Point(1, 1), Point(4, 5) + >>> p1.taxicab_distance(p2) + 7 + + """ + s, p = Point._normalize_dimension(self, Point(p)) + return Add(*(abs(a - b) for a, b in zip(s, p))) + + def canberra_distance(self, p): + """The Canberra Distance from self to point p. + + Returns the weighted sum of horizontal and vertical distances to + point p. + + Parameters + ========== + + p : Point + + Returns + ======= + + canberra_distance : The weighted sum of horizontal and vertical + distances to point p. The weight used is the sum of absolute values + of the coordinates. + + Examples + ======== + + >>> from sympy import Point + >>> p1, p2 = Point(1, 1), Point(3, 3) + >>> p1.canberra_distance(p2) + 1 + >>> p1, p2 = Point(0, 0), Point(3, 3) + >>> p1.canberra_distance(p2) + 2 + + Raises + ====== + + ValueError when both vectors are zero. + + See Also + ======== + + sympy.geometry.point.Point.distance + + """ + + s, p = Point._normalize_dimension(self, Point(p)) + if self.is_zero and p.is_zero: + raise ValueError("Cannot project to the zero vector.") + return Add(*((abs(a - b)/(abs(a) + abs(b))) for a, b in zip(s, p))) + + @property + def unit(self): + """Return the Point that is in the same direction as `self` + and a distance of 1 from the origin""" + return self / abs(self) + + +class Point2D(Point): + """A point in a 2-dimensional Euclidean space. + + Parameters + ========== + + coords + A sequence of 2 coordinate values. + + Attributes + ========== + + x + y + length + + Raises + ====== + + TypeError + When trying to add or subtract points with different dimensions. + When trying to create a point with more than two dimensions. + When `intersection` is called with object other than a Point. + + See Also + ======== + + sympy.geometry.line.Segment : Connects two Points + + Examples + ======== + + >>> from sympy import Point2D + >>> from sympy.abc import x + >>> Point2D(1, 2) + Point2D(1, 2) + >>> Point2D([1, 2]) + Point2D(1, 2) + >>> Point2D(0, x) + Point2D(0, x) + + Floats are automatically converted to Rational unless the + evaluate flag is False: + + >>> Point2D(0.5, 0.25) + Point2D(1/2, 1/4) + >>> Point2D(0.5, 0.25, evaluate=False) + Point2D(0.5, 0.25) + + """ + + _ambient_dimension = 2 + + def __new__(cls, *args, _nocheck=False, **kwargs): + if not _nocheck: + kwargs['dim'] = 2 + args = Point(*args, **kwargs) + return GeometryEntity.__new__(cls, *args) + + def __contains__(self, item): + return item == self + + @property + def bounds(self): + """Return a tuple (xmin, ymin, xmax, ymax) representing the bounding + rectangle for the geometric figure. + + """ + + return (self.x, self.y, self.x, self.y) + + def rotate(self, angle, pt=None): + """Rotate ``angle`` radians counterclockwise about Point ``pt``. + + See Also + ======== + + translate, scale + + Examples + ======== + + >>> from sympy import Point2D, pi + >>> t = Point2D(1, 0) + >>> t.rotate(pi/2) + Point2D(0, 1) + >>> t.rotate(pi/2, (2, 0)) + Point2D(2, -1) + + """ + c = cos(angle) + s = sin(angle) + + rv = self + if pt is not None: + pt = Point(pt, dim=2) + rv -= pt + x, y = rv.args + rv = Point(c*x - s*y, s*x + c*y) + if pt is not None: + rv += pt + return rv + + def scale(self, x=1, y=1, pt=None): + """Scale the coordinates of the Point by multiplying by + ``x`` and ``y`` after subtracting ``pt`` -- default is (0, 0) -- + and then adding ``pt`` back again (i.e. ``pt`` is the point of + reference for the scaling). + + See Also + ======== + + rotate, translate + + Examples + ======== + + >>> from sympy import Point2D + >>> t = Point2D(1, 1) + >>> t.scale(2) + Point2D(2, 1) + >>> t.scale(2, 2) + Point2D(2, 2) + + """ + if pt: + pt = Point(pt, dim=2) + return self.translate(*(-pt).args).scale(x, y).translate(*pt.args) + return Point(self.x*x, self.y*y) + + def transform(self, matrix): + """Return the point after applying the transformation described + by the 3x3 Matrix, ``matrix``. + + See Also + ======== + sympy.geometry.point.Point2D.rotate + sympy.geometry.point.Point2D.scale + sympy.geometry.point.Point2D.translate + """ + if not (matrix.is_Matrix and matrix.shape == (3, 3)): + raise ValueError("matrix must be a 3x3 matrix") + x, y = self.args + return Point(*(Matrix(1, 3, [x, y, 1])*matrix).tolist()[0][:2]) + + def translate(self, x=0, y=0): + """Shift the Point by adding x and y to the coordinates of the Point. + + See Also + ======== + + sympy.geometry.point.Point2D.rotate, scale + + Examples + ======== + + >>> from sympy import Point2D + >>> t = Point2D(0, 1) + >>> t.translate(2) + Point2D(2, 1) + >>> t.translate(2, 2) + Point2D(2, 3) + >>> t + Point2D(2, 2) + Point2D(2, 3) + + """ + return Point(self.x + x, self.y + y) + + @property + def coordinates(self): + """ + Returns the two coordinates of the Point. + + Examples + ======== + + >>> from sympy import Point2D + >>> p = Point2D(0, 1) + >>> p.coordinates + (0, 1) + """ + return self.args + + @property + def x(self): + """ + Returns the X coordinate of the Point. + + Examples + ======== + + >>> from sympy import Point2D + >>> p = Point2D(0, 1) + >>> p.x + 0 + """ + return self.args[0] + + @property + def y(self): + """ + Returns the Y coordinate of the Point. + + Examples + ======== + + >>> from sympy import Point2D + >>> p = Point2D(0, 1) + >>> p.y + 1 + """ + return self.args[1] + +class Point3D(Point): + """A point in a 3-dimensional Euclidean space. + + Parameters + ========== + + coords + A sequence of 3 coordinate values. + + Attributes + ========== + + x + y + z + length + + Raises + ====== + + TypeError + When trying to add or subtract points with different dimensions. + When `intersection` is called with object other than a Point. + + Examples + ======== + + >>> from sympy import Point3D + >>> from sympy.abc import x + >>> Point3D(1, 2, 3) + Point3D(1, 2, 3) + >>> Point3D([1, 2, 3]) + Point3D(1, 2, 3) + >>> Point3D(0, x, 3) + Point3D(0, x, 3) + + Floats are automatically converted to Rational unless the + evaluate flag is False: + + >>> Point3D(0.5, 0.25, 2) + Point3D(1/2, 1/4, 2) + >>> Point3D(0.5, 0.25, 3, evaluate=False) + Point3D(0.5, 0.25, 3) + + """ + + _ambient_dimension = 3 + + def __new__(cls, *args, _nocheck=False, **kwargs): + if not _nocheck: + kwargs['dim'] = 3 + args = Point(*args, **kwargs) + return GeometryEntity.__new__(cls, *args) + + def __contains__(self, item): + return item == self + + @staticmethod + def are_collinear(*points): + """Is a sequence of points collinear? + + Test whether or not a set of points are collinear. Returns True if + the set of points are collinear, or False otherwise. + + Parameters + ========== + + points : sequence of Point + + Returns + ======= + + are_collinear : boolean + + See Also + ======== + + sympy.geometry.line.Line3D + + Examples + ======== + + >>> from sympy import Point3D + >>> from sympy.abc import x + >>> p1, p2 = Point3D(0, 0, 0), Point3D(1, 1, 1) + >>> p3, p4, p5 = Point3D(2, 2, 2), Point3D(x, x, x), Point3D(1, 2, 6) + >>> Point3D.are_collinear(p1, p2, p3, p4) + True + >>> Point3D.are_collinear(p1, p2, p3, p5) + False + """ + return Point.is_collinear(*points) + + def direction_cosine(self, point): + """ + Gives the direction cosine between 2 points + + Parameters + ========== + + p : Point3D + + Returns + ======= + + list + + Examples + ======== + + >>> from sympy import Point3D + >>> p1 = Point3D(1, 2, 3) + >>> p1.direction_cosine(Point3D(2, 3, 5)) + [sqrt(6)/6, sqrt(6)/6, sqrt(6)/3] + """ + a = self.direction_ratio(point) + b = sqrt(Add(*(i**2 for i in a))) + return [(point.x - self.x) / b,(point.y - self.y) / b, + (point.z - self.z) / b] + + def direction_ratio(self, point): + """ + Gives the direction ratio between 2 points + + Parameters + ========== + + p : Point3D + + Returns + ======= + + list + + Examples + ======== + + >>> from sympy import Point3D + >>> p1 = Point3D(1, 2, 3) + >>> p1.direction_ratio(Point3D(2, 3, 5)) + [1, 1, 2] + """ + return [(point.x - self.x),(point.y - self.y),(point.z - self.z)] + + def intersection(self, other): + """The intersection between this point and another GeometryEntity. + + Parameters + ========== + + other : GeometryEntity or sequence of coordinates + + Returns + ======= + + intersection : list of Points + + Notes + ===== + + The return value will either be an empty list if there is no + intersection, otherwise it will contain this point. + + Examples + ======== + + >>> from sympy import Point3D + >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(1, 1, 1), Point3D(0, 0, 0) + >>> p1.intersection(p2) + [] + >>> p1.intersection(p3) + [Point3D(0, 0, 0)] + + """ + if not isinstance(other, GeometryEntity): + other = Point(other, dim=3) + if isinstance(other, Point3D): + if self == other: + return [self] + return [] + return other.intersection(self) + + def scale(self, x=1, y=1, z=1, pt=None): + """Scale the coordinates of the Point by multiplying by + ``x`` and ``y`` after subtracting ``pt`` -- default is (0, 0) -- + and then adding ``pt`` back again (i.e. ``pt`` is the point of + reference for the scaling). + + See Also + ======== + + translate + + Examples + ======== + + >>> from sympy import Point3D + >>> t = Point3D(1, 1, 1) + >>> t.scale(2) + Point3D(2, 1, 1) + >>> t.scale(2, 2) + Point3D(2, 2, 1) + + """ + if pt: + pt = Point3D(pt) + return self.translate(*(-pt).args).scale(x, y, z).translate(*pt.args) + return Point3D(self.x*x, self.y*y, self.z*z) + + def transform(self, matrix): + """Return the point after applying the transformation described + by the 4x4 Matrix, ``matrix``. + + See Also + ======== + sympy.geometry.point.Point3D.scale + sympy.geometry.point.Point3D.translate + """ + if not (matrix.is_Matrix and matrix.shape == (4, 4)): + raise ValueError("matrix must be a 4x4 matrix") + x, y, z = self.args + m = Transpose(matrix) + return Point3D(*(Matrix(1, 4, [x, y, z, 1])*m).tolist()[0][:3]) + + def translate(self, x=0, y=0, z=0): + """Shift the Point by adding x and y to the coordinates of the Point. + + See Also + ======== + + scale + + Examples + ======== + + >>> from sympy import Point3D + >>> t = Point3D(0, 1, 1) + >>> t.translate(2) + Point3D(2, 1, 1) + >>> t.translate(2, 2) + Point3D(2, 3, 1) + >>> t + Point3D(2, 2, 2) + Point3D(2, 3, 3) + + """ + return Point3D(self.x + x, self.y + y, self.z + z) + + @property + def coordinates(self): + """ + Returns the three coordinates of the Point. + + Examples + ======== + + >>> from sympy import Point3D + >>> p = Point3D(0, 1, 2) + >>> p.coordinates + (0, 1, 2) + """ + return self.args + + @property + def x(self): + """ + Returns the X coordinate of the Point. + + Examples + ======== + + >>> from sympy import Point3D + >>> p = Point3D(0, 1, 3) + >>> p.x + 0 + """ + return self.args[0] + + @property + def y(self): + """ + Returns the Y coordinate of the Point. + + Examples + ======== + + >>> from sympy import Point3D + >>> p = Point3D(0, 1, 2) + >>> p.y + 1 + """ + return self.args[1] + + @property + def z(self): + """ + Returns the Z coordinate of the Point. + + Examples + ======== + + >>> from sympy import Point3D + >>> p = Point3D(0, 1, 1) + >>> p.z + 1 + """ + return self.args[2] diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/polygon.py b/.venv/lib/python3.13/site-packages/sympy/geometry/polygon.py new file mode 100644 index 0000000000000000000000000000000000000000..63031183438e2d228f881fd82e1b0ecca04ec534 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/polygon.py @@ -0,0 +1,2891 @@ +from sympy.core import Expr, S, oo, pi, sympify +from sympy.core.evalf import N +from sympy.core.sorting import default_sort_key, ordered +from sympy.core.symbol import _symbol, Dummy, Symbol +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import cos, sin, tan +from .ellipse import Circle +from .entity import GeometryEntity, GeometrySet +from .exceptions import GeometryError +from .line import Line, Segment, Ray +from .point import Point +from sympy.logic import And +from sympy.matrices import Matrix +from sympy.simplify.simplify import simplify +from sympy.solvers.solvers import solve +from sympy.utilities.iterables import has_dups, has_variety, uniq, rotate_left, least_rotation +from sympy.utilities.misc import as_int, func_name + +from mpmath.libmp.libmpf import prec_to_dps + +import warnings + + +x, y, T = [Dummy('polygon_dummy', real=True) for i in range(3)] + + +class Polygon(GeometrySet): + """A two-dimensional polygon. + + A simple polygon in space. Can be constructed from a sequence of points + or from a center, radius, number of sides and rotation angle. + + Parameters + ========== + + vertices + A sequence of points. + + n : int, optional + If $> 0$, an n-sided RegularPolygon is created. + Default value is $0$. + + Attributes + ========== + + area + angles + perimeter + vertices + centroid + sides + + Raises + ====== + + GeometryError + If all parameters are not Points. + + See Also + ======== + + sympy.geometry.point.Point, sympy.geometry.line.Segment, Triangle + + Notes + ===== + + Polygons are treated as closed paths rather than 2D areas so + some calculations can be be negative or positive (e.g., area) + based on the orientation of the points. + + Any consecutive identical points are reduced to a single point + and any points collinear and between two points will be removed + unless they are needed to define an explicit intersection (see examples). + + A Triangle, Segment or Point will be returned when there are 3 or + fewer points provided. + + Examples + ======== + + >>> from sympy import Polygon, pi + >>> p1, p2, p3, p4, p5 = [(0, 0), (1, 0), (5, 1), (0, 1), (3, 0)] + >>> Polygon(p1, p2, p3, p4) + Polygon(Point2D(0, 0), Point2D(1, 0), Point2D(5, 1), Point2D(0, 1)) + >>> Polygon(p1, p2) + Segment2D(Point2D(0, 0), Point2D(1, 0)) + >>> Polygon(p1, p2, p5) + Segment2D(Point2D(0, 0), Point2D(3, 0)) + + The area of a polygon is calculated as positive when vertices are + traversed in a ccw direction. When the sides of a polygon cross the + area will have positive and negative contributions. The following + defines a Z shape where the bottom right connects back to the top + left. + + >>> Polygon((0, 2), (2, 2), (0, 0), (2, 0)).area + 0 + + When the keyword `n` is used to define the number of sides of the + Polygon then a RegularPolygon is created and the other arguments are + interpreted as center, radius and rotation. The unrotated RegularPolygon + will always have a vertex at Point(r, 0) where `r` is the radius of the + circle that circumscribes the RegularPolygon. Its method `spin` can be + used to increment that angle. + + >>> p = Polygon((0,0), 1, n=3) + >>> p + RegularPolygon(Point2D(0, 0), 1, 3, 0) + >>> p.vertices[0] + Point2D(1, 0) + >>> p.args[0] + Point2D(0, 0) + >>> p.spin(pi/2) + >>> p.vertices[0] + Point2D(0, 1) + + """ + + __slots__ = () + + def __new__(cls, *args, n = 0, **kwargs): + if n: + args = list(args) + # return a virtual polygon with n sides + if len(args) == 2: # center, radius + args.append(n) + elif len(args) == 3: # center, radius, rotation + args.insert(2, n) + return RegularPolygon(*args, **kwargs) + + vertices = [Point(a, dim=2, **kwargs) for a in args] + + # remove consecutive duplicates + nodup = [] + for p in vertices: + if nodup and p == nodup[-1]: + continue + nodup.append(p) + if len(nodup) > 1 and nodup[-1] == nodup[0]: + nodup.pop() # last point was same as first + + # remove collinear points + i = -3 + while i < len(nodup) - 3 and len(nodup) > 2: + a, b, c = nodup[i], nodup[i + 1], nodup[i + 2] + if Point.is_collinear(a, b, c): + nodup.pop(i + 1) + if a == c: + nodup.pop(i) + else: + i += 1 + + vertices = list(nodup) + + if len(vertices) > 3: + return GeometryEntity.__new__(cls, *vertices, **kwargs) + elif len(vertices) == 3: + return Triangle(*vertices, **kwargs) + elif len(vertices) == 2: + return Segment(*vertices, **kwargs) + else: + return Point(*vertices, **kwargs) + + @property + def area(self): + """ + The area of the polygon. + + Notes + ===== + + The area calculation can be positive or negative based on the + orientation of the points. If any side of the polygon crosses + any other side, there will be areas having opposite signs. + + See Also + ======== + + sympy.geometry.ellipse.Ellipse.area + + Examples + ======== + + >>> from sympy import Point, Polygon + >>> p1, p2, p3, p4 = map(Point, [(0, 0), (1, 0), (5, 1), (0, 1)]) + >>> poly = Polygon(p1, p2, p3, p4) + >>> poly.area + 3 + + In the Z shaped polygon (with the lower right connecting back + to the upper left) the areas cancel out: + + >>> Z = Polygon((0, 1), (1, 1), (0, 0), (1, 0)) + >>> Z.area + 0 + + In the M shaped polygon, areas do not cancel because no side + crosses any other (though there is a point of contact). + + >>> M = Polygon((0, 0), (0, 1), (2, 0), (3, 1), (3, 0)) + >>> M.area + -3/2 + + """ + area = 0 + args = self.args + for i in range(len(args)): + x1, y1 = args[i - 1].args + x2, y2 = args[i].args + area += x1*y2 - x2*y1 + return simplify(area) / 2 + + @staticmethod + def _is_clockwise(a, b, c): + """Return True/False for cw/ccw orientation. + + Examples + ======== + + >>> from sympy import Point, Polygon + >>> a, b, c = [Point(i) for i in [(0, 0), (1, 1), (1, 0)]] + >>> Polygon._is_clockwise(a, b, c) + True + >>> Polygon._is_clockwise(a, c, b) + False + """ + ba = b - a + ca = c - a + t_area = simplify(ba.x*ca.y - ca.x*ba.y) + res = t_area.is_nonpositive + if res is None: + raise ValueError("Can't determine orientation") + return res + + @property + def angles(self): + """The internal angle at each vertex. + + Returns + ======= + + angles : dict + A dictionary where each key is a vertex and each value is the + internal angle at that vertex. The vertices are represented as + Points. + + See Also + ======== + + sympy.geometry.point.Point, sympy.geometry.line.LinearEntity.angle_between + + Examples + ======== + + >>> from sympy import Point, Polygon + >>> p1, p2, p3, p4 = map(Point, [(0, 0), (1, 0), (5, 1), (0, 1)]) + >>> poly = Polygon(p1, p2, p3, p4) + >>> poly.angles[p1] + pi/2 + >>> poly.angles[p2] + acos(-4*sqrt(17)/17) + + """ + + args = self.vertices + n = len(args) + ret = {} + for i in range(n): + a, b, c = args[i - 2], args[i - 1], args[i] + reflex_ang = Ray(b, a).angle_between(Ray(b, c)) + if self._is_clockwise(a, b, c): + ret[b] = 2*S.Pi - reflex_ang + else: + ret[b] = reflex_ang + + # internal sum should be pi*(n - 2), not pi*(n+2) + # so if ratio is (n+2)/(n-2) > 1 it is wrong + wrong = ((sum(ret.values())/S.Pi-1)/(n - 2) - 1).is_positive + if wrong: + two_pi = 2*S.Pi + for b in ret: + ret[b] = two_pi - ret[b] + elif wrong is None: + raise ValueError("could not determine Polygon orientation.") + return ret + + @property + def ambient_dimension(self): + return self.vertices[0].ambient_dimension + + @property + def perimeter(self): + """The perimeter of the polygon. + + Returns + ======= + + perimeter : number or Basic instance + + See Also + ======== + + sympy.geometry.line.Segment.length + + Examples + ======== + + >>> from sympy import Point, Polygon + >>> p1, p2, p3, p4 = map(Point, [(0, 0), (1, 0), (5, 1), (0, 1)]) + >>> poly = Polygon(p1, p2, p3, p4) + >>> poly.perimeter + sqrt(17) + 7 + """ + p = 0 + args = self.vertices + for i in range(len(args)): + p += args[i - 1].distance(args[i]) + return simplify(p) + + @property + def vertices(self): + """The vertices of the polygon. + + Returns + ======= + + vertices : list of Points + + Notes + ===== + + When iterating over the vertices, it is more efficient to index self + rather than to request the vertices and index them. Only use the + vertices when you want to process all of them at once. This is even + more important with RegularPolygons that calculate each vertex. + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Polygon + >>> p1, p2, p3, p4 = map(Point, [(0, 0), (1, 0), (5, 1), (0, 1)]) + >>> poly = Polygon(p1, p2, p3, p4) + >>> poly.vertices + [Point2D(0, 0), Point2D(1, 0), Point2D(5, 1), Point2D(0, 1)] + >>> poly.vertices[0] + Point2D(0, 0) + + """ + return list(self.args) + + @property + def centroid(self): + """The centroid of the polygon. + + Returns + ======= + + centroid : Point + + See Also + ======== + + sympy.geometry.point.Point, sympy.geometry.util.centroid + + Examples + ======== + + >>> from sympy import Point, Polygon + >>> p1, p2, p3, p4 = map(Point, [(0, 0), (1, 0), (5, 1), (0, 1)]) + >>> poly = Polygon(p1, p2, p3, p4) + >>> poly.centroid + Point2D(31/18, 11/18) + + """ + A = 1/(6*self.area) + cx, cy = 0, 0 + args = self.args + for i in range(len(args)): + x1, y1 = args[i - 1].args + x2, y2 = args[i].args + v = x1*y2 - x2*y1 + cx += v*(x1 + x2) + cy += v*(y1 + y2) + return Point(simplify(A*cx), simplify(A*cy)) + + + def second_moment_of_area(self, point=None): + """Returns the second moment and product moment of area of a two dimensional polygon. + + Parameters + ========== + + point : Point, two-tuple of sympifyable objects, or None(default=None) + point is the point about which second moment of area is to be found. + If "point=None" it will be calculated about the axis passing through the + centroid of the polygon. + + Returns + ======= + + I_xx, I_yy, I_xy : number or SymPy expression + I_xx, I_yy are second moment of area of a two dimensional polygon. + I_xy is product moment of area of a two dimensional polygon. + + Examples + ======== + + >>> from sympy import Polygon, symbols + >>> a, b = symbols('a, b') + >>> p1, p2, p3, p4, p5 = [(0, 0), (a, 0), (a, b), (0, b), (a/3, b/3)] + >>> rectangle = Polygon(p1, p2, p3, p4) + >>> rectangle.second_moment_of_area() + (a*b**3/12, a**3*b/12, 0) + >>> rectangle.second_moment_of_area(p5) + (a*b**3/9, a**3*b/9, a**2*b**2/36) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Second_moment_of_area + + """ + + I_xx, I_yy, I_xy = 0, 0, 0 + args = self.vertices + for i in range(len(args)): + x1, y1 = args[i-1].args + x2, y2 = args[i].args + v = x1*y2 - x2*y1 + I_xx += (y1**2 + y1*y2 + y2**2)*v + I_yy += (x1**2 + x1*x2 + x2**2)*v + I_xy += (x1*y2 + 2*x1*y1 + 2*x2*y2 + x2*y1)*v + A = self.area + c_x = self.centroid[0] + c_y = self.centroid[1] + # parallel axis theorem + I_xx_c = (I_xx/12) - (A*(c_y**2)) + I_yy_c = (I_yy/12) - (A*(c_x**2)) + I_xy_c = (I_xy/24) - (A*(c_x*c_y)) + if point is None: + return I_xx_c, I_yy_c, I_xy_c + + I_xx = (I_xx_c + A*((point[1]-c_y)**2)) + I_yy = (I_yy_c + A*((point[0]-c_x)**2)) + I_xy = (I_xy_c + A*((point[0]-c_x)*(point[1]-c_y))) + + return I_xx, I_yy, I_xy + + + def first_moment_of_area(self, point=None): + """ + Returns the first moment of area of a two-dimensional polygon with + respect to a certain point of interest. + + First moment of area is a measure of the distribution of the area + of a polygon in relation to an axis. The first moment of area of + the entire polygon about its own centroid is always zero. Therefore, + here it is calculated for an area, above or below a certain point + of interest, that makes up a smaller portion of the polygon. This + area is bounded by the point of interest and the extreme end + (top or bottom) of the polygon. The first moment for this area is + is then determined about the centroidal axis of the initial polygon. + + References + ========== + + .. [1] https://skyciv.com/docs/tutorials/section-tutorials/calculating-the-statical-or-first-moment-of-area-of-beam-sections/?cc=BMD + .. [2] https://mechanicalc.com/reference/cross-sections + + Parameters + ========== + + point: Point, two-tuple of sympifyable objects, or None (default=None) + point is the point above or below which the area of interest lies + If ``point=None`` then the centroid acts as the point of interest. + + Returns + ======= + + Q_x, Q_y: number or SymPy expressions + Q_x is the first moment of area about the x-axis + Q_y is the first moment of area about the y-axis + A negative sign indicates that the section modulus is + determined for a section below (or left of) the centroidal axis + + Examples + ======== + + >>> from sympy import Point, Polygon + >>> a, b = 50, 10 + >>> p1, p2, p3, p4 = [(0, b), (0, 0), (a, 0), (a, b)] + >>> p = Polygon(p1, p2, p3, p4) + >>> p.first_moment_of_area() + (625, 3125) + >>> p.first_moment_of_area(point=Point(30, 7)) + (525, 3000) + """ + if point: + xc, yc = self.centroid + else: + point = self.centroid + xc, yc = point + + h_line = Line(point, slope=0) + v_line = Line(point, slope=S.Infinity) + + h_poly = self.cut_section(h_line) + v_poly = self.cut_section(v_line) + + poly_1 = h_poly[0] if h_poly[0].area <= h_poly[1].area else h_poly[1] + poly_2 = v_poly[0] if v_poly[0].area <= v_poly[1].area else v_poly[1] + + Q_x = (poly_1.centroid.y - yc)*poly_1.area + Q_y = (poly_2.centroid.x - xc)*poly_2.area + + return Q_x, Q_y + + + def polar_second_moment_of_area(self): + """Returns the polar modulus of a two-dimensional polygon + + It is a constituent of the second moment of area, linked through + the perpendicular axis theorem. While the planar second moment of + area describes an object's resistance to deflection (bending) when + subjected to a force applied to a plane parallel to the central + axis, the polar second moment of area describes an object's + resistance to deflection when subjected to a moment applied in a + plane perpendicular to the object's central axis (i.e. parallel to + the cross-section) + + Examples + ======== + + >>> from sympy import Polygon, symbols + >>> a, b = symbols('a, b') + >>> rectangle = Polygon((0, 0), (a, 0), (a, b), (0, b)) + >>> rectangle.polar_second_moment_of_area() + a**3*b/12 + a*b**3/12 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Polar_moment_of_inertia + + """ + second_moment = self.second_moment_of_area() + return second_moment[0] + second_moment[1] + + + def section_modulus(self, point=None): + """Returns a tuple with the section modulus of a two-dimensional + polygon. + + Section modulus is a geometric property of a polygon defined as the + ratio of second moment of area to the distance of the extreme end of + the polygon from the centroidal axis. + + Parameters + ========== + + point : Point, two-tuple of sympifyable objects, or None(default=None) + point is the point at which section modulus is to be found. + If "point=None" it will be calculated for the point farthest from the + centroidal axis of the polygon. + + Returns + ======= + + S_x, S_y: numbers or SymPy expressions + S_x is the section modulus with respect to the x-axis + S_y is the section modulus with respect to the y-axis + A negative sign indicates that the section modulus is + determined for a point below the centroidal axis + + Examples + ======== + + >>> from sympy import symbols, Polygon, Point + >>> a, b = symbols('a, b', positive=True) + >>> rectangle = Polygon((0, 0), (a, 0), (a, b), (0, b)) + >>> rectangle.section_modulus() + (a*b**2/6, a**2*b/6) + >>> rectangle.section_modulus(Point(a/4, b/4)) + (-a*b**2/3, -a**2*b/3) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Section_modulus + + """ + x_c, y_c = self.centroid + if point is None: + # taking x and y as maximum distances from centroid + x_min, y_min, x_max, y_max = self.bounds + y = max(y_c - y_min, y_max - y_c) + x = max(x_c - x_min, x_max - x_c) + else: + # taking x and y as distances of the given point from the centroid + y = point.y - y_c + x = point.x - x_c + + second_moment= self.second_moment_of_area() + S_x = second_moment[0]/y + S_y = second_moment[1]/x + + return S_x, S_y + + + @property + def sides(self): + """The directed line segments that form the sides of the polygon. + + Returns + ======= + + sides : list of sides + Each side is a directed Segment. + + See Also + ======== + + sympy.geometry.point.Point, sympy.geometry.line.Segment + + Examples + ======== + + >>> from sympy import Point, Polygon + >>> p1, p2, p3, p4 = map(Point, [(0, 0), (1, 0), (5, 1), (0, 1)]) + >>> poly = Polygon(p1, p2, p3, p4) + >>> poly.sides + [Segment2D(Point2D(0, 0), Point2D(1, 0)), + Segment2D(Point2D(1, 0), Point2D(5, 1)), + Segment2D(Point2D(5, 1), Point2D(0, 1)), Segment2D(Point2D(0, 1), Point2D(0, 0))] + + """ + res = [] + args = self.vertices + for i in range(-len(args), 0): + res.append(Segment(args[i], args[i + 1])) + return res + + @property + def bounds(self): + """Return a tuple (xmin, ymin, xmax, ymax) representing the bounding + rectangle for the geometric figure. + + """ + + verts = self.vertices + xs = [p.x for p in verts] + ys = [p.y for p in verts] + return (min(xs), min(ys), max(xs), max(ys)) + + def is_convex(self): + """Is the polygon convex? + + A polygon is convex if all its interior angles are less than 180 + degrees and there are no intersections between sides. + + Returns + ======= + + is_convex : boolean + True if this polygon is convex, False otherwise. + + See Also + ======== + + sympy.geometry.util.convex_hull + + Examples + ======== + + >>> from sympy import Point, Polygon + >>> p1, p2, p3, p4 = map(Point, [(0, 0), (1, 0), (5, 1), (0, 1)]) + >>> poly = Polygon(p1, p2, p3, p4) + >>> poly.is_convex() + True + + """ + # Determine orientation of points + args = self.vertices + cw = self._is_clockwise(args[-2], args[-1], args[0]) + for i in range(1, len(args)): + if cw ^ self._is_clockwise(args[i - 2], args[i - 1], args[i]): + return False + # check for intersecting sides + sides = self.sides + for i, si in enumerate(sides): + pts = si.args + # exclude the sides connected to si + for j in range(1 if i == len(sides) - 1 else 0, i - 1): + sj = sides[j] + if sj.p1 not in pts and sj.p2 not in pts: + hit = si.intersection(sj) + if hit: + return False + return True + + def encloses_point(self, p): + """ + Return True if p is enclosed by (is inside of) self. + + Notes + ===== + + Being on the border of self is considered False. + + Parameters + ========== + + p : Point + + Returns + ======= + + encloses_point : True, False or None + + See Also + ======== + + sympy.geometry.point.Point, sympy.geometry.ellipse.Ellipse.encloses_point + + Examples + ======== + + >>> from sympy import Polygon, Point + >>> p = Polygon((0, 0), (4, 0), (4, 4)) + >>> p.encloses_point(Point(2, 1)) + True + >>> p.encloses_point(Point(2, 2)) + False + >>> p.encloses_point(Point(5, 5)) + False + + References + ========== + + .. [1] https://paulbourke.net/geometry/polygonmesh/#insidepoly + + """ + p = Point(p, dim=2) + if p in self.vertices or any(p in s for s in self.sides): + return False + + # move to p, checking that the result is numeric + lit = [] + for v in self.vertices: + lit.append(v - p) # the difference is simplified + if lit[-1].free_symbols: + return None + + poly = Polygon(*lit) + + # polygon closure is assumed in the following test but Polygon removes duplicate pts so + # the last point has to be added so all sides are computed. Using Polygon.sides is + # not good since Segments are unordered. + args = poly.args + indices = list(range(-len(args), 1)) + + if poly.is_convex(): + orientation = None + for i in indices: + a = args[i] + b = args[i + 1] + test = ((-a.y)*(b.x - a.x) - (-a.x)*(b.y - a.y)).is_negative + if orientation is None: + orientation = test + elif test is not orientation: + return False + return True + + hit_odd = False + p1x, p1y = args[0].args + for i in indices[1:]: + p2x, p2y = args[i].args + if 0 > min(p1y, p2y): + if 0 <= max(p1y, p2y): + if 0 <= max(p1x, p2x): + if p1y != p2y: + xinters = (-p1y)*(p2x - p1x)/(p2y - p1y) + p1x + if p1x == p2x or 0 <= xinters: + hit_odd = not hit_odd + p1x, p1y = p2x, p2y + return hit_odd + + def arbitrary_point(self, parameter='t'): + """A parameterized point on the polygon. + + The parameter, varying from 0 to 1, assigns points to the position on + the perimeter that is that fraction of the total perimeter. So the + point evaluated at t=1/2 would return the point from the first vertex + that is 1/2 way around the polygon. + + Parameters + ========== + + parameter : str, optional + Default value is 't'. + + Returns + ======= + + arbitrary_point : Point + + Raises + ====== + + ValueError + When `parameter` already appears in the Polygon's definition. + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Polygon, Symbol + >>> t = Symbol('t', real=True) + >>> tri = Polygon((0, 0), (1, 0), (1, 1)) + >>> p = tri.arbitrary_point('t') + >>> perimeter = tri.perimeter + >>> s1, s2 = [s.length for s in tri.sides[:2]] + >>> p.subs(t, (s1 + s2/2)/perimeter) + Point2D(1, 1/2) + + """ + t = _symbol(parameter, real=True) + if t.name in (f.name for f in self.free_symbols): + raise ValueError('Symbol %s already appears in object and cannot be used as a parameter.' % t.name) + sides = [] + perimeter = self.perimeter + perim_fraction_start = 0 + for s in self.sides: + side_perim_fraction = s.length/perimeter + perim_fraction_end = perim_fraction_start + side_perim_fraction + pt = s.arbitrary_point(parameter).subs( + t, (t - perim_fraction_start)/side_perim_fraction) + sides.append( + (pt, (And(perim_fraction_start <= t, t < perim_fraction_end)))) + perim_fraction_start = perim_fraction_end + return Piecewise(*sides) + + def parameter_value(self, other, t): + if not isinstance(other,GeometryEntity): + other = Point(other, dim=self.ambient_dimension) + if not isinstance(other,Point): + raise ValueError("other must be a point") + if other.free_symbols: + raise NotImplementedError('non-numeric coordinates') + unknown = False + p = self.arbitrary_point(T) + for pt, cond in p.args: + sol = solve(pt - other, T, dict=True) + if not sol: + continue + value = sol[0][T] + if simplify(cond.subs(T, value)) == True: + return {t: value} + unknown = True + if unknown: + raise ValueError("Given point may not be on %s" % func_name(self)) + raise ValueError("Given point is not on %s" % func_name(self)) + + def plot_interval(self, parameter='t'): + """The plot interval for the default geometric plot of the polygon. + + Parameters + ========== + + parameter : str, optional + Default value is 't'. + + Returns + ======= + + plot_interval : list (plot interval) + [parameter, lower_bound, upper_bound] + + Examples + ======== + + >>> from sympy import Polygon + >>> p = Polygon((0, 0), (1, 0), (1, 1)) + >>> p.plot_interval() + [t, 0, 1] + + """ + t = Symbol(parameter, real=True) + return [t, 0, 1] + + def intersection(self, o): + """The intersection of polygon and geometry entity. + + The intersection may be empty and can contain individual Points and + complete Line Segments. + + Parameters + ========== + + other: GeometryEntity + + Returns + ======= + + intersection : list + The list of Segments and Points + + See Also + ======== + + sympy.geometry.point.Point, sympy.geometry.line.Segment + + Examples + ======== + + >>> from sympy import Point, Polygon, Line + >>> p1, p2, p3, p4 = map(Point, [(0, 0), (1, 0), (5, 1), (0, 1)]) + >>> poly1 = Polygon(p1, p2, p3, p4) + >>> p5, p6, p7 = map(Point, [(3, 2), (1, -1), (0, 2)]) + >>> poly2 = Polygon(p5, p6, p7) + >>> poly1.intersection(poly2) + [Point2D(1/3, 1), Point2D(2/3, 0), Point2D(9/5, 1/5), Point2D(7/3, 1)] + >>> poly1.intersection(Line(p1, p2)) + [Segment2D(Point2D(0, 0), Point2D(1, 0))] + >>> poly1.intersection(p1) + [Point2D(0, 0)] + """ + intersection_result = [] + k = o.sides if isinstance(o, Polygon) else [o] + for side in self.sides: + for side1 in k: + intersection_result.extend(side.intersection(side1)) + + intersection_result = list(uniq(intersection_result)) + points = [entity for entity in intersection_result if isinstance(entity, Point)] + segments = [entity for entity in intersection_result if isinstance(entity, Segment)] + + if points and segments: + points_in_segments = list(uniq([point for point in points for segment in segments if point in segment])) + if points_in_segments: + for i in points_in_segments: + points.remove(i) + return list(ordered(segments + points)) + else: + return list(ordered(intersection_result)) + + + def cut_section(self, line): + """ + Returns a tuple of two polygon segments that lie above and below + the intersecting line respectively. + + Parameters + ========== + + line: Line object of geometry module + line which cuts the Polygon. The part of the Polygon that lies + above and below this line is returned. + + Returns + ======= + + upper_polygon, lower_polygon: Polygon objects or None + upper_polygon is the polygon that lies above the given line. + lower_polygon is the polygon that lies below the given line. + upper_polygon and lower polygon are ``None`` when no polygon + exists above the line or below the line. + + Raises + ====== + + ValueError: When the line does not intersect the polygon + + Examples + ======== + + >>> from sympy import Polygon, Line + >>> a, b = 20, 10 + >>> p1, p2, p3, p4 = [(0, b), (0, 0), (a, 0), (a, b)] + >>> rectangle = Polygon(p1, p2, p3, p4) + >>> t = rectangle.cut_section(Line((0, 5), slope=0)) + >>> t + (Polygon(Point2D(0, 10), Point2D(0, 5), Point2D(20, 5), Point2D(20, 10)), + Polygon(Point2D(0, 5), Point2D(0, 0), Point2D(20, 0), Point2D(20, 5))) + >>> upper_segment, lower_segment = t + >>> upper_segment.area + 100 + >>> upper_segment.centroid + Point2D(10, 15/2) + >>> lower_segment.centroid + Point2D(10, 5/2) + + References + ========== + + .. [1] https://github.com/sympy/sympy/wiki/A-method-to-return-a-cut-section-of-any-polygon-geometry + + """ + intersection_points = self.intersection(line) + if not intersection_points: + raise ValueError("This line does not intersect the polygon") + + points = list(self.vertices) + points.append(points[0]) + + eq = line.equation(x, y) + + # considering equation of line to be `ax +by + c` + a = eq.coeff(x) + b = eq.coeff(y) + + upper_vertices = [] + lower_vertices = [] + # prev is true when previous point is above the line + prev = True + prev_point = None + for point in points: + # when coefficient of y is 0, right side of the line is + # considered + compare = eq.subs({x: point.x, y: point.y})/b if b \ + else eq.subs(x, point.x)/a + + # if point lies above line + if compare > 0: + if not prev: + # if previous point lies below the line, the intersection + # point of the polygon edge and the line has to be included + edge = Line(point, prev_point) + new_point = edge.intersection(line) + upper_vertices.append(new_point[0]) + lower_vertices.append(new_point[0]) + + upper_vertices.append(point) + prev = True + else: + if prev and prev_point: + edge = Line(point, prev_point) + new_point = edge.intersection(line) + upper_vertices.append(new_point[0]) + lower_vertices.append(new_point[0]) + lower_vertices.append(point) + prev = False + prev_point = point + + upper_polygon, lower_polygon = None, None + if upper_vertices and isinstance(Polygon(*upper_vertices), Polygon): + upper_polygon = Polygon(*upper_vertices) + if lower_vertices and isinstance(Polygon(*lower_vertices), Polygon): + lower_polygon = Polygon(*lower_vertices) + + return upper_polygon, lower_polygon + + + def distance(self, o): + """ + Returns the shortest distance between self and o. + + If o is a point, then self does not need to be convex. + If o is another polygon self and o must be convex. + + Examples + ======== + + >>> from sympy import Point, Polygon, RegularPolygon + >>> p1, p2 = map(Point, [(0, 0), (7, 5)]) + >>> poly = Polygon(*RegularPolygon(p1, 1, 3).vertices) + >>> poly.distance(p2) + sqrt(61) + """ + if isinstance(o, Point): + dist = oo + for side in self.sides: + current = side.distance(o) + if current == 0: + return S.Zero + elif current < dist: + dist = current + return dist + elif isinstance(o, Polygon) and self.is_convex() and o.is_convex(): + return self._do_poly_distance(o) + raise NotImplementedError() + + def _do_poly_distance(self, e2): + """ + Calculates the least distance between the exteriors of two + convex polygons e1 and e2. Does not check for the convexity + of the polygons as this is checked by Polygon.distance. + + Notes + ===== + + - Prints a warning if the two polygons possibly intersect as the return + value will not be valid in such a case. For a more through test of + intersection use intersection(). + + See Also + ======== + + sympy.geometry.point.Point.distance + + Examples + ======== + + >>> from sympy import Point, Polygon + >>> square = Polygon(Point(0, 0), Point(0, 1), Point(1, 1), Point(1, 0)) + >>> triangle = Polygon(Point(1, 2), Point(2, 2), Point(2, 1)) + >>> square._do_poly_distance(triangle) + sqrt(2)/2 + + Description of method used + ========================== + + Method: + [1] https://web.archive.org/web/20150509035744/http://cgm.cs.mcgill.ca/~orm/mind2p.html + Uses rotating calipers: + [2] https://en.wikipedia.org/wiki/Rotating_calipers + and antipodal points: + [3] https://en.wikipedia.org/wiki/Antipodal_point + """ + e1 = self + + '''Tests for a possible intersection between the polygons and outputs a warning''' + e1_center = e1.centroid + e2_center = e2.centroid + e1_max_radius = S.Zero + e2_max_radius = S.Zero + for vertex in e1.vertices: + r = Point.distance(e1_center, vertex) + if e1_max_radius < r: + e1_max_radius = r + for vertex in e2.vertices: + r = Point.distance(e2_center, vertex) + if e2_max_radius < r: + e2_max_radius = r + center_dist = Point.distance(e1_center, e2_center) + if center_dist <= e1_max_radius + e2_max_radius: + warnings.warn("Polygons may intersect producing erroneous output", + stacklevel=3) + + ''' + Find the upper rightmost vertex of e1 and the lowest leftmost vertex of e2 + ''' + e1_ymax = Point(0, -oo) + e2_ymin = Point(0, oo) + + for vertex in e1.vertices: + if vertex.y > e1_ymax.y or (vertex.y == e1_ymax.y and vertex.x > e1_ymax.x): + e1_ymax = vertex + for vertex in e2.vertices: + if vertex.y < e2_ymin.y or (vertex.y == e2_ymin.y and vertex.x < e2_ymin.x): + e2_ymin = vertex + min_dist = Point.distance(e1_ymax, e2_ymin) + + ''' + Produce a dictionary with vertices of e1 as the keys and, for each vertex, the points + to which the vertex is connected as its value. The same is then done for e2. + ''' + e1_connections = {} + e2_connections = {} + + for side in e1.sides: + if side.p1 in e1_connections: + e1_connections[side.p1].append(side.p2) + else: + e1_connections[side.p1] = [side.p2] + + if side.p2 in e1_connections: + e1_connections[side.p2].append(side.p1) + else: + e1_connections[side.p2] = [side.p1] + + for side in e2.sides: + if side.p1 in e2_connections: + e2_connections[side.p1].append(side.p2) + else: + e2_connections[side.p1] = [side.p2] + + if side.p2 in e2_connections: + e2_connections[side.p2].append(side.p1) + else: + e2_connections[side.p2] = [side.p1] + + e1_current = e1_ymax + e2_current = e2_ymin + support_line = Line(Point(S.Zero, S.Zero), Point(S.One, S.Zero)) + + ''' + Determine which point in e1 and e2 will be selected after e2_ymin and e1_ymax, + this information combined with the above produced dictionaries determines the + path that will be taken around the polygons + ''' + point1 = e1_connections[e1_ymax][0] + point2 = e1_connections[e1_ymax][1] + angle1 = support_line.angle_between(Line(e1_ymax, point1)) + angle2 = support_line.angle_between(Line(e1_ymax, point2)) + if angle1 < angle2: + e1_next = point1 + elif angle2 < angle1: + e1_next = point2 + elif Point.distance(e1_ymax, point1) > Point.distance(e1_ymax, point2): + e1_next = point2 + else: + e1_next = point1 + + point1 = e2_connections[e2_ymin][0] + point2 = e2_connections[e2_ymin][1] + angle1 = support_line.angle_between(Line(e2_ymin, point1)) + angle2 = support_line.angle_between(Line(e2_ymin, point2)) + if angle1 > angle2: + e2_next = point1 + elif angle2 > angle1: + e2_next = point2 + elif Point.distance(e2_ymin, point1) > Point.distance(e2_ymin, point2): + e2_next = point2 + else: + e2_next = point1 + + ''' + Loop which determines the distance between anti-podal pairs and updates the + minimum distance accordingly. It repeats until it reaches the starting position. + ''' + while True: + e1_angle = support_line.angle_between(Line(e1_current, e1_next)) + e2_angle = pi - support_line.angle_between(Line( + e2_current, e2_next)) + + if (e1_angle < e2_angle) is True: + support_line = Line(e1_current, e1_next) + e1_segment = Segment(e1_current, e1_next) + min_dist_current = e1_segment.distance(e2_current) + + if min_dist_current.evalf() < min_dist.evalf(): + min_dist = min_dist_current + + if e1_connections[e1_next][0] != e1_current: + e1_current = e1_next + e1_next = e1_connections[e1_next][0] + else: + e1_current = e1_next + e1_next = e1_connections[e1_next][1] + elif (e1_angle > e2_angle) is True: + support_line = Line(e2_next, e2_current) + e2_segment = Segment(e2_current, e2_next) + min_dist_current = e2_segment.distance(e1_current) + + if min_dist_current.evalf() < min_dist.evalf(): + min_dist = min_dist_current + + if e2_connections[e2_next][0] != e2_current: + e2_current = e2_next + e2_next = e2_connections[e2_next][0] + else: + e2_current = e2_next + e2_next = e2_connections[e2_next][1] + else: + support_line = Line(e1_current, e1_next) + e1_segment = Segment(e1_current, e1_next) + e2_segment = Segment(e2_current, e2_next) + min1 = e1_segment.distance(e2_next) + min2 = e2_segment.distance(e1_next) + + min_dist_current = min(min1, min2) + if min_dist_current.evalf() < min_dist.evalf(): + min_dist = min_dist_current + + if e1_connections[e1_next][0] != e1_current: + e1_current = e1_next + e1_next = e1_connections[e1_next][0] + else: + e1_current = e1_next + e1_next = e1_connections[e1_next][1] + + if e2_connections[e2_next][0] != e2_current: + e2_current = e2_next + e2_next = e2_connections[e2_next][0] + else: + e2_current = e2_next + e2_next = e2_connections[e2_next][1] + if e1_current == e1_ymax and e2_current == e2_ymin: + break + return min_dist + + def _svg(self, scale_factor=1., fill_color="#66cc99"): + """Returns SVG path element for the Polygon. + + Parameters + ========== + + scale_factor : float + Multiplication factor for the SVG stroke-width. Default is 1. + fill_color : str, optional + Hex string for fill color. Default is "#66cc99". + """ + verts = map(N, self.vertices) + coords = ["{},{}".format(p.x, p.y) for p in verts] + path = "M {} L {} z".format(coords[0], " L ".join(coords[1:])) + return ( + '' + ).format(2. * scale_factor, path, fill_color) + + def _hashable_content(self): + + D = {} + def ref_list(point_list): + kee = {} + for i, p in enumerate(ordered(set(point_list))): + kee[p] = i + D[i] = p + return [kee[p] for p in point_list] + + S1 = ref_list(self.args) + r_nor = rotate_left(S1, least_rotation(S1)) + S2 = ref_list(list(reversed(self.args))) + r_rev = rotate_left(S2, least_rotation(S2)) + if r_nor < r_rev: + r = r_nor + else: + r = r_rev + canonical_args = [ D[order] for order in r ] + return tuple(canonical_args) + + def __contains__(self, o): + """ + Return True if o is contained within the boundary lines of self.altitudes + + Parameters + ========== + + other : GeometryEntity + + Returns + ======= + + contained in : bool + The points (and sides, if applicable) are contained in self. + + See Also + ======== + + sympy.geometry.entity.GeometryEntity.encloses + + Examples + ======== + + >>> from sympy import Line, Segment, Point + >>> p = Point(0, 0) + >>> q = Point(1, 1) + >>> s = Segment(p, q*2) + >>> l = Line(p, q) + >>> p in q + False + >>> p in s + True + >>> q*3 in s + False + >>> s in l + True + + """ + + if isinstance(o, Polygon): + return self == o + elif isinstance(o, Segment): + return any(o in s for s in self.sides) + elif isinstance(o, Point): + if o in self.vertices: + return True + for side in self.sides: + if o in side: + return True + + return False + + def bisectors(p, prec=None): + """Returns angle bisectors of a polygon. If prec is given + then approximate the point defining the ray to that precision. + + The distance between the points defining the bisector ray is 1. + + Examples + ======== + + >>> from sympy import Polygon, Point + >>> p = Polygon(Point(0, 0), Point(2, 0), Point(1, 1), Point(0, 3)) + >>> p.bisectors(2) + {Point2D(0, 0): Ray2D(Point2D(0, 0), Point2D(0.71, 0.71)), + Point2D(0, 3): Ray2D(Point2D(0, 3), Point2D(0.23, 2.0)), + Point2D(1, 1): Ray2D(Point2D(1, 1), Point2D(0.19, 0.42)), + Point2D(2, 0): Ray2D(Point2D(2, 0), Point2D(1.1, 0.38))} + """ + b = {} + pts = list(p.args) + pts.append(pts[0]) # close it + cw = Polygon._is_clockwise(*pts[:3]) + if cw: + pts = list(reversed(pts)) + for v, a in p.angles.items(): + i = pts.index(v) + p1, p2 = Point._normalize_dimension(pts[i], pts[i + 1]) + ray = Ray(p1, p2).rotate(a/2, v) + dir = ray.direction + ray = Ray(ray.p1, ray.p1 + dir/dir.distance((0, 0))) + if prec is not None: + ray = Ray(ray.p1, ray.p2.n(prec)) + b[v] = ray + return b + + +class RegularPolygon(Polygon): + """ + A regular polygon. + + Such a polygon has all internal angles equal and all sides the same length. + + Parameters + ========== + + center : Point + radius : number or Basic instance + The distance from the center to a vertex + n : int + The number of sides + + Attributes + ========== + + vertices + center + radius + rotation + apothem + interior_angle + exterior_angle + circumcircle + incircle + angles + + Raises + ====== + + GeometryError + If the `center` is not a Point, or the `radius` is not a number or Basic + instance, or the number of sides, `n`, is less than three. + + Notes + ===== + + A RegularPolygon can be instantiated with Polygon with the kwarg n. + + Regular polygons are instantiated with a center, radius, number of sides + and a rotation angle. Whereas the arguments of a Polygon are vertices, the + vertices of the RegularPolygon must be obtained with the vertices method. + + See Also + ======== + + sympy.geometry.point.Point, Polygon + + Examples + ======== + + >>> from sympy import RegularPolygon, Point + >>> r = RegularPolygon(Point(0, 0), 5, 3) + >>> r + RegularPolygon(Point2D(0, 0), 5, 3, 0) + >>> r.vertices[0] + Point2D(5, 0) + + """ + + __slots__ = ('_n', '_center', '_radius', '_rot') + + def __new__(self, c, r, n, rot=0, **kwargs): + r, n, rot = map(sympify, (r, n, rot)) + c = Point(c, dim=2, **kwargs) + if not isinstance(r, Expr): + raise GeometryError("r must be an Expr object, not %s" % r) + if n.is_Number: + as_int(n) # let an error raise if necessary + if n < 3: + raise GeometryError("n must be a >= 3, not %s" % n) + + obj = GeometryEntity.__new__(self, c, r, n, **kwargs) + obj._n = n + obj._center = c + obj._radius = r + obj._rot = rot % (2*S.Pi/n) if rot.is_number else rot + return obj + + def _eval_evalf(self, prec=15, **options): + c, r, n, a = self.args + dps = prec_to_dps(prec) + c, r, a = [i.evalf(n=dps, **options) for i in (c, r, a)] + return self.func(c, r, n, a) + + @property + def args(self): + """ + Returns the center point, the radius, + the number of sides, and the orientation angle. + + Examples + ======== + + >>> from sympy import RegularPolygon, Point + >>> r = RegularPolygon(Point(0, 0), 5, 3) + >>> r.args + (Point2D(0, 0), 5, 3, 0) + """ + return self._center, self._radius, self._n, self._rot + + def __str__(self): + return 'RegularPolygon(%s, %s, %s, %s)' % tuple(self.args) + + def __repr__(self): + return 'RegularPolygon(%s, %s, %s, %s)' % tuple(self.args) + + @property + def area(self): + """Returns the area. + + Examples + ======== + + >>> from sympy import RegularPolygon + >>> square = RegularPolygon((0, 0), 1, 4) + >>> square.area + 2 + >>> _ == square.length**2 + True + """ + c, r, n, rot = self.args + return sign(r)*n*self.length**2/(4*tan(pi/n)) + + @property + def length(self): + """Returns the length of the sides. + + The half-length of the side and the apothem form two legs + of a right triangle whose hypotenuse is the radius of the + regular polygon. + + Examples + ======== + + >>> from sympy import RegularPolygon + >>> from sympy import sqrt + >>> s = square_in_unit_circle = RegularPolygon((0, 0), 1, 4) + >>> s.length + sqrt(2) + >>> sqrt((_/2)**2 + s.apothem**2) == s.radius + True + + """ + return self.radius*2*sin(pi/self._n) + + @property + def center(self): + """The center of the RegularPolygon + + This is also the center of the circumscribing circle. + + Returns + ======= + + center : Point + + See Also + ======== + + sympy.geometry.point.Point, sympy.geometry.ellipse.Ellipse.center + + Examples + ======== + + >>> from sympy import RegularPolygon, Point + >>> rp = RegularPolygon(Point(0, 0), 5, 4) + >>> rp.center + Point2D(0, 0) + """ + return self._center + + centroid = center + + @property + def circumcenter(self): + """ + Alias for center. + + Examples + ======== + + >>> from sympy import RegularPolygon, Point + >>> rp = RegularPolygon(Point(0, 0), 5, 4) + >>> rp.circumcenter + Point2D(0, 0) + """ + return self.center + + @property + def radius(self): + """Radius of the RegularPolygon + + This is also the radius of the circumscribing circle. + + Returns + ======= + + radius : number or instance of Basic + + See Also + ======== + + sympy.geometry.line.Segment.length, sympy.geometry.ellipse.Circle.radius + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy import RegularPolygon, Point + >>> radius = Symbol('r') + >>> rp = RegularPolygon(Point(0, 0), radius, 4) + >>> rp.radius + r + + """ + return self._radius + + @property + def circumradius(self): + """ + Alias for radius. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy import RegularPolygon, Point + >>> radius = Symbol('r') + >>> rp = RegularPolygon(Point(0, 0), radius, 4) + >>> rp.circumradius + r + """ + return self.radius + + @property + def rotation(self): + """CCW angle by which the RegularPolygon is rotated + + Returns + ======= + + rotation : number or instance of Basic + + Examples + ======== + + >>> from sympy import pi + >>> from sympy.abc import a + >>> from sympy import RegularPolygon, Point + >>> RegularPolygon(Point(0, 0), 3, 4, pi/4).rotation + pi/4 + + Numerical rotation angles are made canonical: + + >>> RegularPolygon(Point(0, 0), 3, 4, a).rotation + a + >>> RegularPolygon(Point(0, 0), 3, 4, pi).rotation + 0 + + """ + return self._rot + + @property + def apothem(self): + """The inradius of the RegularPolygon. + + The apothem/inradius is the radius of the inscribed circle. + + Returns + ======= + + apothem : number or instance of Basic + + See Also + ======== + + sympy.geometry.line.Segment.length, sympy.geometry.ellipse.Circle.radius + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy import RegularPolygon, Point + >>> radius = Symbol('r') + >>> rp = RegularPolygon(Point(0, 0), radius, 4) + >>> rp.apothem + sqrt(2)*r/2 + + """ + return self.radius * cos(S.Pi/self._n) + + @property + def inradius(self): + """ + Alias for apothem. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy import RegularPolygon, Point + >>> radius = Symbol('r') + >>> rp = RegularPolygon(Point(0, 0), radius, 4) + >>> rp.inradius + sqrt(2)*r/2 + """ + return self.apothem + + @property + def interior_angle(self): + """Measure of the interior angles. + + Returns + ======= + + interior_angle : number + + See Also + ======== + + sympy.geometry.line.LinearEntity.angle_between + + Examples + ======== + + >>> from sympy import RegularPolygon, Point + >>> rp = RegularPolygon(Point(0, 0), 4, 8) + >>> rp.interior_angle + 3*pi/4 + + """ + return (self._n - 2)*S.Pi/self._n + + @property + def exterior_angle(self): + """Measure of the exterior angles. + + Returns + ======= + + exterior_angle : number + + See Also + ======== + + sympy.geometry.line.LinearEntity.angle_between + + Examples + ======== + + >>> from sympy import RegularPolygon, Point + >>> rp = RegularPolygon(Point(0, 0), 4, 8) + >>> rp.exterior_angle + pi/4 + + """ + return 2*S.Pi/self._n + + @property + def circumcircle(self): + """The circumcircle of the RegularPolygon. + + Returns + ======= + + circumcircle : Circle + + See Also + ======== + + circumcenter, sympy.geometry.ellipse.Circle + + Examples + ======== + + >>> from sympy import RegularPolygon, Point + >>> rp = RegularPolygon(Point(0, 0), 4, 8) + >>> rp.circumcircle + Circle(Point2D(0, 0), 4) + + """ + return Circle(self.center, self.radius) + + @property + def incircle(self): + """The incircle of the RegularPolygon. + + Returns + ======= + + incircle : Circle + + See Also + ======== + + inradius, sympy.geometry.ellipse.Circle + + Examples + ======== + + >>> from sympy import RegularPolygon, Point + >>> rp = RegularPolygon(Point(0, 0), 4, 7) + >>> rp.incircle + Circle(Point2D(0, 0), 4*cos(pi/7)) + + """ + return Circle(self.center, self.apothem) + + @property + def angles(self): + """ + Returns a dictionary with keys, the vertices of the Polygon, + and values, the interior angle at each vertex. + + Examples + ======== + + >>> from sympy import RegularPolygon, Point + >>> r = RegularPolygon(Point(0, 0), 5, 3) + >>> r.angles + {Point2D(-5/2, -5*sqrt(3)/2): pi/3, + Point2D(-5/2, 5*sqrt(3)/2): pi/3, + Point2D(5, 0): pi/3} + """ + ret = {} + ang = self.interior_angle + for v in self.vertices: + ret[v] = ang + return ret + + def encloses_point(self, p): + """ + Return True if p is enclosed by (is inside of) self. + + Notes + ===== + + Being on the border of self is considered False. + + The general Polygon.encloses_point method is called only if + a point is not within or beyond the incircle or circumcircle, + respectively. + + Parameters + ========== + + p : Point + + Returns + ======= + + encloses_point : True, False or None + + See Also + ======== + + sympy.geometry.ellipse.Ellipse.encloses_point + + Examples + ======== + + >>> from sympy import RegularPolygon, S, Point, Symbol + >>> p = RegularPolygon((0, 0), 3, 4) + >>> p.encloses_point(Point(0, 0)) + True + >>> r, R = p.inradius, p.circumradius + >>> p.encloses_point(Point((r + R)/2, 0)) + True + >>> p.encloses_point(Point(R/2, R/2 + (R - r)/10)) + False + >>> t = Symbol('t', real=True) + >>> p.encloses_point(p.arbitrary_point().subs(t, S.Half)) + False + >>> p.encloses_point(Point(5, 5)) + False + + """ + + c = self.center + d = Segment(c, p).length + if d >= self.radius: + return False + elif d < self.inradius: + return True + else: + # now enumerate the RegularPolygon like a general polygon. + return Polygon.encloses_point(self, p) + + def spin(self, angle): + """Increment *in place* the virtual Polygon's rotation by ccw angle. + + See also: rotate method which moves the center. + + >>> from sympy import Polygon, Point, pi + >>> r = Polygon(Point(0,0), 1, n=3) + >>> r.vertices[0] + Point2D(1, 0) + >>> r.spin(pi/6) + >>> r.vertices[0] + Point2D(sqrt(3)/2, 1/2) + + See Also + ======== + + rotation + rotate : Creates a copy of the RegularPolygon rotated about a Point + + """ + self._rot += angle + + def rotate(self, angle, pt=None): + """Override GeometryEntity.rotate to first rotate the RegularPolygon + about its center. + + >>> from sympy import Point, RegularPolygon, pi + >>> t = RegularPolygon(Point(1, 0), 1, 3) + >>> t.vertices[0] # vertex on x-axis + Point2D(2, 0) + >>> t.rotate(pi/2).vertices[0] # vertex on y axis now + Point2D(0, 2) + + See Also + ======== + + rotation + spin : Rotates a RegularPolygon in place + + """ + + r = type(self)(*self.args) # need a copy or else changes are in-place + r._rot += angle + return GeometryEntity.rotate(r, angle, pt) + + def scale(self, x=1, y=1, pt=None): + """Override GeometryEntity.scale since it is the radius that must be + scaled (if x == y) or else a new Polygon must be returned. + + >>> from sympy import RegularPolygon + + Symmetric scaling returns a RegularPolygon: + + >>> RegularPolygon((0, 0), 1, 4).scale(2, 2) + RegularPolygon(Point2D(0, 0), 2, 4, 0) + + Asymmetric scaling returns a kite as a Polygon: + + >>> RegularPolygon((0, 0), 1, 4).scale(2, 1) + Polygon(Point2D(2, 0), Point2D(0, 1), Point2D(-2, 0), Point2D(0, -1)) + + """ + if pt: + pt = Point(pt, dim=2) + return self.translate(*(-pt).args).scale(x, y).translate(*pt.args) + if x != y: + return Polygon(*self.vertices).scale(x, y) + c, r, n, rot = self.args + r *= x + return self.func(c, r, n, rot) + + def reflect(self, line): + """Override GeometryEntity.reflect since this is not made of only + points. + + Examples + ======== + + >>> from sympy import RegularPolygon, Line + + >>> RegularPolygon((0, 0), 1, 4).reflect(Line((0, 1), slope=-2)) + RegularPolygon(Point2D(4/5, 2/5), -1, 4, atan(4/3)) + + """ + c, r, n, rot = self.args + v = self.vertices[0] + d = v - c + cc = c.reflect(line) + vv = v.reflect(line) + dd = vv - cc + # calculate rotation about the new center + # which will align the vertices + l1 = Ray((0, 0), dd) + l2 = Ray((0, 0), d) + ang = l1.closing_angle(l2) + rot += ang + # change sign of radius as point traversal is reversed + return self.func(cc, -r, n, rot) + + @property + def vertices(self): + """The vertices of the RegularPolygon. + + Returns + ======= + + vertices : list + Each vertex is a Point. + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import RegularPolygon, Point + >>> rp = RegularPolygon(Point(0, 0), 5, 4) + >>> rp.vertices + [Point2D(5, 0), Point2D(0, 5), Point2D(-5, 0), Point2D(0, -5)] + + """ + c = self._center + r = abs(self._radius) + rot = self._rot + v = 2*S.Pi/self._n + + return [Point(c.x + r*cos(k*v + rot), c.y + r*sin(k*v + rot)) + for k in range(self._n)] + + def __eq__(self, o): + if not isinstance(o, Polygon): + return False + elif not isinstance(o, RegularPolygon): + return Polygon.__eq__(o, self) + return self.args == o.args + + def __hash__(self): + return super().__hash__() + + +class Triangle(Polygon): + """ + A polygon with three vertices and three sides. + + Parameters + ========== + + points : sequence of Points + keyword: asa, sas, or sss to specify sides/angles of the triangle + + Attributes + ========== + + vertices + altitudes + orthocenter + circumcenter + circumradius + circumcircle + inradius + incircle + exradii + medians + medial + nine_point_circle + + Raises + ====== + + GeometryError + If the number of vertices is not equal to three, or one of the vertices + is not a Point, or a valid keyword is not given. + + See Also + ======== + + sympy.geometry.point.Point, Polygon + + Examples + ======== + + >>> from sympy import Triangle, Point + >>> Triangle(Point(0, 0), Point(4, 0), Point(4, 3)) + Triangle(Point2D(0, 0), Point2D(4, 0), Point2D(4, 3)) + + Keywords sss, sas, or asa can be used to give the desired + side lengths (in order) and interior angles (in degrees) that + define the triangle: + + >>> Triangle(sss=(3, 4, 5)) + Triangle(Point2D(0, 0), Point2D(3, 0), Point2D(3, 4)) + >>> Triangle(asa=(30, 1, 30)) + Triangle(Point2D(0, 0), Point2D(1, 0), Point2D(1/2, sqrt(3)/6)) + >>> Triangle(sas=(1, 45, 2)) + Triangle(Point2D(0, 0), Point2D(2, 0), Point2D(sqrt(2)/2, sqrt(2)/2)) + + """ + + def __new__(cls, *args, **kwargs): + if len(args) != 3: + if 'sss' in kwargs: + return _sss(*[simplify(a) for a in kwargs['sss']]) + if 'asa' in kwargs: + return _asa(*[simplify(a) for a in kwargs['asa']]) + if 'sas' in kwargs: + return _sas(*[simplify(a) for a in kwargs['sas']]) + msg = "Triangle instantiates with three points or a valid keyword." + raise GeometryError(msg) + + vertices = [Point(a, dim=2, **kwargs) for a in args] + + # remove consecutive duplicates + nodup = [] + for p in vertices: + if nodup and p == nodup[-1]: + continue + nodup.append(p) + if len(nodup) > 1 and nodup[-1] == nodup[0]: + nodup.pop() # last point was same as first + + # remove collinear points + i = -3 + while i < len(nodup) - 3 and len(nodup) > 2: + a, b, c = sorted( + [nodup[i], nodup[i + 1], nodup[i + 2]], key=default_sort_key) + if Point.is_collinear(a, b, c): + nodup[i] = a + nodup[i + 1] = None + nodup.pop(i + 1) + i += 1 + + vertices = list(filter(lambda x: x is not None, nodup)) + + if len(vertices) == 3: + return GeometryEntity.__new__(cls, *vertices, **kwargs) + elif len(vertices) == 2: + return Segment(*vertices, **kwargs) + else: + return Point(*vertices, **kwargs) + + @property + def vertices(self): + """The triangle's vertices + + Returns + ======= + + vertices : tuple + Each element in the tuple is a Point + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Triangle, Point + >>> t = Triangle(Point(0, 0), Point(4, 0), Point(4, 3)) + >>> t.vertices + (Point2D(0, 0), Point2D(4, 0), Point2D(4, 3)) + + """ + return self.args + + def is_similar(t1, t2): + """Is another triangle similar to this one. + + Two triangles are similar if one can be uniformly scaled to the other. + + Parameters + ========== + + other: Triangle + + Returns + ======= + + is_similar : boolean + + See Also + ======== + + sympy.geometry.entity.GeometryEntity.is_similar + + Examples + ======== + + >>> from sympy import Triangle, Point + >>> t1 = Triangle(Point(0, 0), Point(4, 0), Point(4, 3)) + >>> t2 = Triangle(Point(0, 0), Point(-4, 0), Point(-4, -3)) + >>> t1.is_similar(t2) + True + + >>> t2 = Triangle(Point(0, 0), Point(-4, 0), Point(-4, -4)) + >>> t1.is_similar(t2) + False + + """ + if not isinstance(t2, Polygon): + return False + + s1_1, s1_2, s1_3 = [side.length for side in t1.sides] + s2 = [side.length for side in t2.sides] + + def _are_similar(u1, u2, u3, v1, v2, v3): + e1 = simplify(u1/v1) + e2 = simplify(u2/v2) + e3 = simplify(u3/v3) + return bool(e1 == e2) and bool(e2 == e3) + + # There's only 6 permutations, so write them out + return _are_similar(s1_1, s1_2, s1_3, *s2) or \ + _are_similar(s1_1, s1_3, s1_2, *s2) or \ + _are_similar(s1_2, s1_1, s1_3, *s2) or \ + _are_similar(s1_2, s1_3, s1_1, *s2) or \ + _are_similar(s1_3, s1_1, s1_2, *s2) or \ + _are_similar(s1_3, s1_2, s1_1, *s2) + + def is_equilateral(self): + """Are all the sides the same length? + + Returns + ======= + + is_equilateral : boolean + + See Also + ======== + + sympy.geometry.entity.GeometryEntity.is_similar, RegularPolygon + is_isosceles, is_right, is_scalene + + Examples + ======== + + >>> from sympy import Triangle, Point + >>> t1 = Triangle(Point(0, 0), Point(4, 0), Point(4, 3)) + >>> t1.is_equilateral() + False + + >>> from sympy import sqrt + >>> t2 = Triangle(Point(0, 0), Point(10, 0), Point(5, 5*sqrt(3))) + >>> t2.is_equilateral() + True + + """ + return not has_variety(s.length for s in self.sides) + + def is_isosceles(self): + """Are two or more of the sides the same length? + + Returns + ======= + + is_isosceles : boolean + + See Also + ======== + + is_equilateral, is_right, is_scalene + + Examples + ======== + + >>> from sympy import Triangle, Point + >>> t1 = Triangle(Point(0, 0), Point(4, 0), Point(2, 4)) + >>> t1.is_isosceles() + True + + """ + return has_dups(s.length for s in self.sides) + + def is_scalene(self): + """Are all the sides of the triangle of different lengths? + + Returns + ======= + + is_scalene : boolean + + See Also + ======== + + is_equilateral, is_isosceles, is_right + + Examples + ======== + + >>> from sympy import Triangle, Point + >>> t1 = Triangle(Point(0, 0), Point(4, 0), Point(1, 4)) + >>> t1.is_scalene() + True + + """ + return not has_dups(s.length for s in self.sides) + + def is_right(self): + """Is the triangle right-angled. + + Returns + ======= + + is_right : boolean + + See Also + ======== + + sympy.geometry.line.LinearEntity.is_perpendicular + is_equilateral, is_isosceles, is_scalene + + Examples + ======== + + >>> from sympy import Triangle, Point + >>> t1 = Triangle(Point(0, 0), Point(4, 0), Point(4, 3)) + >>> t1.is_right() + True + + """ + s = self.sides + return Segment.is_perpendicular(s[0], s[1]) or \ + Segment.is_perpendicular(s[1], s[2]) or \ + Segment.is_perpendicular(s[0], s[2]) + + @property + def altitudes(self): + """The altitudes of the triangle. + + An altitude of a triangle is a segment through a vertex, + perpendicular to the opposite side, with length being the + height of the vertex measured from the line containing the side. + + Returns + ======= + + altitudes : dict + The dictionary consists of keys which are vertices and values + which are Segments. + + See Also + ======== + + sympy.geometry.point.Point, sympy.geometry.line.Segment.length + + Examples + ======== + + >>> from sympy import Point, Triangle + >>> p1, p2, p3 = Point(0, 0), Point(1, 0), Point(0, 1) + >>> t = Triangle(p1, p2, p3) + >>> t.altitudes[p1] + Segment2D(Point2D(0, 0), Point2D(1/2, 1/2)) + + """ + s = self.sides + v = self.vertices + return {v[0]: s[1].perpendicular_segment(v[0]), + v[1]: s[2].perpendicular_segment(v[1]), + v[2]: s[0].perpendicular_segment(v[2])} + + @property + def orthocenter(self): + """The orthocenter of the triangle. + + The orthocenter is the intersection of the altitudes of a triangle. + It may lie inside, outside or on the triangle. + + Returns + ======= + + orthocenter : Point + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Triangle + >>> p1, p2, p3 = Point(0, 0), Point(1, 0), Point(0, 1) + >>> t = Triangle(p1, p2, p3) + >>> t.orthocenter + Point2D(0, 0) + + """ + a = self.altitudes + v = self.vertices + return Line(a[v[0]]).intersection(Line(a[v[1]]))[0] + + @property + def circumcenter(self): + """The circumcenter of the triangle + + The circumcenter is the center of the circumcircle. + + Returns + ======= + + circumcenter : Point + + See Also + ======== + + sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Triangle + >>> p1, p2, p3 = Point(0, 0), Point(1, 0), Point(0, 1) + >>> t = Triangle(p1, p2, p3) + >>> t.circumcenter + Point2D(1/2, 1/2) + """ + a, b, c = [x.perpendicular_bisector() for x in self.sides] + return a.intersection(b)[0] + + @property + def circumradius(self): + """The radius of the circumcircle of the triangle. + + Returns + ======= + + circumradius : number of Basic instance + + See Also + ======== + + sympy.geometry.ellipse.Circle.radius + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy import Point, Triangle + >>> a = Symbol('a') + >>> p1, p2, p3 = Point(0, 0), Point(1, 0), Point(0, a) + >>> t = Triangle(p1, p2, p3) + >>> t.circumradius + sqrt(a**2/4 + 1/4) + """ + return Point.distance(self.circumcenter, self.vertices[0]) + + @property + def circumcircle(self): + """The circle which passes through the three vertices of the triangle. + + Returns + ======= + + circumcircle : Circle + + See Also + ======== + + sympy.geometry.ellipse.Circle + + Examples + ======== + + >>> from sympy import Point, Triangle + >>> p1, p2, p3 = Point(0, 0), Point(1, 0), Point(0, 1) + >>> t = Triangle(p1, p2, p3) + >>> t.circumcircle + Circle(Point2D(1/2, 1/2), sqrt(2)/2) + + """ + return Circle(self.circumcenter, self.circumradius) + + def bisectors(self): + """The angle bisectors of the triangle. + + An angle bisector of a triangle is a straight line through a vertex + which cuts the corresponding angle in half. + + Returns + ======= + + bisectors : dict + Each key is a vertex (Point) and each value is the corresponding + bisector (Segment). + + See Also + ======== + + sympy.geometry.point.Point, sympy.geometry.line.Segment + + Examples + ======== + + >>> from sympy import Point, Triangle, Segment + >>> p1, p2, p3 = Point(0, 0), Point(1, 0), Point(0, 1) + >>> t = Triangle(p1, p2, p3) + >>> from sympy import sqrt + >>> t.bisectors()[p2] == Segment(Point(1, 0), Point(0, sqrt(2) - 1)) + True + + """ + # use lines containing sides so containment check during + # intersection calculation can be avoided, thus reducing + # the processing time for calculating the bisectors + s = [Line(l) for l in self.sides] + v = self.vertices + c = self.incenter + l1 = Segment(v[0], Line(v[0], c).intersection(s[1])[0]) + l2 = Segment(v[1], Line(v[1], c).intersection(s[2])[0]) + l3 = Segment(v[2], Line(v[2], c).intersection(s[0])[0]) + return {v[0]: l1, v[1]: l2, v[2]: l3} + + @property + def incenter(self): + """The center of the incircle. + + The incircle is the circle which lies inside the triangle and touches + all three sides. + + Returns + ======= + + incenter : Point + + See Also + ======== + + incircle, sympy.geometry.point.Point + + Examples + ======== + + >>> from sympy import Point, Triangle + >>> p1, p2, p3 = Point(0, 0), Point(1, 0), Point(0, 1) + >>> t = Triangle(p1, p2, p3) + >>> t.incenter + Point2D(1 - sqrt(2)/2, 1 - sqrt(2)/2) + + """ + s = self.sides + l = Matrix([s[i].length for i in [1, 2, 0]]) + p = sum(l) + v = self.vertices + x = simplify(l.dot(Matrix([vi.x for vi in v]))/p) + y = simplify(l.dot(Matrix([vi.y for vi in v]))/p) + return Point(x, y) + + @property + def inradius(self): + """The radius of the incircle. + + Returns + ======= + + inradius : number of Basic instance + + See Also + ======== + + incircle, sympy.geometry.ellipse.Circle.radius + + Examples + ======== + + >>> from sympy import Point, Triangle + >>> p1, p2, p3 = Point(0, 0), Point(4, 0), Point(0, 3) + >>> t = Triangle(p1, p2, p3) + >>> t.inradius + 1 + + """ + return simplify(2 * self.area / self.perimeter) + + @property + def incircle(self): + """The incircle of the triangle. + + The incircle is the circle which lies inside the triangle and touches + all three sides. + + Returns + ======= + + incircle : Circle + + See Also + ======== + + sympy.geometry.ellipse.Circle + + Examples + ======== + + >>> from sympy import Point, Triangle + >>> p1, p2, p3 = Point(0, 0), Point(2, 0), Point(0, 2) + >>> t = Triangle(p1, p2, p3) + >>> t.incircle + Circle(Point2D(2 - sqrt(2), 2 - sqrt(2)), 2 - sqrt(2)) + + """ + return Circle(self.incenter, self.inradius) + + @property + def exradii(self): + """The radius of excircles of a triangle. + + An excircle of the triangle is a circle lying outside the triangle, + tangent to one of its sides and tangent to the extensions of the + other two. + + Returns + ======= + + exradii : dict + + See Also + ======== + + sympy.geometry.polygon.Triangle.inradius + + Examples + ======== + + The exradius touches the side of the triangle to which it is keyed, e.g. + the exradius touching side 2 is: + + >>> from sympy import Point, Triangle + >>> p1, p2, p3 = Point(0, 0), Point(6, 0), Point(0, 2) + >>> t = Triangle(p1, p2, p3) + >>> t.exradii[t.sides[2]] + -2 + sqrt(10) + + References + ========== + + .. [1] https://mathworld.wolfram.com/Exradius.html + .. [2] https://mathworld.wolfram.com/Excircles.html + + """ + + side = self.sides + a = side[0].length + b = side[1].length + c = side[2].length + s = (a+b+c)/2 + area = self.area + exradii = {self.sides[0]: simplify(area/(s-a)), + self.sides[1]: simplify(area/(s-b)), + self.sides[2]: simplify(area/(s-c))} + + return exradii + + @property + def excenters(self): + """Excenters of the triangle. + + An excenter is the center of a circle that is tangent to a side of the + triangle and the extensions of the other two sides. + + Returns + ======= + + excenters : dict + + + Examples + ======== + + The excenters are keyed to the side of the triangle to which their corresponding + excircle is tangent: The center is keyed, e.g. the excenter of a circle touching + side 0 is: + + >>> from sympy import Point, Triangle + >>> p1, p2, p3 = Point(0, 0), Point(6, 0), Point(0, 2) + >>> t = Triangle(p1, p2, p3) + >>> t.excenters[t.sides[0]] + Point2D(12*sqrt(10), 2/3 + sqrt(10)/3) + + See Also + ======== + + sympy.geometry.polygon.Triangle.exradii + + References + ========== + + .. [1] https://mathworld.wolfram.com/Excircles.html + + """ + + s = self.sides + v = self.vertices + a = s[0].length + b = s[1].length + c = s[2].length + x = [v[0].x, v[1].x, v[2].x] + y = [v[0].y, v[1].y, v[2].y] + + exc_coords = { + "x1": simplify(-a*x[0]+b*x[1]+c*x[2]/(-a+b+c)), + "x2": simplify(a*x[0]-b*x[1]+c*x[2]/(a-b+c)), + "x3": simplify(a*x[0]+b*x[1]-c*x[2]/(a+b-c)), + "y1": simplify(-a*y[0]+b*y[1]+c*y[2]/(-a+b+c)), + "y2": simplify(a*y[0]-b*y[1]+c*y[2]/(a-b+c)), + "y3": simplify(a*y[0]+b*y[1]-c*y[2]/(a+b-c)) + } + + excenters = { + s[0]: Point(exc_coords["x1"], exc_coords["y1"]), + s[1]: Point(exc_coords["x2"], exc_coords["y2"]), + s[2]: Point(exc_coords["x3"], exc_coords["y3"]) + } + + return excenters + + @property + def medians(self): + """The medians of the triangle. + + A median of a triangle is a straight line through a vertex and the + midpoint of the opposite side, and divides the triangle into two + equal areas. + + Returns + ======= + + medians : dict + Each key is a vertex (Point) and each value is the median (Segment) + at that point. + + See Also + ======== + + sympy.geometry.point.Point.midpoint, sympy.geometry.line.Segment.midpoint + + Examples + ======== + + >>> from sympy import Point, Triangle + >>> p1, p2, p3 = Point(0, 0), Point(1, 0), Point(0, 1) + >>> t = Triangle(p1, p2, p3) + >>> t.medians[p1] + Segment2D(Point2D(0, 0), Point2D(1/2, 1/2)) + + """ + s = self.sides + v = self.vertices + return {v[0]: Segment(v[0], s[1].midpoint), + v[1]: Segment(v[1], s[2].midpoint), + v[2]: Segment(v[2], s[0].midpoint)} + + @property + def medial(self): + """The medial triangle of the triangle. + + The triangle which is formed from the midpoints of the three sides. + + Returns + ======= + + medial : Triangle + + See Also + ======== + + sympy.geometry.line.Segment.midpoint + + Examples + ======== + + >>> from sympy import Point, Triangle + >>> p1, p2, p3 = Point(0, 0), Point(1, 0), Point(0, 1) + >>> t = Triangle(p1, p2, p3) + >>> t.medial + Triangle(Point2D(1/2, 0), Point2D(1/2, 1/2), Point2D(0, 1/2)) + + """ + s = self.sides + return Triangle(s[0].midpoint, s[1].midpoint, s[2].midpoint) + + @property + def nine_point_circle(self): + """The nine-point circle of the triangle. + + Nine-point circle is the circumcircle of the medial triangle, which + passes through the feet of altitudes and the middle points of segments + connecting the vertices and the orthocenter. + + Returns + ======= + + nine_point_circle : Circle + + See also + ======== + + sympy.geometry.line.Segment.midpoint + sympy.geometry.polygon.Triangle.medial + sympy.geometry.polygon.Triangle.orthocenter + + Examples + ======== + + >>> from sympy import Point, Triangle + >>> p1, p2, p3 = Point(0, 0), Point(1, 0), Point(0, 1) + >>> t = Triangle(p1, p2, p3) + >>> t.nine_point_circle + Circle(Point2D(1/4, 1/4), sqrt(2)/4) + + """ + return Circle(*self.medial.vertices) + + @property + def eulerline(self): + """The Euler line of the triangle. + + The line which passes through circumcenter, centroid and orthocenter. + + Returns + ======= + + eulerline : Line (or Point for equilateral triangles in which case all + centers coincide) + + Examples + ======== + + >>> from sympy import Point, Triangle + >>> p1, p2, p3 = Point(0, 0), Point(1, 0), Point(0, 1) + >>> t = Triangle(p1, p2, p3) + >>> t.eulerline + Line2D(Point2D(0, 0), Point2D(1/2, 1/2)) + + """ + if self.is_equilateral(): + return self.orthocenter + return Line(self.orthocenter, self.circumcenter) + +def rad(d): + """Return the radian value for the given degrees (pi = 180 degrees).""" + return d*pi/180 + + +def deg(r): + """Return the degree value for the given radians (pi = 180 degrees).""" + return r/pi*180 + + +def _slope(d): + rv = tan(rad(d)) + return rv + + +def _asa(d1, l, d2): + """Return triangle having side with length l on the x-axis.""" + xy = Line((0, 0), slope=_slope(d1)).intersection( + Line((l, 0), slope=_slope(180 - d2)))[0] + return Triangle((0, 0), (l, 0), xy) + + +def _sss(l1, l2, l3): + """Return triangle having side of length l1 on the x-axis.""" + c1 = Circle((0, 0), l3) + c2 = Circle((l1, 0), l2) + inter = [a for a in c1.intersection(c2) if a.y.is_nonnegative] + if not inter: + return None + pt = inter[0] + return Triangle((0, 0), (l1, 0), pt) + + +def _sas(l1, d, l2): + """Return triangle having side with length l2 on the x-axis.""" + p1 = Point(0, 0) + p2 = Point(l2, 0) + p3 = Point(cos(rad(d))*l1, sin(rad(d))*l1) + return Triangle(p1, p2, p3) diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/util.py b/.venv/lib/python3.13/site-packages/sympy/geometry/util.py new file mode 100644 index 0000000000000000000000000000000000000000..1d8fb77550f2faea8185ff0c373b5f1680e623ec --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/util.py @@ -0,0 +1,731 @@ +"""Utility functions for geometrical entities. + +Contains +======== +intersection +convex_hull +closest_points +farthest_points +are_coplanar +are_similar + +""" + +from collections import deque +from math import sqrt as _sqrt + +from sympy import nsimplify +from .entity import GeometryEntity +from .exceptions import GeometryError +from .point import Point, Point2D, Point3D +from sympy.core.containers import OrderedSet +from sympy.core.exprtools import factor_terms +from sympy.core.function import Function, expand_mul +from sympy.core.numbers import Float +from sympy.core.sorting import ordered +from sympy.core.symbol import Symbol +from sympy.core.singleton import S +from sympy.polys.polytools import cancel +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.utilities.iterables import is_sequence + +from mpmath.libmp.libmpf import prec_to_dps + + +def find(x, equation): + """ + Checks whether a Symbol matching ``x`` is present in ``equation`` + or not. If present, the matching symbol is returned, else a + ValueError is raised. If ``x`` is a string the matching symbol + will have the same name; if ``x`` is a Symbol then it will be + returned if found. + + Examples + ======== + + >>> from sympy.geometry.util import find + >>> from sympy import Dummy + >>> from sympy.abc import x + >>> find('x', x) + x + >>> find('x', Dummy('x')) + _x + + The dummy symbol is returned since it has a matching name: + + >>> _.name == 'x' + True + >>> find(x, Dummy('x')) + Traceback (most recent call last): + ... + ValueError: could not find x + """ + + free = equation.free_symbols + xs = [i for i in free if (i.name if isinstance(x, str) else i) == x] + if not xs: + raise ValueError('could not find %s' % x) + if len(xs) != 1: + raise ValueError('ambiguous %s' % x) + return xs[0] + + +def _ordered_points(p): + """Return the tuple of points sorted numerically according to args""" + return tuple(sorted(p, key=lambda x: x.args)) + + +def are_coplanar(*e): + """ Returns True if the given entities are coplanar otherwise False + + Parameters + ========== + + e: entities to be checked for being coplanar + + Returns + ======= + + Boolean + + Examples + ======== + + >>> from sympy import Point3D, Line3D + >>> from sympy.geometry.util import are_coplanar + >>> a = Line3D(Point3D(5, 0, 0), Point3D(1, -1, 1)) + >>> b = Line3D(Point3D(0, -2, 0), Point3D(3, 1, 1)) + >>> c = Line3D(Point3D(0, -1, 0), Point3D(5, -1, 9)) + >>> are_coplanar(a, b, c) + False + + """ + from .line import LinearEntity3D + from .plane import Plane + # XXX update tests for coverage + + e = set(e) + # first work with a Plane if present + for i in list(e): + if isinstance(i, Plane): + e.remove(i) + return all(p.is_coplanar(i) for p in e) + + if all(isinstance(i, Point3D) for i in e): + if len(e) < 3: + return False + + # remove pts that are collinear with 2 pts + a, b = e.pop(), e.pop() + for i in list(e): + if Point3D.are_collinear(a, b, i): + e.remove(i) + + if not e: + return False + else: + # define a plane + p = Plane(a, b, e.pop()) + for i in e: + if i not in p: + return False + return True + else: + pt3d = [] + for i in e: + if isinstance(i, Point3D): + pt3d.append(i) + elif isinstance(i, LinearEntity3D): + pt3d.extend(i.args) + elif isinstance(i, GeometryEntity): # XXX we should have a GeometryEntity3D class so we can tell the difference between 2D and 3D -- here we just want to deal with 2D objects; if new 3D objects are encountered that we didn't handle above, an error should be raised + # all 2D objects have some Point that defines them; so convert those points to 3D pts by making z=0 + for p in i.args: + if isinstance(p, Point): + pt3d.append(Point3D(*(p.args + (0,)))) + return are_coplanar(*pt3d) + + +def are_similar(e1, e2): + """Are two geometrical entities similar. + + Can one geometrical entity be uniformly scaled to the other? + + Parameters + ========== + + e1 : GeometryEntity + e2 : GeometryEntity + + Returns + ======= + + are_similar : boolean + + Raises + ====== + + GeometryError + When `e1` and `e2` cannot be compared. + + Notes + ===== + + If the two objects are equal then they are similar. + + See Also + ======== + + sympy.geometry.entity.GeometryEntity.is_similar + + Examples + ======== + + >>> from sympy import Point, Circle, Triangle, are_similar + >>> c1, c2 = Circle(Point(0, 0), 4), Circle(Point(1, 4), 3) + >>> t1 = Triangle(Point(0, 0), Point(1, 0), Point(0, 1)) + >>> t2 = Triangle(Point(0, 0), Point(2, 0), Point(0, 2)) + >>> t3 = Triangle(Point(0, 0), Point(3, 0), Point(0, 1)) + >>> are_similar(t1, t2) + True + >>> are_similar(t1, t3) + False + + """ + if e1 == e2: + return True + is_similar1 = getattr(e1, 'is_similar', None) + if is_similar1: + return is_similar1(e2) + is_similar2 = getattr(e2, 'is_similar', None) + if is_similar2: + return is_similar2(e1) + n1 = e1.__class__.__name__ + n2 = e2.__class__.__name__ + raise GeometryError( + "Cannot test similarity between %s and %s" % (n1, n2)) + + +def centroid(*args): + """Find the centroid (center of mass) of the collection containing only Points, + Segments or Polygons. The centroid is the weighted average of the individual centroid + where the weights are the lengths (of segments) or areas (of polygons). + Overlapping regions will add to the weight of that region. + + If there are no objects (or a mixture of objects) then None is returned. + + See Also + ======== + + sympy.geometry.point.Point, sympy.geometry.line.Segment, + sympy.geometry.polygon.Polygon + + Examples + ======== + + >>> from sympy import Point, Segment, Polygon + >>> from sympy.geometry.util import centroid + >>> p = Polygon((0, 0), (10, 0), (10, 10)) + >>> q = p.translate(0, 20) + >>> p.centroid, q.centroid + (Point2D(20/3, 10/3), Point2D(20/3, 70/3)) + >>> centroid(p, q) + Point2D(20/3, 40/3) + >>> p, q = Segment((0, 0), (2, 0)), Segment((0, 0), (2, 2)) + >>> centroid(p, q) + Point2D(1, 2 - sqrt(2)) + >>> centroid(Point(0, 0), Point(2, 0)) + Point2D(1, 0) + + Stacking 3 polygons on top of each other effectively triples the + weight of that polygon: + + >>> p = Polygon((0, 0), (1, 0), (1, 1), (0, 1)) + >>> q = Polygon((1, 0), (3, 0), (3, 1), (1, 1)) + >>> centroid(p, q) + Point2D(3/2, 1/2) + >>> centroid(p, p, p, q) # centroid x-coord shifts left + Point2D(11/10, 1/2) + + Stacking the squares vertically above and below p has the same + effect: + + >>> centroid(p, p.translate(0, 1), p.translate(0, -1), q) + Point2D(11/10, 1/2) + + """ + from .line import Segment + from .polygon import Polygon + if args: + if all(isinstance(g, Point) for g in args): + c = Point(0, 0) + for g in args: + c += g + den = len(args) + elif all(isinstance(g, Segment) for g in args): + c = Point(0, 0) + L = 0 + for g in args: + l = g.length + c += g.midpoint*l + L += l + den = L + elif all(isinstance(g, Polygon) for g in args): + c = Point(0, 0) + A = 0 + for g in args: + a = g.area + c += g.centroid*a + A += a + den = A + c /= den + return c.func(*[i.simplify() for i in c.args]) + + +def closest_points(*args): + """Return the subset of points from a set of points that were + the closest to each other in the 2D plane. + + Parameters + ========== + + args + A collection of Points on 2D plane. + + Notes + ===== + + This can only be performed on a set of points whose coordinates can + be ordered on the number line. If there are no ties then a single + pair of Points will be in the set. + + Examples + ======== + + >>> from sympy import closest_points, Triangle + >>> Triangle(sss=(3, 4, 5)).args + (Point2D(0, 0), Point2D(3, 0), Point2D(3, 4)) + >>> closest_points(*_) + {(Point2D(0, 0), Point2D(3, 0))} + + References + ========== + + .. [1] https://www.cs.mcgill.ca/~cs251/ClosestPair/ClosestPairPS.html + + .. [2] Sweep line algorithm + https://en.wikipedia.org/wiki/Sweep_line_algorithm + + """ + p = [Point2D(i) for i in set(args)] + if len(p) < 2: + raise ValueError('At least 2 distinct points must be given.') + + try: + p.sort(key=lambda x: x.args) + except TypeError: + raise ValueError("The points could not be sorted.") + + if not all(i.is_Rational for j in p for i in j.args): + def hypot(x, y): + arg = x*x + y*y + if arg.is_Rational: + return _sqrt(arg) + return sqrt(arg) + else: + from math import hypot + + rv = [(0, 1)] + best_dist = hypot(p[1].x - p[0].x, p[1].y - p[0].y) + left = 0 + box = deque([0, 1]) + for i in range(2, len(p)): + while left < i and p[i][0] - p[left][0] > best_dist: + box.popleft() + left += 1 + + for j in box: + d = hypot(p[i].x - p[j].x, p[i].y - p[j].y) + if d < best_dist: + rv = [(j, i)] + elif d == best_dist: + rv.append((j, i)) + else: + continue + best_dist = d + box.append(i) + + return {tuple([p[i] for i in pair]) for pair in rv} + + +def convex_hull(*args, polygon=True): + """The convex hull surrounding the Points contained in the list of entities. + + Parameters + ========== + + args : a collection of Points, Segments and/or Polygons + + Optional parameters + =================== + + polygon : Boolean. If True, returns a Polygon, if false a tuple, see below. + Default is True. + + Returns + ======= + + convex_hull : Polygon if ``polygon`` is True else as a tuple `(U, L)` where + ``L`` and ``U`` are the lower and upper hulls, respectively. + + Notes + ===== + + This can only be performed on a set of points whose coordinates can + be ordered on the number line. + + See Also + ======== + + sympy.geometry.point.Point, sympy.geometry.polygon.Polygon + + Examples + ======== + + >>> from sympy import convex_hull + >>> points = [(1, 1), (1, 2), (3, 1), (-5, 2), (15, 4)] + >>> convex_hull(*points) + Polygon(Point2D(-5, 2), Point2D(1, 1), Point2D(3, 1), Point2D(15, 4)) + >>> convex_hull(*points, **dict(polygon=False)) + ([Point2D(-5, 2), Point2D(15, 4)], + [Point2D(-5, 2), Point2D(1, 1), Point2D(3, 1), Point2D(15, 4)]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Graham_scan + + .. [2] Andrew's Monotone Chain Algorithm + (A.M. Andrew, + "Another Efficient Algorithm for Convex Hulls in Two Dimensions", 1979) + https://web.archive.org/web/20210511015444/http://geomalgorithms.com/a10-_hull-1.html + + """ + from .line import Segment + from .polygon import Polygon + p = OrderedSet() + for e in args: + if not isinstance(e, GeometryEntity): + try: + e = Point(e) + except NotImplementedError: + raise ValueError('%s is not a GeometryEntity and cannot be made into Point' % str(e)) + if isinstance(e, Point): + p.add(e) + elif isinstance(e, Segment): + p.update(e.points) + elif isinstance(e, Polygon): + p.update(e.vertices) + else: + raise NotImplementedError( + 'Convex hull for %s not implemented.' % type(e)) + + # make sure all our points are of the same dimension + if any(len(x) != 2 for x in p): + raise ValueError('Can only compute the convex hull in two dimensions') + + p = list(p) + if len(p) == 1: + return p[0] if polygon else (p[0], None) + elif len(p) == 2: + s = Segment(p[0], p[1]) + return s if polygon else (s, None) + + def _orientation(p, q, r): + '''Return positive if p-q-r are clockwise, neg if ccw, zero if + collinear.''' + return (q.y - p.y)*(r.x - p.x) - (q.x - p.x)*(r.y - p.y) + + # scan to find upper and lower convex hulls of a set of 2d points. + U = [] + L = [] + try: + p.sort(key=lambda x: x.args) + except TypeError: + raise ValueError("The points could not be sorted.") + for p_i in p: + while len(U) > 1 and _orientation(U[-2], U[-1], p_i) <= 0: + U.pop() + while len(L) > 1 and _orientation(L[-2], L[-1], p_i) >= 0: + L.pop() + U.append(p_i) + L.append(p_i) + U.reverse() + convexHull = tuple(L + U[1:-1]) + + if len(convexHull) == 2: + s = Segment(convexHull[0], convexHull[1]) + return s if polygon else (s, None) + if polygon: + return Polygon(*convexHull) + else: + U.reverse() + return (U, L) + +def farthest_points(*args): + """Return the subset of points from a set of points that were + the furthest apart from each other in the 2D plane. + + Parameters + ========== + + args + A collection of Points on 2D plane. + + Notes + ===== + + This can only be performed on a set of points whose coordinates can + be ordered on the number line. If there are no ties then a single + pair of Points will be in the set. + + Examples + ======== + + >>> from sympy.geometry import farthest_points, Triangle + >>> Triangle(sss=(3, 4, 5)).args + (Point2D(0, 0), Point2D(3, 0), Point2D(3, 4)) + >>> farthest_points(*_) + {(Point2D(0, 0), Point2D(3, 4))} + + References + ========== + + .. [1] https://code.activestate.com/recipes/117225-convex-hull-and-diameter-of-2d-point-sets/ + + .. [2] Rotating Callipers Technique + https://en.wikipedia.org/wiki/Rotating_calipers + + """ + + def rotatingCalipers(Points): + U, L = convex_hull(*Points, **{"polygon": False}) + + if L is None: + if isinstance(U, Point): + raise ValueError('At least two distinct points must be given.') + yield U.args + else: + i = 0 + j = len(L) - 1 + while i < len(U) - 1 or j > 0: + yield U[i], L[j] + # if all the way through one side of hull, advance the other side + if i == len(U) - 1: + j -= 1 + elif j == 0: + i += 1 + # still points left on both lists, compare slopes of next hull edges + # being careful to avoid divide-by-zero in slope calculation + elif (U[i+1].y - U[i].y) * (L[j].x - L[j-1].x) > \ + (L[j].y - L[j-1].y) * (U[i+1].x - U[i].x): + i += 1 + else: + j -= 1 + + p = [Point2D(i) for i in set(args)] + + if not all(i.is_Rational for j in p for i in j.args): + def hypot(x, y): + arg = x*x + y*y + if arg.is_Rational: + return _sqrt(arg) + return sqrt(arg) + else: + from math import hypot + + rv = [] + diam = 0 + for pair in rotatingCalipers(args): + h, q = _ordered_points(pair) + d = hypot(h.x - q.x, h.y - q.y) + if d > diam: + rv = [(h, q)] + elif d == diam: + rv.append((h, q)) + else: + continue + diam = d + + return set(rv) + + +def idiff(eq, y, x, n=1): + """Return ``dy/dx`` assuming that ``eq == 0``. + + Parameters + ========== + + y : the dependent variable or a list of dependent variables (with y first) + x : the variable that the derivative is being taken with respect to + n : the order of the derivative (default is 1) + + Examples + ======== + + >>> from sympy.abc import x, y, a + >>> from sympy.geometry.util import idiff + + >>> circ = x**2 + y**2 - 4 + >>> idiff(circ, y, x) + -x/y + >>> idiff(circ, y, x, 2).simplify() + (-x**2 - y**2)/y**3 + + Here, ``a`` is assumed to be independent of ``x``: + + >>> idiff(x + a + y, y, x) + -1 + + Now the x-dependence of ``a`` is made explicit by listing ``a`` after + ``y`` in a list. + + >>> idiff(x + a + y, [y, a], x) + -Derivative(a, x) - 1 + + See Also + ======== + + sympy.core.function.Derivative: represents unevaluated derivatives + sympy.core.function.diff: explicitly differentiates wrt symbols + + """ + if is_sequence(y): + dep = set(y) + y = y[0] + elif isinstance(y, Symbol): + dep = {y} + elif isinstance(y, Function): + pass + else: + raise ValueError("expecting x-dependent symbol(s) or function(s) but got: %s" % y) + + f = {s: Function(s.name)(x) for s in eq.free_symbols + if s != x and s in dep} + + if isinstance(y, Symbol): + dydx = Function(y.name)(x).diff(x) + else: + dydx = y.diff(x) + + eq = eq.subs(f) + derivs = {} + for i in range(n): + # equation will be linear in dydx, a*dydx + b, so dydx = -b/a + deq = eq.diff(x) + b = deq.xreplace({dydx: S.Zero}) + a = (deq - b).xreplace({dydx: S.One}) + yp = factor_terms(expand_mul(cancel((-b/a).subs(derivs)), deep=False)) + if i == n - 1: + return yp.subs([(v, k) for k, v in f.items()]) + derivs[dydx] = yp + eq = dydx - yp + dydx = dydx.diff(x) + + +def intersection(*entities, pairwise=False, **kwargs): + """The intersection of a collection of GeometryEntity instances. + + Parameters + ========== + entities : sequence of GeometryEntity + pairwise (keyword argument) : Can be either True or False + + Returns + ======= + intersection : list of GeometryEntity + + Raises + ====== + NotImplementedError + When unable to calculate intersection. + + Notes + ===== + The intersection of any geometrical entity with itself should return + a list with one item: the entity in question. + An intersection requires two or more entities. If only a single + entity is given then the function will return an empty list. + It is possible for `intersection` to miss intersections that one + knows exists because the required quantities were not fully + simplified internally. + Reals should be converted to Rationals, e.g. Rational(str(real_num)) + or else failures due to floating point issues may result. + + Case 1: When the keyword argument 'pairwise' is False (default value): + In this case, the function returns a list of intersections common to + all entities. + + Case 2: When the keyword argument 'pairwise' is True: + In this case, the functions returns a list intersections that occur + between any pair of entities. + + See Also + ======== + + sympy.geometry.entity.GeometryEntity.intersection + + Examples + ======== + + >>> from sympy import Ray, Circle, intersection + >>> c = Circle((0, 1), 1) + >>> intersection(c, c.center) + [] + >>> right = Ray((0, 0), (1, 0)) + >>> up = Ray((0, 0), (0, 1)) + >>> intersection(c, right, up) + [Point2D(0, 0)] + >>> intersection(c, right, up, pairwise=True) + [Point2D(0, 0), Point2D(0, 2)] + >>> left = Ray((1, 0), (0, 0)) + >>> intersection(right, left) + [Segment2D(Point2D(0, 0), Point2D(1, 0))] + + """ + if len(entities) <= 1: + return [] + + entities = list(entities) + prec = None + for i, e in enumerate(entities): + if not isinstance(e, GeometryEntity): + # entities may be an immutable tuple + e = Point(e) + # convert to exact Rationals + d = {} + for f in e.atoms(Float): + prec = f._prec if prec is None else min(f._prec, prec) + d.setdefault(f, nsimplify(f, rational=True)) + entities[i] = e.xreplace(d) + + if not pairwise: + # find the intersection common to all objects + res = entities[0].intersection(entities[1]) + for entity in entities[2:]: + newres = [] + for x in res: + newres.extend(x.intersection(entity)) + res = newres + else: + # find all pairwise intersections + ans = [] + for j in range(len(entities)): + for k in range(j + 1, len(entities)): + ans.extend(intersection(entities[j], entities[k])) + res = list(ordered(set(ans))) + + # convert back to Floats + if prec is not None: + p = prec_to_dps(prec) + res = [i.n(p) for i in res] + return res diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/__init__.py b/.venv/lib/python3.13/site-packages/sympy/matrices/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..37b558f3f03f149dae6af20254e9b88192f7f1ed --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/__init__.py @@ -0,0 +1,72 @@ +"""A module that handles matrices. + +Includes functions for fast creating matrices like zero, one/eye, random +matrix, etc. +""" +from .exceptions import ShapeError, NonSquareMatrixError +from .kind import MatrixKind +from .dense import ( + GramSchmidt, casoratian, diag, eye, hessian, jordan_cell, + list2numpy, matrix2numpy, matrix_multiply_elementwise, ones, + randMatrix, rot_axis1, rot_axis2, rot_axis3, rot_ccw_axis1, + rot_ccw_axis2, rot_ccw_axis3, rot_givens, + symarray, wronskian, zeros) +from .dense import MutableDenseMatrix +from .matrixbase import DeferredVector, MatrixBase + +MutableMatrix = MutableDenseMatrix +Matrix = MutableMatrix + +from .sparse import MutableSparseMatrix +from .sparsetools import banded +from .immutable import ImmutableDenseMatrix, ImmutableSparseMatrix + +ImmutableMatrix = ImmutableDenseMatrix +SparseMatrix = MutableSparseMatrix + +from .expressions import ( + MatrixSlice, BlockDiagMatrix, BlockMatrix, FunctionMatrix, Identity, + Inverse, MatAdd, MatMul, MatPow, MatrixExpr, MatrixSymbol, Trace, + Transpose, ZeroMatrix, OneMatrix, blockcut, block_collapse, matrix_symbols, Adjoint, + hadamard_product, HadamardProduct, HadamardPower, Determinant, det, + diagonalize_vector, DiagMatrix, DiagonalMatrix, DiagonalOf, trace, + DotProduct, kronecker_product, KroneckerProduct, + PermutationMatrix, MatrixPermute, MatrixSet, Permanent, per) + +from .utilities import dotprodsimp + +__all__ = [ + 'ShapeError', 'NonSquareMatrixError', 'MatrixKind', + + 'GramSchmidt', 'casoratian', 'diag', 'eye', 'hessian', 'jordan_cell', + 'list2numpy', 'matrix2numpy', 'matrix_multiply_elementwise', 'ones', + 'randMatrix', 'rot_axis1', 'rot_axis2', 'rot_axis3', 'symarray', + 'wronskian', 'zeros', 'rot_ccw_axis1', 'rot_ccw_axis2', 'rot_ccw_axis3', + 'rot_givens', + + 'MutableDenseMatrix', + + 'DeferredVector', 'MatrixBase', + + 'Matrix', 'MutableMatrix', + + 'MutableSparseMatrix', + + 'banded', + + 'ImmutableDenseMatrix', 'ImmutableSparseMatrix', + + 'ImmutableMatrix', 'SparseMatrix', + + 'MatrixSlice', 'BlockDiagMatrix', 'BlockMatrix', 'FunctionMatrix', + 'Identity', 'Inverse', 'MatAdd', 'MatMul', 'MatPow', 'MatrixExpr', + 'MatrixSymbol', 'Trace', 'Transpose', 'ZeroMatrix', 'OneMatrix', + 'blockcut', 'block_collapse', 'matrix_symbols', 'Adjoint', + 'hadamard_product', 'HadamardProduct', 'HadamardPower', 'Determinant', + 'det', 'diagonalize_vector', 'DiagMatrix', 'DiagonalMatrix', + 'DiagonalOf', 'trace', 'DotProduct', 'kronecker_product', + 'KroneckerProduct', 'PermutationMatrix', 'MatrixPermute', 'MatrixSet', + 'Permanent', 'per', + + 'dotprodsimp', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/benchmarks/__init__.py b/.venv/lib/python3.13/site-packages/sympy/matrices/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/benchmarks/bench_matrix.py b/.venv/lib/python3.13/site-packages/sympy/matrices/benchmarks/bench_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..4fb845600533c4c6fef196fe5a45b98890f4ad78 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/benchmarks/bench_matrix.py @@ -0,0 +1,21 @@ +from sympy.core.numbers import Integer +from sympy.matrices.dense import (eye, zeros) + +i3 = Integer(3) +M = eye(100) + + +def timeit_Matrix__getitem_ii(): + M[3, 3] + + +def timeit_Matrix__getitem_II(): + M[i3, i3] + + +def timeit_Matrix__getslice(): + M[:, :] + + +def timeit_Matrix_zeronm(): + zeros(100, 100) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/common.py b/.venv/lib/python3.13/site-packages/sympy/matrices/common.py new file mode 100644 index 0000000000000000000000000000000000000000..bcb54726fe1a0c36658d8bf63b974db5a3ce8bad --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/common.py @@ -0,0 +1,3258 @@ +""" +A module containing deprecated matrix mixin classes. + +The classes in this module are deprecated and will be removed in a future +release. They are kept here for backwards compatibility in case downstream +code was subclassing them. + +Importing anything else from this module is deprecated so anything here +should either not be used or should be imported from somewhere else. +""" +from __future__ import annotations +from collections import defaultdict +from collections.abc import Iterable +from inspect import isfunction +from functools import reduce + +from sympy.assumptions.refine import refine +from sympy.core import SympifyError, Add +from sympy.core.basic import Atom +from sympy.core.decorators import call_highest_priority +from sympy.core.logic import fuzzy_and, FuzzyBool +from sympy.core.numbers import Integer +from sympy.core.mod import Mod +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import Abs, re, im +from sympy.utilities.exceptions import sympy_deprecation_warning +from .utilities import _dotprodsimp, _simplify +from sympy.polys.polytools import Poly +from sympy.utilities.iterables import flatten, is_sequence +from sympy.utilities.misc import as_int, filldedent +from sympy.tensor.array import NDimArray + +from .utilities import _get_intermediate_simp_bool + + +# These exception types were previously defined in this module but were moved +# to exceptions.py. We reimport them here for backwards compatibility in case +# downstream code was importing them from here. +from .exceptions import ( # noqa: F401 + MatrixError, ShapeError, NonSquareMatrixError, NonInvertibleMatrixError, + NonPositiveDefiniteMatrixError +) + + +_DEPRECATED_MIXINS = ( + 'MatrixShaping', + 'MatrixSpecial', + 'MatrixProperties', + 'MatrixOperations', + 'MatrixArithmetic', + 'MatrixCommon', + 'MatrixDeterminant', + 'MatrixReductions', + 'MatrixSubspaces', + 'MatrixEigen', + 'MatrixCalculus', + 'MatrixDeprecated', +) + + +class _MatrixDeprecatedMeta(type): + + # + # Override the default __instancecheck__ implementation to ensure that + # e.g. isinstance(M, MatrixCommon) still works when M is one of the + # matrix classes. Matrix no longer inherits from MatrixCommon so + # isinstance(M, MatrixCommon) would now return False by default. + # + # There were lots of places in the codebase where this was being done + # so it seems likely that downstream code may be doing it too. All use + # of these mixins is deprecated though so we give a deprecation warning + # unconditionally if they are being used with isinstance. + # + # Any code seeing this deprecation warning should be changed to use + # isinstance(M, MatrixBase) instead which also works in previous versions + # of SymPy. + # + + def __instancecheck__(cls, instance): + + sympy_deprecation_warning( + f""" + Checking whether an object is an instance of {cls.__name__} is + deprecated. + + Use `isinstance(obj, Matrix)` instead of `isinstance(obj, {cls.__name__})`. + """, + deprecated_since_version="1.13", + active_deprecations_target="deprecated-matrix-mixins", + stacklevel=3, + ) + + from sympy.matrices.matrixbase import MatrixBase + from sympy.matrices.matrices import ( + MatrixDeterminant, + MatrixReductions, + MatrixSubspaces, + MatrixEigen, + MatrixCalculus, + MatrixDeprecated + ) + + all_mixins = ( + MatrixRequired, + MatrixShaping, + MatrixSpecial, + MatrixProperties, + MatrixOperations, + MatrixArithmetic, + MatrixCommon, + MatrixDeterminant, + MatrixReductions, + MatrixSubspaces, + MatrixEigen, + MatrixCalculus, + MatrixDeprecated + ) + + if cls in all_mixins and isinstance(instance, MatrixBase): + return True + else: + return super().__instancecheck__(instance) + + +class MatrixRequired(metaclass=_MatrixDeprecatedMeta): + """Deprecated mixin class for making matrix classes.""" + + rows: int + cols: int + _simplify = None + + def __init_subclass__(cls, **kwargs): + + # Warn if any downstream code is subclassing this class or any of the + # deprecated mixin classes that are all ultimately subclasses of this + # class. + # + # We don't want to warn about the deprecated mixins themselves being + # created, but only about them being used as mixins by downstream code. + # Otherwise just importing this module would trigger a warning. + # Ultimately the whole module should be deprecated and removed but for + # SymPy 1.13 it is premature to do that given that this module was the + # main way to import matrix exception types in all previous versions. + + if cls.__name__ not in _DEPRECATED_MIXINS: + sympy_deprecation_warning( + f""" + Inheriting from the Matrix mixin classes is deprecated. + + The class {cls.__name__} is subclassing a deprecated mixin. + """, + deprecated_since_version="1.13", + active_deprecations_target="deprecated-matrix-mixins", + stacklevel=3, + ) + + super().__init_subclass__(**kwargs) + + @classmethod + def _new(cls, *args, **kwargs): + """`_new` must, at minimum, be callable as + `_new(rows, cols, mat) where mat is a flat list of the + elements of the matrix.""" + raise NotImplementedError("Subclasses must implement this.") + + def __eq__(self, other): + raise NotImplementedError("Subclasses must implement this.") + + def __getitem__(self, key): + """Implementations of __getitem__ should accept ints, in which + case the matrix is indexed as a flat list, tuples (i,j) in which + case the (i,j) entry is returned, slices, or mixed tuples (a,b) + where a and b are any combination of slices and integers.""" + raise NotImplementedError("Subclasses must implement this.") + + def __len__(self): + """The total number of entries in the matrix.""" + raise NotImplementedError("Subclasses must implement this.") + + @property + def shape(self): + raise NotImplementedError("Subclasses must implement this.") + + +class MatrixShaping(MatrixRequired): + """Provides basic matrix shaping and extracting of submatrices""" + + def _eval_col_del(self, col): + def entry(i, j): + return self[i, j] if j < col else self[i, j + 1] + return self._new(self.rows, self.cols - 1, entry) + + def _eval_col_insert(self, pos, other): + + def entry(i, j): + if j < pos: + return self[i, j] + elif pos <= j < pos + other.cols: + return other[i, j - pos] + return self[i, j - other.cols] + + return self._new(self.rows, self.cols + other.cols, entry) + + def _eval_col_join(self, other): + rows = self.rows + + def entry(i, j): + if i < rows: + return self[i, j] + return other[i - rows, j] + + return classof(self, other)._new(self.rows + other.rows, self.cols, + entry) + + def _eval_extract(self, rowsList, colsList): + mat = list(self) + cols = self.cols + indices = (i * cols + j for i in rowsList for j in colsList) + return self._new(len(rowsList), len(colsList), + [mat[i] for i in indices]) + + def _eval_get_diag_blocks(self): + sub_blocks = [] + + def recurse_sub_blocks(M): + for i in range(1, M.shape[0] + 1): + if i == 1: + to_the_right = M[0, i:] + to_the_bottom = M[i:, 0] + else: + to_the_right = M[:i, i:] + to_the_bottom = M[i:, :i] + if any(to_the_right) or any(to_the_bottom): + continue + sub_blocks.append(M[:i, :i]) + if M.shape != M[:i, :i].shape: + recurse_sub_blocks(M[i:, i:]) + return + + recurse_sub_blocks(self) + return sub_blocks + + def _eval_row_del(self, row): + def entry(i, j): + return self[i, j] if i < row else self[i + 1, j] + return self._new(self.rows - 1, self.cols, entry) + + def _eval_row_insert(self, pos, other): + entries = list(self) + insert_pos = pos * self.cols + entries[insert_pos:insert_pos] = list(other) + return self._new(self.rows + other.rows, self.cols, entries) + + def _eval_row_join(self, other): + cols = self.cols + + def entry(i, j): + if j < cols: + return self[i, j] + return other[i, j - cols] + + return classof(self, other)._new(self.rows, self.cols + other.cols, + entry) + + def _eval_tolist(self): + return [list(self[i,:]) for i in range(self.rows)] + + def _eval_todok(self): + dok = {} + rows, cols = self.shape + for i in range(rows): + for j in range(cols): + val = self[i, j] + if val != self.zero: + dok[i, j] = val + return dok + + def _eval_vec(self): + rows = self.rows + + def entry(n, _): + # we want to read off the columns first + j = n // rows + i = n - j * rows + return self[i, j] + + return self._new(len(self), 1, entry) + + def _eval_vech(self, diagonal): + c = self.cols + v = [] + if diagonal: + for j in range(c): + for i in range(j, c): + v.append(self[i, j]) + else: + for j in range(c): + for i in range(j + 1, c): + v.append(self[i, j]) + return self._new(len(v), 1, v) + + def col_del(self, col): + """Delete the specified column.""" + if col < 0: + col += self.cols + if not 0 <= col < self.cols: + raise IndexError("Column {} is out of range.".format(col)) + return self._eval_col_del(col) + + def col_insert(self, pos, other): + """Insert one or more columns at the given column position. + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(3, 1) + >>> M.col_insert(1, V) + Matrix([ + [0, 1, 0, 0], + [0, 1, 0, 0], + [0, 1, 0, 0]]) + + See Also + ======== + + col + row_insert + """ + # Allows you to build a matrix even if it is null matrix + if not self: + return type(self)(other) + + pos = as_int(pos) + + if pos < 0: + pos = self.cols + pos + if pos < 0: + pos = 0 + elif pos > self.cols: + pos = self.cols + + if self.rows != other.rows: + raise ShapeError( + "The matrices have incompatible number of rows ({} and {})" + .format(self.rows, other.rows)) + + return self._eval_col_insert(pos, other) + + def col_join(self, other): + """Concatenates two matrices along self's last and other's first row. + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(1, 3) + >>> M.col_join(V) + Matrix([ + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [1, 1, 1]]) + + See Also + ======== + + col + row_join + """ + # A null matrix can always be stacked (see #10770) + if self.rows == 0 and self.cols != other.cols: + return self._new(0, other.cols, []).col_join(other) + + if self.cols != other.cols: + raise ShapeError( + "The matrices have incompatible number of columns ({} and {})" + .format(self.cols, other.cols)) + return self._eval_col_join(other) + + def col(self, j): + """Elementary column selector. + + Examples + ======== + + >>> from sympy import eye + >>> eye(2).col(0) + Matrix([ + [1], + [0]]) + + See Also + ======== + + row + col_del + col_join + col_insert + """ + return self[:, j] + + def extract(self, rowsList, colsList): + r"""Return a submatrix by specifying a list of rows and columns. + Negative indices can be given. All indices must be in the range + $-n \le i < n$ where $n$ is the number of rows or columns. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(4, 3, range(12)) + >>> m + Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11]]) + >>> m.extract([0, 1, 3], [0, 1]) + Matrix([ + [0, 1], + [3, 4], + [9, 10]]) + + Rows or columns can be repeated: + + >>> m.extract([0, 0, 1], [-1]) + Matrix([ + [2], + [2], + [5]]) + + Every other row can be taken by using range to provide the indices: + + >>> m.extract(range(0, m.rows, 2), [-1]) + Matrix([ + [2], + [8]]) + + RowsList or colsList can also be a list of booleans, in which case + the rows or columns corresponding to the True values will be selected: + + >>> m.extract([0, 1, 2, 3], [True, False, True]) + Matrix([ + [0, 2], + [3, 5], + [6, 8], + [9, 11]]) + """ + + if not is_sequence(rowsList) or not is_sequence(colsList): + raise TypeError("rowsList and colsList must be iterable") + # ensure rowsList and colsList are lists of integers + if rowsList and all(isinstance(i, bool) for i in rowsList): + rowsList = [index for index, item in enumerate(rowsList) if item] + if colsList and all(isinstance(i, bool) for i in colsList): + colsList = [index for index, item in enumerate(colsList) if item] + + # ensure everything is in range + rowsList = [a2idx(k, self.rows) for k in rowsList] + colsList = [a2idx(k, self.cols) for k in colsList] + + return self._eval_extract(rowsList, colsList) + + def get_diag_blocks(self): + """Obtains the square sub-matrices on the main diagonal of a square matrix. + + Useful for inverting symbolic matrices or solving systems of + linear equations which may be decoupled by having a block diagonal + structure. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y, z + >>> A = Matrix([[1, 3, 0, 0], [y, z*z, 0, 0], [0, 0, x, 0], [0, 0, 0, 0]]) + >>> a1, a2, a3 = A.get_diag_blocks() + >>> a1 + Matrix([ + [1, 3], + [y, z**2]]) + >>> a2 + Matrix([[x]]) + >>> a3 + Matrix([[0]]) + + """ + return self._eval_get_diag_blocks() + + @classmethod + def hstack(cls, *args): + """Return a matrix formed by joining args horizontally (i.e. + by repeated application of row_join). + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> Matrix.hstack(eye(2), 2*eye(2)) + Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2]]) + """ + if len(args) == 0: + return cls._new() + + kls = type(args[0]) + return reduce(kls.row_join, args) + + def reshape(self, rows, cols): + """Reshape the matrix. Total number of elements must remain the same. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 3, lambda i, j: 1) + >>> m + Matrix([ + [1, 1, 1], + [1, 1, 1]]) + >>> m.reshape(1, 6) + Matrix([[1, 1, 1, 1, 1, 1]]) + >>> m.reshape(3, 2) + Matrix([ + [1, 1], + [1, 1], + [1, 1]]) + + """ + if self.rows * self.cols != rows * cols: + raise ValueError("Invalid reshape parameters %d %d" % (rows, cols)) + return self._new(rows, cols, lambda i, j: self[i * cols + j]) + + def row_del(self, row): + """Delete the specified row.""" + if row < 0: + row += self.rows + if not 0 <= row < self.rows: + raise IndexError("Row {} is out of range.".format(row)) + + return self._eval_row_del(row) + + def row_insert(self, pos, other): + """Insert one or more rows at the given row position. + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(1, 3) + >>> M.row_insert(1, V) + Matrix([ + [0, 0, 0], + [1, 1, 1], + [0, 0, 0], + [0, 0, 0]]) + + See Also + ======== + + row + col_insert + """ + # Allows you to build a matrix even if it is null matrix + if not self: + return self._new(other) + + pos = as_int(pos) + + if pos < 0: + pos = self.rows + pos + if pos < 0: + pos = 0 + elif pos > self.rows: + pos = self.rows + + if self.cols != other.cols: + raise ShapeError( + "The matrices have incompatible number of columns ({} and {})" + .format(self.cols, other.cols)) + + return self._eval_row_insert(pos, other) + + def row_join(self, other): + """Concatenates two matrices along self's last and rhs's first column + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(3, 1) + >>> M.row_join(V) + Matrix([ + [0, 0, 0, 1], + [0, 0, 0, 1], + [0, 0, 0, 1]]) + + See Also + ======== + + row + col_join + """ + # A null matrix can always be stacked (see #10770) + if self.cols == 0 and self.rows != other.rows: + return self._new(other.rows, 0, []).row_join(other) + + if self.rows != other.rows: + raise ShapeError( + "The matrices have incompatible number of rows ({} and {})" + .format(self.rows, other.rows)) + return self._eval_row_join(other) + + def diagonal(self, k=0): + """Returns the kth diagonal of self. The main diagonal + corresponds to `k=0`; diagonals above and below correspond to + `k > 0` and `k < 0`, respectively. The values of `self[i, j]` + for which `j - i = k`, are returned in order of increasing + `i + j`, starting with `i + j = |k|`. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(3, 3, lambda i, j: j - i); m + Matrix([ + [ 0, 1, 2], + [-1, 0, 1], + [-2, -1, 0]]) + >>> _.diagonal() + Matrix([[0, 0, 0]]) + >>> m.diagonal(1) + Matrix([[1, 1]]) + >>> m.diagonal(-2) + Matrix([[-2]]) + + Even though the diagonal is returned as a Matrix, the element + retrieval can be done with a single index: + + >>> Matrix.diag(1, 2, 3).diagonal()[1] # instead of [0, 1] + 2 + + See Also + ======== + + diag + """ + rv = [] + k = as_int(k) + r = 0 if k > 0 else -k + c = 0 if r else k + while True: + if r == self.rows or c == self.cols: + break + rv.append(self[r, c]) + r += 1 + c += 1 + if not rv: + raise ValueError(filldedent(''' + The %s diagonal is out of range [%s, %s]''' % ( + k, 1 - self.rows, self.cols - 1))) + return self._new(1, len(rv), rv) + + def row(self, i): + """Elementary row selector. + + Examples + ======== + + >>> from sympy import eye + >>> eye(2).row(0) + Matrix([[1, 0]]) + + See Also + ======== + + col + row_del + row_join + row_insert + """ + return self[i, :] + + @property + def shape(self): + """The shape (dimensions) of the matrix as the 2-tuple (rows, cols). + + Examples + ======== + + >>> from sympy import zeros + >>> M = zeros(2, 3) + >>> M.shape + (2, 3) + >>> M.rows + 2 + >>> M.cols + 3 + """ + return (self.rows, self.cols) + + def todok(self): + """Return the matrix as dictionary of keys. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix.eye(3) + >>> M.todok() + {(0, 0): 1, (1, 1): 1, (2, 2): 1} + """ + return self._eval_todok() + + def tolist(self): + """Return the Matrix as a nested Python list. + + Examples + ======== + + >>> from sympy import Matrix, ones + >>> m = Matrix(3, 3, range(9)) + >>> m + Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> m.tolist() + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + >>> ones(3, 0).tolist() + [[], [], []] + + When there are no rows then it will not be possible to tell how + many columns were in the original matrix: + + >>> ones(0, 3).tolist() + [] + + """ + if not self.rows: + return [] + if not self.cols: + return [[] for i in range(self.rows)] + return self._eval_tolist() + + def todod(M): + """Returns matrix as dict of dicts containing non-zero elements of the Matrix + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[0, 1],[0, 3]]) + >>> A + Matrix([ + [0, 1], + [0, 3]]) + >>> A.todod() + {0: {1: 1}, 1: {1: 3}} + + + """ + rowsdict = {} + Mlol = M.tolist() + for i, Mi in enumerate(Mlol): + row = {j: Mij for j, Mij in enumerate(Mi) if Mij} + if row: + rowsdict[i] = row + return rowsdict + + def vec(self): + """Return the Matrix converted into a one column matrix by stacking columns + + Examples + ======== + + >>> from sympy import Matrix + >>> m=Matrix([[1, 3], [2, 4]]) + >>> m + Matrix([ + [1, 3], + [2, 4]]) + >>> m.vec() + Matrix([ + [1], + [2], + [3], + [4]]) + + See Also + ======== + + vech + """ + return self._eval_vec() + + def vech(self, diagonal=True, check_symmetry=True): + """Reshapes the matrix into a column vector by stacking the + elements in the lower triangle. + + Parameters + ========== + + diagonal : bool, optional + If ``True``, it includes the diagonal elements. + + check_symmetry : bool, optional + If ``True``, it checks whether the matrix is symmetric. + + Examples + ======== + + >>> from sympy import Matrix + >>> m=Matrix([[1, 2], [2, 3]]) + >>> m + Matrix([ + [1, 2], + [2, 3]]) + >>> m.vech() + Matrix([ + [1], + [2], + [3]]) + >>> m.vech(diagonal=False) + Matrix([[2]]) + + Notes + ===== + + This should work for symmetric matrices and ``vech`` can + represent symmetric matrices in vector form with less size than + ``vec``. + + See Also + ======== + + vec + """ + if not self.is_square: + raise NonSquareMatrixError + + if check_symmetry and not self.is_symmetric(): + raise ValueError("The matrix is not symmetric.") + + return self._eval_vech(diagonal) + + @classmethod + def vstack(cls, *args): + """Return a matrix formed by joining args vertically (i.e. + by repeated application of col_join). + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> Matrix.vstack(eye(2), 2*eye(2)) + Matrix([ + [1, 0], + [0, 1], + [2, 0], + [0, 2]]) + """ + if len(args) == 0: + return cls._new() + + kls = type(args[0]) + return reduce(kls.col_join, args) + + +class MatrixSpecial(MatrixRequired): + """Construction of special matrices""" + + @classmethod + def _eval_diag(cls, rows, cols, diag_dict): + """diag_dict is a defaultdict containing + all the entries of the diagonal matrix.""" + def entry(i, j): + return diag_dict[(i, j)] + return cls._new(rows, cols, entry) + + @classmethod + def _eval_eye(cls, rows, cols): + vals = [cls.zero]*(rows*cols) + vals[::cols+1] = [cls.one]*min(rows, cols) + return cls._new(rows, cols, vals, copy=False) + + @classmethod + def _eval_jordan_block(cls, size: int, eigenvalue, band='upper'): + if band == 'lower': + def entry(i, j): + if i == j: + return eigenvalue + elif j + 1 == i: + return cls.one + return cls.zero + else: + def entry(i, j): + if i == j: + return eigenvalue + elif i + 1 == j: + return cls.one + return cls.zero + return cls._new(size, size, entry) + + @classmethod + def _eval_ones(cls, rows, cols): + def entry(i, j): + return cls.one + return cls._new(rows, cols, entry) + + @classmethod + def _eval_zeros(cls, rows, cols): + return cls._new(rows, cols, [cls.zero]*(rows*cols), copy=False) + + @classmethod + def _eval_wilkinson(cls, n): + def entry(i, j): + return cls.one if i + 1 == j else cls.zero + + D = cls._new(2*n + 1, 2*n + 1, entry) + + wminus = cls.diag(list(range(-n, n + 1)), unpack=True) + D + D.T + wplus = abs(cls.diag(list(range(-n, n + 1)), unpack=True)) + D + D.T + + return wminus, wplus + + @classmethod + def diag(kls, *args, strict=False, unpack=True, rows=None, cols=None, **kwargs): + """Returns a matrix with the specified diagonal. + If matrices are passed, a block-diagonal matrix + is created (i.e. the "direct sum" of the matrices). + + kwargs + ====== + + rows : rows of the resulting matrix; computed if + not given. + + cols : columns of the resulting matrix; computed if + not given. + + cls : class for the resulting matrix + + unpack : bool which, when True (default), unpacks a single + sequence rather than interpreting it as a Matrix. + + strict : bool which, when False (default), allows Matrices to + have variable-length rows. + + Examples + ======== + + >>> from sympy import Matrix + >>> Matrix.diag(1, 2, 3) + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + + The current default is to unpack a single sequence. If this is + not desired, set `unpack=False` and it will be interpreted as + a matrix. + + >>> Matrix.diag([1, 2, 3]) == Matrix.diag(1, 2, 3) + True + + When more than one element is passed, each is interpreted as + something to put on the diagonal. Lists are converted to + matrices. Filling of the diagonal always continues from + the bottom right hand corner of the previous item: this + will create a block-diagonal matrix whether the matrices + are square or not. + + >>> col = [1, 2, 3] + >>> row = [[4, 5]] + >>> Matrix.diag(col, row) + Matrix([ + [1, 0, 0], + [2, 0, 0], + [3, 0, 0], + [0, 4, 5]]) + + When `unpack` is False, elements within a list need not all be + of the same length. Setting `strict` to True would raise a + ValueError for the following: + + >>> Matrix.diag([[1, 2, 3], [4, 5], [6]], unpack=False) + Matrix([ + [1, 2, 3], + [4, 5, 0], + [6, 0, 0]]) + + The type of the returned matrix can be set with the ``cls`` + keyword. + + >>> from sympy import ImmutableMatrix + >>> from sympy.utilities.misc import func_name + >>> func_name(Matrix.diag(1, cls=ImmutableMatrix)) + 'ImmutableDenseMatrix' + + A zero dimension matrix can be used to position the start of + the filling at the start of an arbitrary row or column: + + >>> from sympy import ones + >>> r2 = ones(0, 2) + >>> Matrix.diag(r2, 1, 2) + Matrix([ + [0, 0, 1, 0], + [0, 0, 0, 2]]) + + See Also + ======== + eye + diagonal + .dense.diag + .expressions.blockmatrix.BlockMatrix + .sparsetools.banded + """ + from sympy.matrices.matrixbase import MatrixBase + from sympy.matrices.dense import Matrix + from sympy.matrices import SparseMatrix + klass = kwargs.get('cls', kls) + if unpack and len(args) == 1 and is_sequence(args[0]) and \ + not isinstance(args[0], MatrixBase): + args = args[0] + + # fill a default dict with the diagonal entries + diag_entries = defaultdict(int) + rmax = cmax = 0 # keep track of the biggest index seen + for m in args: + if isinstance(m, list): + if strict: + # if malformed, Matrix will raise an error + _ = Matrix(m) + r, c = _.shape + m = _.tolist() + else: + r, c, smat = SparseMatrix._handle_creation_inputs(m) + for (i, j), _ in smat.items(): + diag_entries[(i + rmax, j + cmax)] = _ + m = [] # to skip process below + elif hasattr(m, 'shape'): # a Matrix + # convert to list of lists + r, c = m.shape + m = m.tolist() + else: # in this case, we're a single value + diag_entries[(rmax, cmax)] = m + rmax += 1 + cmax += 1 + continue + # process list of lists + for i, mi in enumerate(m): + for j, _ in enumerate(mi): + diag_entries[(i + rmax, j + cmax)] = _ + rmax += r + cmax += c + if rows is None: + rows, cols = cols, rows + if rows is None: + rows, cols = rmax, cmax + else: + cols = rows if cols is None else cols + if rows < rmax or cols < cmax: + raise ValueError(filldedent(''' + The constructed matrix is {} x {} but a size of {} x {} + was specified.'''.format(rmax, cmax, rows, cols))) + return klass._eval_diag(rows, cols, diag_entries) + + @classmethod + def eye(kls, rows, cols=None, **kwargs): + """Returns an identity matrix. + + Parameters + ========== + + rows : rows of the matrix + cols : cols of the matrix (if None, cols=rows) + + kwargs + ====== + cls : class of the returned matrix + """ + if cols is None: + cols = rows + if rows < 0 or cols < 0: + raise ValueError("Cannot create a {} x {} matrix. " + "Both dimensions must be positive".format(rows, cols)) + klass = kwargs.get('cls', kls) + rows, cols = as_int(rows), as_int(cols) + + return klass._eval_eye(rows, cols) + + @classmethod + def jordan_block(kls, size=None, eigenvalue=None, *, band='upper', **kwargs): + """Returns a Jordan block + + Parameters + ========== + + size : Integer, optional + Specifies the shape of the Jordan block matrix. + + eigenvalue : Number or Symbol + Specifies the value for the main diagonal of the matrix. + + .. note:: + The keyword ``eigenval`` is also specified as an alias + of this keyword, but it is not recommended to use. + + We may deprecate the alias in later release. + + band : 'upper' or 'lower', optional + Specifies the position of the off-diagonal to put `1` s on. + + cls : Matrix, optional + Specifies the matrix class of the output form. + + If it is not specified, the class type where the method is + being executed on will be returned. + + Returns + ======= + + Matrix + A Jordan block matrix. + + Raises + ====== + + ValueError + If insufficient arguments are given for matrix size + specification, or no eigenvalue is given. + + Examples + ======== + + Creating a default Jordan block: + + >>> from sympy import Matrix + >>> from sympy.abc import x + >>> Matrix.jordan_block(4, x) + Matrix([ + [x, 1, 0, 0], + [0, x, 1, 0], + [0, 0, x, 1], + [0, 0, 0, x]]) + + Creating an alternative Jordan block matrix where `1` is on + lower off-diagonal: + + >>> Matrix.jordan_block(4, x, band='lower') + Matrix([ + [x, 0, 0, 0], + [1, x, 0, 0], + [0, 1, x, 0], + [0, 0, 1, x]]) + + Creating a Jordan block with keyword arguments + + >>> Matrix.jordan_block(size=4, eigenvalue=x) + Matrix([ + [x, 1, 0, 0], + [0, x, 1, 0], + [0, 0, x, 1], + [0, 0, 0, x]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Jordan_matrix + """ + klass = kwargs.pop('cls', kls) + + eigenval = kwargs.get('eigenval', None) + if eigenvalue is None and eigenval is None: + raise ValueError("Must supply an eigenvalue") + elif eigenvalue != eigenval and None not in (eigenval, eigenvalue): + raise ValueError( + "Inconsistent values are given: 'eigenval'={}, " + "'eigenvalue'={}".format(eigenval, eigenvalue)) + else: + if eigenval is not None: + eigenvalue = eigenval + + if size is None: + raise ValueError("Must supply a matrix size") + + size = as_int(size) + return klass._eval_jordan_block(size, eigenvalue, band) + + @classmethod + def ones(kls, rows, cols=None, **kwargs): + """Returns a matrix of ones. + + Parameters + ========== + + rows : rows of the matrix + cols : cols of the matrix (if None, cols=rows) + + kwargs + ====== + cls : class of the returned matrix + """ + if cols is None: + cols = rows + klass = kwargs.get('cls', kls) + rows, cols = as_int(rows), as_int(cols) + + return klass._eval_ones(rows, cols) + + @classmethod + def zeros(kls, rows, cols=None, **kwargs): + """Returns a matrix of zeros. + + Parameters + ========== + + rows : rows of the matrix + cols : cols of the matrix (if None, cols=rows) + + kwargs + ====== + cls : class of the returned matrix + """ + if cols is None: + cols = rows + if rows < 0 or cols < 0: + raise ValueError("Cannot create a {} x {} matrix. " + "Both dimensions must be positive".format(rows, cols)) + klass = kwargs.get('cls', kls) + rows, cols = as_int(rows), as_int(cols) + + return klass._eval_zeros(rows, cols) + + @classmethod + def companion(kls, poly): + """Returns a companion matrix of a polynomial. + + Examples + ======== + + >>> from sympy import Matrix, Poly, Symbol, symbols + >>> x = Symbol('x') + >>> c0, c1, c2, c3, c4 = symbols('c0:5') + >>> p = Poly(c0 + c1*x + c2*x**2 + c3*x**3 + c4*x**4 + x**5, x) + >>> Matrix.companion(p) + Matrix([ + [0, 0, 0, 0, -c0], + [1, 0, 0, 0, -c1], + [0, 1, 0, 0, -c2], + [0, 0, 1, 0, -c3], + [0, 0, 0, 1, -c4]]) + """ + poly = kls._sympify(poly) + if not isinstance(poly, Poly): + raise ValueError("{} must be a Poly instance.".format(poly)) + if not poly.is_monic: + raise ValueError("{} must be a monic polynomial.".format(poly)) + if not poly.is_univariate: + raise ValueError( + "{} must be a univariate polynomial.".format(poly)) + + size = poly.degree() + if not size >= 1: + raise ValueError( + "{} must have degree not less than 1.".format(poly)) + + coeffs = poly.all_coeffs() + def entry(i, j): + if j == size - 1: + return -coeffs[-1 - i] + elif i == j + 1: + return kls.one + return kls.zero + return kls._new(size, size, entry) + + + @classmethod + def wilkinson(kls, n, **kwargs): + """Returns two square Wilkinson Matrix of size 2*n + 1 + $W_{2n + 1}^-, W_{2n + 1}^+ =$ Wilkinson(n) + + Examples + ======== + + >>> from sympy import Matrix + >>> wminus, wplus = Matrix.wilkinson(3) + >>> wminus + Matrix([ + [-3, 1, 0, 0, 0, 0, 0], + [ 1, -2, 1, 0, 0, 0, 0], + [ 0, 1, -1, 1, 0, 0, 0], + [ 0, 0, 1, 0, 1, 0, 0], + [ 0, 0, 0, 1, 1, 1, 0], + [ 0, 0, 0, 0, 1, 2, 1], + [ 0, 0, 0, 0, 0, 1, 3]]) + >>> wplus + Matrix([ + [3, 1, 0, 0, 0, 0, 0], + [1, 2, 1, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 2, 1], + [0, 0, 0, 0, 0, 1, 3]]) + + References + ========== + + .. [1] https://blogs.mathworks.com/cleve/2013/04/15/wilkinsons-matrices-2/ + .. [2] J. H. Wilkinson, The Algebraic Eigenvalue Problem, Claredon Press, Oxford, 1965, 662 pp. + + """ + klass = kwargs.get('cls', kls) + n = as_int(n) + return klass._eval_wilkinson(n) + +class MatrixProperties(MatrixRequired): + """Provides basic properties of a matrix.""" + + def _eval_atoms(self, *types): + result = set() + for i in self: + result.update(i.atoms(*types)) + return result + + def _eval_free_symbols(self): + return set().union(*(i.free_symbols for i in self if i)) + + def _eval_has(self, *patterns): + return any(a.has(*patterns) for a in self) + + def _eval_is_anti_symmetric(self, simpfunc): + if not all(simpfunc(self[i, j] + self[j, i]).is_zero for i in range(self.rows) for j in range(self.cols)): + return False + return True + + def _eval_is_diagonal(self): + for i in range(self.rows): + for j in range(self.cols): + if i != j and self[i, j]: + return False + return True + + # _eval_is_hermitian is called by some general SymPy + # routines and has a different *args signature. Make + # sure the names don't clash by adding `_matrix_` in name. + def _eval_is_matrix_hermitian(self, simpfunc): + mat = self._new(self.rows, self.cols, lambda i, j: simpfunc(self[i, j] - self[j, i].conjugate())) + return mat.is_zero_matrix + + def _eval_is_Identity(self) -> FuzzyBool: + def dirac(i, j): + if i == j: + return 1 + return 0 + + return all(self[i, j] == dirac(i, j) + for i in range(self.rows) + for j in range(self.cols)) + + def _eval_is_lower_hessenberg(self): + return all(self[i, j].is_zero + for i in range(self.rows) + for j in range(i + 2, self.cols)) + + def _eval_is_lower(self): + return all(self[i, j].is_zero + for i in range(self.rows) + for j in range(i + 1, self.cols)) + + def _eval_is_symbolic(self): + return self.has(Symbol) + + def _eval_is_symmetric(self, simpfunc): + mat = self._new(self.rows, self.cols, lambda i, j: simpfunc(self[i, j] - self[j, i])) + return mat.is_zero_matrix + + def _eval_is_zero_matrix(self): + if any(i.is_zero == False for i in self): + return False + if any(i.is_zero is None for i in self): + return None + return True + + def _eval_is_upper_hessenberg(self): + return all(self[i, j].is_zero + for i in range(2, self.rows) + for j in range(min(self.cols, (i - 1)))) + + def _eval_values(self): + return [i for i in self if not i.is_zero] + + def _has_positive_diagonals(self): + diagonal_entries = (self[i, i] for i in range(self.rows)) + return fuzzy_and(x.is_positive for x in diagonal_entries) + + def _has_nonnegative_diagonals(self): + diagonal_entries = (self[i, i] for i in range(self.rows)) + return fuzzy_and(x.is_nonnegative for x in diagonal_entries) + + def atoms(self, *types): + """Returns the atoms that form the current object. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import Matrix + >>> Matrix([[x]]) + Matrix([[x]]) + >>> _.atoms() + {x} + >>> Matrix([[x, y], [y, x]]) + Matrix([ + [x, y], + [y, x]]) + >>> _.atoms() + {x, y} + """ + + types = tuple(t if isinstance(t, type) else type(t) for t in types) + if not types: + types = (Atom,) + return self._eval_atoms(*types) + + @property + def free_symbols(self): + """Returns the free symbols within the matrix. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import Matrix + >>> Matrix([[x], [1]]).free_symbols + {x} + """ + return self._eval_free_symbols() + + def has(self, *patterns): + """Test whether any subexpression matches any of the patterns. + + Examples + ======== + + >>> from sympy import Matrix, SparseMatrix, Float + >>> from sympy.abc import x, y + >>> A = Matrix(((1, x), (0.2, 3))) + >>> B = SparseMatrix(((1, x), (0.2, 3))) + >>> A.has(x) + True + >>> A.has(y) + False + >>> A.has(Float) + True + >>> B.has(x) + True + >>> B.has(y) + False + >>> B.has(Float) + True + """ + return self._eval_has(*patterns) + + def is_anti_symmetric(self, simplify=True): + """Check if matrix M is an antisymmetric matrix, + that is, M is a square matrix with all M[i, j] == -M[j, i]. + + When ``simplify=True`` (default), the sum M[i, j] + M[j, i] is + simplified before testing to see if it is zero. By default, + the SymPy simplify function is used. To use a custom function + set simplify to a function that accepts a single argument which + returns a simplified expression. To skip simplification, set + simplify to False but note that although this will be faster, + it may induce false negatives. + + Examples + ======== + + >>> from sympy import Matrix, symbols + >>> m = Matrix(2, 2, [0, 1, -1, 0]) + >>> m + Matrix([ + [ 0, 1], + [-1, 0]]) + >>> m.is_anti_symmetric() + True + >>> x, y = symbols('x y') + >>> m = Matrix(2, 3, [0, 0, x, -y, 0, 0]) + >>> m + Matrix([ + [ 0, 0, x], + [-y, 0, 0]]) + >>> m.is_anti_symmetric() + False + + >>> from sympy.abc import x, y + >>> m = Matrix(3, 3, [0, x**2 + 2*x + 1, y, + ... -(x + 1)**2, 0, x*y, + ... -y, -x*y, 0]) + + Simplification of matrix elements is done by default so even + though two elements which should be equal and opposite would not + pass an equality test, the matrix is still reported as + anti-symmetric: + + >>> m[0, 1] == -m[1, 0] + False + >>> m.is_anti_symmetric() + True + + If ``simplify=False`` is used for the case when a Matrix is already + simplified, this will speed things up. Here, we see that without + simplification the matrix does not appear anti-symmetric: + + >>> print(m.is_anti_symmetric(simplify=False)) + None + + But if the matrix were already expanded, then it would appear + anti-symmetric and simplification in the is_anti_symmetric routine + is not needed: + + >>> m = m.expand() + >>> m.is_anti_symmetric(simplify=False) + True + """ + # accept custom simplification + simpfunc = simplify + if not isfunction(simplify): + simpfunc = _simplify if simplify else lambda x: x + + if not self.is_square: + return False + return self._eval_is_anti_symmetric(simpfunc) + + def is_diagonal(self): + """Check if matrix is diagonal, + that is matrix in which the entries outside the main diagonal are all zero. + + Examples + ======== + + >>> from sympy import Matrix, diag + >>> m = Matrix(2, 2, [1, 0, 0, 2]) + >>> m + Matrix([ + [1, 0], + [0, 2]]) + >>> m.is_diagonal() + True + + >>> m = Matrix(2, 2, [1, 1, 0, 2]) + >>> m + Matrix([ + [1, 1], + [0, 2]]) + >>> m.is_diagonal() + False + + >>> m = diag(1, 2, 3) + >>> m + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + >>> m.is_diagonal() + True + + See Also + ======== + + is_lower + is_upper + sympy.matrices.matrixbase.MatrixCommon.is_diagonalizable + diagonalize + """ + return self._eval_is_diagonal() + + @property + def is_weakly_diagonally_dominant(self): + r"""Tests if the matrix is row weakly diagonally dominant. + + Explanation + =========== + + A $n, n$ matrix $A$ is row weakly diagonally dominant if + + .. math:: + \left|A_{i, i}\right| \ge \sum_{j = 0, j \neq i}^{n-1} + \left|A_{i, j}\right| \quad {\text{for all }} + i \in \{ 0, ..., n-1 \} + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[3, -2, 1], [1, -3, 2], [-1, 2, 4]]) + >>> A.is_weakly_diagonally_dominant + True + + >>> A = Matrix([[-2, 2, 1], [1, 3, 2], [1, -2, 0]]) + >>> A.is_weakly_diagonally_dominant + False + + >>> A = Matrix([[-4, 2, 1], [1, 6, 2], [1, -2, 5]]) + >>> A.is_weakly_diagonally_dominant + True + + Notes + ===== + + If you want to test whether a matrix is column diagonally + dominant, you can apply the test after transposing the matrix. + """ + if not self.is_square: + return False + + rows, cols = self.shape + + def test_row(i): + summation = self.zero + for j in range(cols): + if i != j: + summation += Abs(self[i, j]) + return (Abs(self[i, i]) - summation).is_nonnegative + + return fuzzy_and(test_row(i) for i in range(rows)) + + @property + def is_strongly_diagonally_dominant(self): + r"""Tests if the matrix is row strongly diagonally dominant. + + Explanation + =========== + + A $n, n$ matrix $A$ is row strongly diagonally dominant if + + .. math:: + \left|A_{i, i}\right| > \sum_{j = 0, j \neq i}^{n-1} + \left|A_{i, j}\right| \quad {\text{for all }} + i \in \{ 0, ..., n-1 \} + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[3, -2, 1], [1, -3, 2], [-1, 2, 4]]) + >>> A.is_strongly_diagonally_dominant + False + + >>> A = Matrix([[-2, 2, 1], [1, 3, 2], [1, -2, 0]]) + >>> A.is_strongly_diagonally_dominant + False + + >>> A = Matrix([[-4, 2, 1], [1, 6, 2], [1, -2, 5]]) + >>> A.is_strongly_diagonally_dominant + True + + Notes + ===== + + If you want to test whether a matrix is column diagonally + dominant, you can apply the test after transposing the matrix. + """ + if not self.is_square: + return False + + rows, cols = self.shape + + def test_row(i): + summation = self.zero + for j in range(cols): + if i != j: + summation += Abs(self[i, j]) + return (Abs(self[i, i]) - summation).is_positive + + return fuzzy_and(test_row(i) for i in range(rows)) + + @property + def is_hermitian(self): + """Checks if the matrix is Hermitian. + + In a Hermitian matrix element i,j is the complex conjugate of + element j,i. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy import I + >>> from sympy.abc import x + >>> a = Matrix([[1, I], [-I, 1]]) + >>> a + Matrix([ + [ 1, I], + [-I, 1]]) + >>> a.is_hermitian + True + >>> a[0, 0] = 2*I + >>> a.is_hermitian + False + >>> a[0, 0] = x + >>> a.is_hermitian + >>> a[0, 1] = a[1, 0]*I + >>> a.is_hermitian + False + """ + if not self.is_square: + return False + + return self._eval_is_matrix_hermitian(_simplify) + + @property + def is_Identity(self) -> FuzzyBool: + if not self.is_square: + return False + return self._eval_is_Identity() + + @property + def is_lower_hessenberg(self): + r"""Checks if the matrix is in the lower-Hessenberg form. + + The lower hessenberg matrix has zero entries + above the first superdiagonal. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[1, 2, 0, 0], [5, 2, 3, 0], [3, 4, 3, 7], [5, 6, 1, 1]]) + >>> a + Matrix([ + [1, 2, 0, 0], + [5, 2, 3, 0], + [3, 4, 3, 7], + [5, 6, 1, 1]]) + >>> a.is_lower_hessenberg + True + + See Also + ======== + + is_upper_hessenberg + is_lower + """ + return self._eval_is_lower_hessenberg() + + @property + def is_lower(self): + """Check if matrix is a lower triangular matrix. True can be returned + even if the matrix is not square. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, [1, 0, 0, 1]) + >>> m + Matrix([ + [1, 0], + [0, 1]]) + >>> m.is_lower + True + + >>> m = Matrix(4, 3, [0, 0, 0, 2, 0, 0, 1, 4, 0, 6, 6, 5]) + >>> m + Matrix([ + [0, 0, 0], + [2, 0, 0], + [1, 4, 0], + [6, 6, 5]]) + >>> m.is_lower + True + + >>> from sympy.abc import x, y + >>> m = Matrix(2, 2, [x**2 + y, y**2 + x, 0, x + y]) + >>> m + Matrix([ + [x**2 + y, x + y**2], + [ 0, x + y]]) + >>> m.is_lower + False + + See Also + ======== + + is_upper + is_diagonal + is_lower_hessenberg + """ + return self._eval_is_lower() + + @property + def is_square(self): + """Checks if a matrix is square. + + A matrix is square if the number of rows equals the number of columns. + The empty matrix is square by definition, since the number of rows and + the number of columns are both zero. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> b = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> c = Matrix([]) + >>> a.is_square + False + >>> b.is_square + True + >>> c.is_square + True + """ + return self.rows == self.cols + + def is_symbolic(self): + """Checks if any elements contain Symbols. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.is_symbolic() + True + + """ + return self._eval_is_symbolic() + + def is_symmetric(self, simplify=True): + """Check if matrix is symmetric matrix, + that is square matrix and is equal to its transpose. + + By default, simplifications occur before testing symmetry. + They can be skipped using 'simplify=False'; while speeding things a bit, + this may however induce false negatives. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, [0, 1, 1, 2]) + >>> m + Matrix([ + [0, 1], + [1, 2]]) + >>> m.is_symmetric() + True + + >>> m = Matrix(2, 2, [0, 1, 2, 0]) + >>> m + Matrix([ + [0, 1], + [2, 0]]) + >>> m.is_symmetric() + False + + >>> m = Matrix(2, 3, [0, 0, 0, 0, 0, 0]) + >>> m + Matrix([ + [0, 0, 0], + [0, 0, 0]]) + >>> m.is_symmetric() + False + + >>> from sympy.abc import x, y + >>> m = Matrix(3, 3, [1, x**2 + 2*x + 1, y, (x + 1)**2, 2, 0, y, 0, 3]) + >>> m + Matrix([ + [ 1, x**2 + 2*x + 1, y], + [(x + 1)**2, 2, 0], + [ y, 0, 3]]) + >>> m.is_symmetric() + True + + If the matrix is already simplified, you may speed-up is_symmetric() + test by using 'simplify=False'. + + >>> bool(m.is_symmetric(simplify=False)) + False + >>> m1 = m.expand() + >>> m1.is_symmetric(simplify=False) + True + """ + simpfunc = simplify + if not isfunction(simplify): + simpfunc = _simplify if simplify else lambda x: x + + if not self.is_square: + return False + + return self._eval_is_symmetric(simpfunc) + + @property + def is_upper_hessenberg(self): + """Checks if the matrix is the upper-Hessenberg form. + + The upper hessenberg matrix has zero entries + below the first subdiagonal. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[1, 4, 2, 3], [3, 4, 1, 7], [0, 2, 3, 4], [0, 0, 1, 3]]) + >>> a + Matrix([ + [1, 4, 2, 3], + [3, 4, 1, 7], + [0, 2, 3, 4], + [0, 0, 1, 3]]) + >>> a.is_upper_hessenberg + True + + See Also + ======== + + is_lower_hessenberg + is_upper + """ + return self._eval_is_upper_hessenberg() + + @property + def is_upper(self): + """Check if matrix is an upper triangular matrix. True can be returned + even if the matrix is not square. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, [1, 0, 0, 1]) + >>> m + Matrix([ + [1, 0], + [0, 1]]) + >>> m.is_upper + True + + >>> m = Matrix(4, 3, [5, 1, 9, 0, 4, 6, 0, 0, 5, 0, 0, 0]) + >>> m + Matrix([ + [5, 1, 9], + [0, 4, 6], + [0, 0, 5], + [0, 0, 0]]) + >>> m.is_upper + True + + >>> m = Matrix(2, 3, [4, 2, 5, 6, 1, 1]) + >>> m + Matrix([ + [4, 2, 5], + [6, 1, 1]]) + >>> m.is_upper + False + + See Also + ======== + + is_lower + is_diagonal + is_upper_hessenberg + """ + return all(self[i, j].is_zero + for i in range(1, self.rows) + for j in range(min(i, self.cols))) + + @property + def is_zero_matrix(self): + """Checks if a matrix is a zero matrix. + + A matrix is zero if every element is zero. A matrix need not be square + to be considered zero. The empty matrix is zero by the principle of + vacuous truth. For a matrix that may or may not be zero (e.g. + contains a symbol), this will be None + + Examples + ======== + + >>> from sympy import Matrix, zeros + >>> from sympy.abc import x + >>> a = Matrix([[0, 0], [0, 0]]) + >>> b = zeros(3, 4) + >>> c = Matrix([[0, 1], [0, 0]]) + >>> d = Matrix([]) + >>> e = Matrix([[x, 0], [0, 0]]) + >>> a.is_zero_matrix + True + >>> b.is_zero_matrix + True + >>> c.is_zero_matrix + False + >>> d.is_zero_matrix + True + >>> e.is_zero_matrix + """ + return self._eval_is_zero_matrix() + + def values(self): + """Return non-zero values of self.""" + return self._eval_values() + + +class MatrixOperations(MatrixRequired): + """Provides basic matrix shape and elementwise + operations. Should not be instantiated directly.""" + + def _eval_adjoint(self): + return self.transpose().conjugate() + + def _eval_applyfunc(self, f): + out = self._new(self.rows, self.cols, [f(x) for x in self]) + return out + + def _eval_as_real_imag(self): # type: ignore + return (self.applyfunc(re), self.applyfunc(im)) + + def _eval_conjugate(self): + return self.applyfunc(lambda x: x.conjugate()) + + def _eval_permute_cols(self, perm): + # apply the permutation to a list + mapping = list(perm) + + def entry(i, j): + return self[i, mapping[j]] + + return self._new(self.rows, self.cols, entry) + + def _eval_permute_rows(self, perm): + # apply the permutation to a list + mapping = list(perm) + + def entry(i, j): + return self[mapping[i], j] + + return self._new(self.rows, self.cols, entry) + + def _eval_trace(self): + return sum(self[i, i] for i in range(self.rows)) + + def _eval_transpose(self): + return self._new(self.cols, self.rows, lambda i, j: self[j, i]) + + def adjoint(self): + """Conjugate transpose or Hermitian conjugation.""" + return self._eval_adjoint() + + def applyfunc(self, f): + """Apply a function to each element of the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, lambda i, j: i*2+j) + >>> m + Matrix([ + [0, 1], + [2, 3]]) + >>> m.applyfunc(lambda i: 2*i) + Matrix([ + [0, 2], + [4, 6]]) + + """ + if not callable(f): + raise TypeError("`f` must be callable.") + + return self._eval_applyfunc(f) + + def as_real_imag(self, deep=True, **hints): + """Returns a tuple containing the (real, imaginary) part of matrix.""" + # XXX: Ignoring deep and hints... + return self._eval_as_real_imag() + + def conjugate(self): + """Return the by-element conjugation. + + Examples + ======== + + >>> from sympy import SparseMatrix, I + >>> a = SparseMatrix(((1, 2 + I), (3, 4), (I, -I))) + >>> a + Matrix([ + [1, 2 + I], + [3, 4], + [I, -I]]) + >>> a.C + Matrix([ + [ 1, 2 - I], + [ 3, 4], + [-I, I]]) + + See Also + ======== + + transpose: Matrix transposition + H: Hermite conjugation + sympy.matrices.matrixbase.MatrixBase.D: Dirac conjugation + """ + return self._eval_conjugate() + + def doit(self, **hints): + return self.applyfunc(lambda x: x.doit(**hints)) + + def evalf(self, n=15, subs=None, maxn=100, chop=False, strict=False, quad=None, verbose=False): + """Apply evalf() to each element of self.""" + options = {'subs':subs, 'maxn':maxn, 'chop':chop, 'strict':strict, + 'quad':quad, 'verbose':verbose} + return self.applyfunc(lambda i: i.evalf(n, **options)) + + def expand(self, deep=True, modulus=None, power_base=True, power_exp=True, + mul=True, log=True, multinomial=True, basic=True, **hints): + """Apply core.function.expand to each entry of the matrix. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import Matrix + >>> Matrix(1, 1, [x*(x+1)]) + Matrix([[x*(x + 1)]]) + >>> _.expand() + Matrix([[x**2 + x]]) + + """ + return self.applyfunc(lambda x: x.expand( + deep, modulus, power_base, power_exp, mul, log, multinomial, basic, + **hints)) + + @property + def H(self): + """Return Hermite conjugate. + + Examples + ======== + + >>> from sympy import Matrix, I + >>> m = Matrix((0, 1 + I, 2, 3)) + >>> m + Matrix([ + [ 0], + [1 + I], + [ 2], + [ 3]]) + >>> m.H + Matrix([[0, 1 - I, 2, 3]]) + + See Also + ======== + + conjugate: By-element conjugation + sympy.matrices.matrixbase.MatrixBase.D: Dirac conjugation + """ + return self.T.C + + def permute(self, perm, orientation='rows', direction='forward'): + r"""Permute the rows or columns of a matrix by the given list of + swaps. + + Parameters + ========== + + perm : Permutation, list, or list of lists + A representation for the permutation. + + If it is ``Permutation``, it is used directly with some + resizing with respect to the matrix size. + + If it is specified as list of lists, + (e.g., ``[[0, 1], [0, 2]]``), then the permutation is formed + from applying the product of cycles. The direction how the + cyclic product is applied is described in below. + + If it is specified as a list, the list should represent + an array form of a permutation. (e.g., ``[1, 2, 0]``) which + would would form the swapping function + `0 \mapsto 1, 1 \mapsto 2, 2\mapsto 0`. + + orientation : 'rows', 'cols' + A flag to control whether to permute the rows or the columns + + direction : 'forward', 'backward' + A flag to control whether to apply the permutations from + the start of the list first, or from the back of the list + first. + + For example, if the permutation specification is + ``[[0, 1], [0, 2]]``, + + If the flag is set to ``'forward'``, the cycle would be + formed as `0 \mapsto 2, 2 \mapsto 1, 1 \mapsto 0`. + + If the flag is set to ``'backward'``, the cycle would be + formed as `0 \mapsto 1, 1 \mapsto 2, 2 \mapsto 0`. + + If the argument ``perm`` is not in a form of list of lists, + this flag takes no effect. + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.permute([[0, 1], [0, 2]], orientation='rows', direction='forward') + Matrix([ + [0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + + >>> from sympy import eye + >>> M = eye(3) + >>> M.permute([[0, 1], [0, 2]], orientation='rows', direction='backward') + Matrix([ + [0, 1, 0], + [0, 0, 1], + [1, 0, 0]]) + + Notes + ===== + + If a bijective function + `\sigma : \mathbb{N}_0 \rightarrow \mathbb{N}_0` denotes the + permutation. + + If the matrix `A` is the matrix to permute, represented as + a horizontal or a vertical stack of vectors: + + .. math:: + A = + \begin{bmatrix} + a_0 \\ a_1 \\ \vdots \\ a_{n-1} + \end{bmatrix} = + \begin{bmatrix} + \alpha_0 & \alpha_1 & \cdots & \alpha_{n-1} + \end{bmatrix} + + If the matrix `B` is the result, the permutation of matrix rows + is defined as: + + .. math:: + B := \begin{bmatrix} + a_{\sigma(0)} \\ a_{\sigma(1)} \\ \vdots \\ a_{\sigma(n-1)} + \end{bmatrix} + + And the permutation of matrix columns is defined as: + + .. math:: + B := \begin{bmatrix} + \alpha_{\sigma(0)} & \alpha_{\sigma(1)} & + \cdots & \alpha_{\sigma(n-1)} + \end{bmatrix} + """ + from sympy.combinatorics import Permutation + + # allow british variants and `columns` + if direction == 'forwards': + direction = 'forward' + if direction == 'backwards': + direction = 'backward' + if orientation == 'columns': + orientation = 'cols' + + if direction not in ('forward', 'backward'): + raise TypeError("direction='{}' is an invalid kwarg. " + "Try 'forward' or 'backward'".format(direction)) + if orientation not in ('rows', 'cols'): + raise TypeError("orientation='{}' is an invalid kwarg. " + "Try 'rows' or 'cols'".format(orientation)) + + if not isinstance(perm, (Permutation, Iterable)): + raise ValueError( + "{} must be a list, a list of lists, " + "or a SymPy permutation object.".format(perm)) + + # ensure all swaps are in range + max_index = self.rows if orientation == 'rows' else self.cols + if not all(0 <= t <= max_index for t in flatten(list(perm))): + raise IndexError("`swap` indices out of range.") + + if perm and not isinstance(perm, Permutation) and \ + isinstance(perm[0], Iterable): + if direction == 'forward': + perm = list(reversed(perm)) + perm = Permutation(perm, size=max_index+1) + else: + perm = Permutation(perm, size=max_index+1) + + if orientation == 'rows': + return self._eval_permute_rows(perm) + if orientation == 'cols': + return self._eval_permute_cols(perm) + + def permute_cols(self, swaps, direction='forward'): + """Alias for + ``self.permute(swaps, orientation='cols', direction=direction)`` + + See Also + ======== + + permute + """ + return self.permute(swaps, orientation='cols', direction=direction) + + def permute_rows(self, swaps, direction='forward'): + """Alias for + ``self.permute(swaps, orientation='rows', direction=direction)`` + + See Also + ======== + + permute + """ + return self.permute(swaps, orientation='rows', direction=direction) + + def refine(self, assumptions=True): + """Apply refine to each element of the matrix. + + Examples + ======== + + >>> from sympy import Symbol, Matrix, Abs, sqrt, Q + >>> x = Symbol('x') + >>> Matrix([[Abs(x)**2, sqrt(x**2)],[sqrt(x**2), Abs(x)**2]]) + Matrix([ + [ Abs(x)**2, sqrt(x**2)], + [sqrt(x**2), Abs(x)**2]]) + >>> _.refine(Q.real(x)) + Matrix([ + [ x**2, Abs(x)], + [Abs(x), x**2]]) + + """ + return self.applyfunc(lambda x: refine(x, assumptions)) + + def replace(self, F, G, map=False, simultaneous=True, exact=None): + """Replaces Function F in Matrix entries with Function G. + + Examples + ======== + + >>> from sympy import symbols, Function, Matrix + >>> F, G = symbols('F, G', cls=Function) + >>> M = Matrix(2, 2, lambda i, j: F(i+j)) ; M + Matrix([ + [F(0), F(1)], + [F(1), F(2)]]) + >>> N = M.replace(F,G) + >>> N + Matrix([ + [G(0), G(1)], + [G(1), G(2)]]) + """ + return self.applyfunc( + lambda x: x.replace(F, G, map=map, simultaneous=simultaneous, exact=exact)) + + def rot90(self, k=1): + """Rotates Matrix by 90 degrees + + Parameters + ========== + + k : int + Specifies how many times the matrix is rotated by 90 degrees + (clockwise when positive, counter-clockwise when negative). + + Examples + ======== + + >>> from sympy import Matrix, symbols + >>> A = Matrix(2, 2, symbols('a:d')) + >>> A + Matrix([ + [a, b], + [c, d]]) + + Rotating the matrix clockwise one time: + + >>> A.rot90(1) + Matrix([ + [c, a], + [d, b]]) + + Rotating the matrix anticlockwise two times: + + >>> A.rot90(-2) + Matrix([ + [d, c], + [b, a]]) + """ + + mod = k%4 + if mod == 0: + return self + if mod == 1: + return self[::-1, ::].T + if mod == 2: + return self[::-1, ::-1] + if mod == 3: + return self[::, ::-1].T + + def simplify(self, **kwargs): + """Apply simplify to each element of the matrix. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import SparseMatrix, sin, cos + >>> SparseMatrix(1, 1, [x*sin(y)**2 + x*cos(y)**2]) + Matrix([[x*sin(y)**2 + x*cos(y)**2]]) + >>> _.simplify() + Matrix([[x]]) + """ + return self.applyfunc(lambda x: x.simplify(**kwargs)) + + def subs(self, *args, **kwargs): # should mirror core.basic.subs + """Return a new matrix with subs applied to each entry. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import SparseMatrix, Matrix + >>> SparseMatrix(1, 1, [x]) + Matrix([[x]]) + >>> _.subs(x, y) + Matrix([[y]]) + >>> Matrix(_).subs(y, x) + Matrix([[x]]) + """ + + if len(args) == 1 and not isinstance(args[0], (dict, set)) and iter(args[0]) and not is_sequence(args[0]): + args = (list(args[0]),) + + return self.applyfunc(lambda x: x.subs(*args, **kwargs)) + + def trace(self): + """ + Returns the trace of a square matrix i.e. the sum of the + diagonal elements. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(2, 2, [1, 2, 3, 4]) + >>> A.trace() + 5 + + """ + if self.rows != self.cols: + raise NonSquareMatrixError() + return self._eval_trace() + + def transpose(self): + """ + Returns the transpose of the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(2, 2, [1, 2, 3, 4]) + >>> A.transpose() + Matrix([ + [1, 3], + [2, 4]]) + + >>> from sympy import Matrix, I + >>> m=Matrix(((1, 2+I), (3, 4))) + >>> m + Matrix([ + [1, 2 + I], + [3, 4]]) + >>> m.transpose() + Matrix([ + [ 1, 3], + [2 + I, 4]]) + >>> m.T == m.transpose() + True + + See Also + ======== + + conjugate: By-element conjugation + + """ + return self._eval_transpose() + + @property + def T(self): + '''Matrix transposition''' + return self.transpose() + + @property + def C(self): + '''By-element conjugation''' + return self.conjugate() + + def n(self, *args, **kwargs): + """Apply evalf() to each element of self.""" + return self.evalf(*args, **kwargs) + + def xreplace(self, rule): # should mirror core.basic.xreplace + """Return a new matrix with xreplace applied to each entry. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import SparseMatrix, Matrix + >>> SparseMatrix(1, 1, [x]) + Matrix([[x]]) + >>> _.xreplace({x: y}) + Matrix([[y]]) + >>> Matrix(_).xreplace({y: x}) + Matrix([[x]]) + """ + return self.applyfunc(lambda x: x.xreplace(rule)) + + def _eval_simplify(self, **kwargs): + # XXX: We can't use self.simplify here as mutable subclasses will + # override simplify and have it return None + return MatrixOperations.simplify(self, **kwargs) + + def _eval_trigsimp(self, **opts): + from sympy.simplify.trigsimp import trigsimp + return self.applyfunc(lambda x: trigsimp(x, **opts)) + + def upper_triangular(self, k=0): + """Return the elements on and above the kth diagonal of a matrix. + If k is not specified then simply returns upper-triangular portion + of a matrix + + Examples + ======== + + >>> from sympy import ones + >>> A = ones(4) + >>> A.upper_triangular() + Matrix([ + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 1]]) + + >>> A.upper_triangular(2) + Matrix([ + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + + >>> A.upper_triangular(-1) + Matrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1]]) + + """ + + def entry(i, j): + return self[i, j] if i + k <= j else self.zero + + return self._new(self.rows, self.cols, entry) + + + def lower_triangular(self, k=0): + """Return the elements on and below the kth diagonal of a matrix. + If k is not specified then simply returns lower-triangular portion + of a matrix + + Examples + ======== + + >>> from sympy import ones + >>> A = ones(4) + >>> A.lower_triangular() + Matrix([ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]]) + + >>> A.lower_triangular(-2) + Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [1, 1, 0, 0]]) + + >>> A.lower_triangular(1) + Matrix([ + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]]) + + """ + + def entry(i, j): + return self[i, j] if i + k >= j else self.zero + + return self._new(self.rows, self.cols, entry) + + + +class MatrixArithmetic(MatrixRequired): + """Provides basic matrix arithmetic operations. + Should not be instantiated directly.""" + + _op_priority = 10.01 + + def _eval_Abs(self): + return self._new(self.rows, self.cols, lambda i, j: Abs(self[i, j])) + + def _eval_add(self, other): + return self._new(self.rows, self.cols, + lambda i, j: self[i, j] + other[i, j]) + + def _eval_matrix_mul(self, other): + def entry(i, j): + vec = [self[i,k]*other[k,j] for k in range(self.cols)] + try: + return Add(*vec) + except (TypeError, SympifyError): + # Some matrices don't work with `sum` or `Add` + # They don't work with `sum` because `sum` tries to add `0` + # Fall back to a safe way to multiply if the `Add` fails. + return reduce(lambda a, b: a + b, vec) + + return self._new(self.rows, other.cols, entry) + + def _eval_matrix_mul_elementwise(self, other): + return self._new(self.rows, self.cols, lambda i, j: self[i,j]*other[i,j]) + + def _eval_matrix_rmul(self, other): + def entry(i, j): + return sum(other[i,k]*self[k,j] for k in range(other.cols)) + return self._new(other.rows, self.cols, entry) + + def _eval_pow_by_recursion(self, num): + if num == 1: + return self + + if num % 2 == 1: + a, b = self, self._eval_pow_by_recursion(num - 1) + else: + a = b = self._eval_pow_by_recursion(num // 2) + + return a.multiply(b) + + def _eval_pow_by_cayley(self, exp): + from sympy.discrete.recurrences import linrec_coeffs + row = self.shape[0] + p = self.charpoly() + + coeffs = (-p).all_coeffs()[1:] + coeffs = linrec_coeffs(coeffs, exp) + new_mat = self.eye(row) + ans = self.zeros(row) + + for i in range(row): + ans += coeffs[i]*new_mat + new_mat *= self + + return ans + + def _eval_pow_by_recursion_dotprodsimp(self, num, prevsimp=None): + if prevsimp is None: + prevsimp = [True]*len(self) + + if num == 1: + return self + + if num % 2 == 1: + a, b = self, self._eval_pow_by_recursion_dotprodsimp(num - 1, + prevsimp=prevsimp) + else: + a = b = self._eval_pow_by_recursion_dotprodsimp(num // 2, + prevsimp=prevsimp) + + m = a.multiply(b, dotprodsimp=False) + lenm = len(m) + elems = [None]*lenm + + for i in range(lenm): + if prevsimp[i]: + elems[i], prevsimp[i] = _dotprodsimp(m[i], withsimp=True) + else: + elems[i] = m[i] + + return m._new(m.rows, m.cols, elems) + + def _eval_scalar_mul(self, other): + return self._new(self.rows, self.cols, lambda i, j: self[i,j]*other) + + def _eval_scalar_rmul(self, other): + return self._new(self.rows, self.cols, lambda i, j: other*self[i,j]) + + def _eval_Mod(self, other): + return self._new(self.rows, self.cols, lambda i, j: Mod(self[i, j], other)) + + # Python arithmetic functions + def __abs__(self): + """Returns a new matrix with entry-wise absolute values.""" + return self._eval_Abs() + + @call_highest_priority('__radd__') + def __add__(self, other): + """Return self + other, raising ShapeError if shapes do not match.""" + if isinstance(other, NDimArray): # Matrix and array addition is currently not implemented + return NotImplemented + other = _matrixify(other) + # matrix-like objects can have shapes. This is + # our first sanity check. + if hasattr(other, 'shape'): + if self.shape != other.shape: + raise ShapeError("Matrix size mismatch: %s + %s" % ( + self.shape, other.shape)) + + # honest SymPy matrices defer to their class's routine + if getattr(other, 'is_Matrix', False): + # call the highest-priority class's _eval_add + a, b = self, other + if a.__class__ != classof(a, b): + b, a = a, b + return a._eval_add(b) + # Matrix-like objects can be passed to CommonMatrix routines directly. + if getattr(other, 'is_MatrixLike', False): + return MatrixArithmetic._eval_add(self, other) + + raise TypeError('cannot add %s and %s' % (type(self), type(other))) + + @call_highest_priority('__rtruediv__') + def __truediv__(self, other): + return self * (self.one / other) + + @call_highest_priority('__rmatmul__') + def __matmul__(self, other): + other = _matrixify(other) + if not getattr(other, 'is_Matrix', False) and not getattr(other, 'is_MatrixLike', False): + return NotImplemented + + return self.__mul__(other) + + def __mod__(self, other): + return self.applyfunc(lambda x: x % other) + + @call_highest_priority('__rmul__') + def __mul__(self, other): + """Return self*other where other is either a scalar or a matrix + of compatible dimensions. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> 2*A == A*2 == Matrix([[2, 4, 6], [8, 10, 12]]) + True + >>> B = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> A*B + Matrix([ + [30, 36, 42], + [66, 81, 96]]) + >>> B*A + Traceback (most recent call last): + ... + ShapeError: Matrices size mismatch. + >>> + + See Also + ======== + + matrix_multiply_elementwise + """ + + return self.multiply(other) + + def multiply(self, other, dotprodsimp=None): + """Same as __mul__() but with optional simplification. + + Parameters + ========== + + dotprodsimp : bool, optional + Specifies whether intermediate term algebraic simplification is used + during matrix multiplications to control expression blowup and thus + speed up calculation. Default is off. + """ + + isimpbool = _get_intermediate_simp_bool(False, dotprodsimp) + other = _matrixify(other) + # matrix-like objects can have shapes. This is + # our first sanity check. Double check other is not explicitly not a Matrix. + if (hasattr(other, 'shape') and len(other.shape) == 2 and + (getattr(other, 'is_Matrix', True) or + getattr(other, 'is_MatrixLike', True))): + if self.shape[1] != other.shape[0]: + raise ShapeError("Matrix size mismatch: %s * %s." % ( + self.shape, other.shape)) + + # honest SymPy matrices defer to their class's routine + if getattr(other, 'is_Matrix', False): + m = self._eval_matrix_mul(other) + if isimpbool: + return m._new(m.rows, m.cols, [_dotprodsimp(e) for e in m]) + return m + + # Matrix-like objects can be passed to CommonMatrix routines directly. + if getattr(other, 'is_MatrixLike', False): + return MatrixArithmetic._eval_matrix_mul(self, other) + + # if 'other' is not iterable then scalar multiplication. + if not isinstance(other, Iterable): + try: + return self._eval_scalar_mul(other) + except TypeError: + pass + + return NotImplemented + + def multiply_elementwise(self, other): + """Return the Hadamard product (elementwise product) of A and B + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[0, 1, 2], [3, 4, 5]]) + >>> B = Matrix([[1, 10, 100], [100, 10, 1]]) + >>> A.multiply_elementwise(B) + Matrix([ + [ 0, 10, 200], + [300, 40, 5]]) + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.cross + sympy.matrices.matrixbase.MatrixBase.dot + multiply + """ + if self.shape != other.shape: + raise ShapeError("Matrix shapes must agree {} != {}".format(self.shape, other.shape)) + + return self._eval_matrix_mul_elementwise(other) + + def __neg__(self): + return self._eval_scalar_mul(-1) + + @call_highest_priority('__rpow__') + def __pow__(self, exp): + """Return self**exp a scalar or symbol.""" + + return self.pow(exp) + + + def pow(self, exp, method=None): + r"""Return self**exp a scalar or symbol. + + Parameters + ========== + + method : multiply, mulsimp, jordan, cayley + If multiply then it returns exponentiation using recursion. + If jordan then Jordan form exponentiation will be used. + If cayley then the exponentiation is done using Cayley-Hamilton + theorem. + If mulsimp then the exponentiation is done using recursion + with dotprodsimp. This specifies whether intermediate term + algebraic simplification is used during naive matrix power to + control expression blowup and thus speed up calculation. + If None, then it heuristically decides which method to use. + + """ + + if method is not None and method not in ['multiply', 'mulsimp', 'jordan', 'cayley']: + raise TypeError('No such method') + if self.rows != self.cols: + raise NonSquareMatrixError() + a = self + jordan_pow = getattr(a, '_matrix_pow_by_jordan_blocks', None) + exp = sympify(exp) + + if exp.is_zero: + return a._new(a.rows, a.cols, lambda i, j: int(i == j)) + if exp == 1: + return a + + diagonal = getattr(a, 'is_diagonal', None) + if diagonal is not None and diagonal(): + return a._new(a.rows, a.cols, lambda i, j: a[i,j]**exp if i == j else 0) + + if exp.is_Number and exp % 1 == 0: + if a.rows == 1: + return a._new([[a[0]**exp]]) + if exp < 0: + exp = -exp + a = a.inv() + # When certain conditions are met, + # Jordan block algorithm is faster than + # computation by recursion. + if method == 'jordan': + try: + return jordan_pow(exp) + except MatrixError: + if method == 'jordan': + raise + + elif method == 'cayley': + if not exp.is_Number or exp % 1 != 0: + raise ValueError("cayley method is only valid for integer powers") + return a._eval_pow_by_cayley(exp) + + elif method == "mulsimp": + if not exp.is_Number or exp % 1 != 0: + raise ValueError("mulsimp method is only valid for integer powers") + return a._eval_pow_by_recursion_dotprodsimp(exp) + + elif method == "multiply": + if not exp.is_Number or exp % 1 != 0: + raise ValueError("multiply method is only valid for integer powers") + return a._eval_pow_by_recursion(exp) + + elif method is None and exp.is_Number and exp % 1 == 0: + if exp.is_Float: + exp = Integer(exp) + # Decide heuristically which method to apply + if a.rows == 2 and exp > 100000: + return jordan_pow(exp) + elif _get_intermediate_simp_bool(True, None): + return a._eval_pow_by_recursion_dotprodsimp(exp) + elif exp > 10000: + return a._eval_pow_by_cayley(exp) + else: + return a._eval_pow_by_recursion(exp) + + if jordan_pow: + try: + return jordan_pow(exp) + except NonInvertibleMatrixError: + # Raised by jordan_pow on zero determinant matrix unless exp is + # definitely known to be a non-negative integer. + # Here we raise if n is definitely not a non-negative integer + # but otherwise we can leave this as an unevaluated MatPow. + if exp.is_integer is False or exp.is_nonnegative is False: + raise + + from sympy.matrices.expressions import MatPow + return MatPow(a, exp) + + @call_highest_priority('__add__') + def __radd__(self, other): + return self + other + + @call_highest_priority('__matmul__') + def __rmatmul__(self, other): + other = _matrixify(other) + if not getattr(other, 'is_Matrix', False) and not getattr(other, 'is_MatrixLike', False): + return NotImplemented + + return self.__rmul__(other) + + @call_highest_priority('__mul__') + def __rmul__(self, other): + return self.rmultiply(other) + + def rmultiply(self, other, dotprodsimp=None): + """Same as __rmul__() but with optional simplification. + + Parameters + ========== + + dotprodsimp : bool, optional + Specifies whether intermediate term algebraic simplification is used + during matrix multiplications to control expression blowup and thus + speed up calculation. Default is off. + """ + isimpbool = _get_intermediate_simp_bool(False, dotprodsimp) + other = _matrixify(other) + # matrix-like objects can have shapes. This is + # our first sanity check. Double check other is not explicitly not a Matrix. + if (hasattr(other, 'shape') and len(other.shape) == 2 and + (getattr(other, 'is_Matrix', True) or + getattr(other, 'is_MatrixLike', True))): + if self.shape[0] != other.shape[1]: + raise ShapeError("Matrix size mismatch.") + + # honest SymPy matrices defer to their class's routine + if getattr(other, 'is_Matrix', False): + m = self._eval_matrix_rmul(other) + if isimpbool: + return m._new(m.rows, m.cols, [_dotprodsimp(e) for e in m]) + return m + # Matrix-like objects can be passed to CommonMatrix routines directly. + if getattr(other, 'is_MatrixLike', False): + return MatrixArithmetic._eval_matrix_rmul(self, other) + + # if 'other' is not iterable then scalar multiplication. + if not isinstance(other, Iterable): + try: + return self._eval_scalar_rmul(other) + except TypeError: + pass + + return NotImplemented + + @call_highest_priority('__sub__') + def __rsub__(self, a): + return (-self) + a + + @call_highest_priority('__rsub__') + def __sub__(self, a): + return self + (-a) + + +class MatrixCommon(MatrixArithmetic, MatrixOperations, MatrixProperties, + MatrixSpecial, MatrixShaping): + """All common matrix operations including basic arithmetic, shaping, + and special matrices like `zeros`, and `eye`.""" + _diff_wrt: bool = True + + +class _MinimalMatrix: + """Class providing the minimum functionality + for a matrix-like object and implementing every method + required for a `MatrixRequired`. This class does not have everything + needed to become a full-fledged SymPy object, but it will satisfy the + requirements of anything inheriting from `MatrixRequired`. If you wish + to make a specialized matrix type, make sure to implement these + methods and properties with the exception of `__init__` and `__repr__` + which are included for convenience.""" + + is_MatrixLike = True + _sympify = staticmethod(sympify) + _class_priority = 3 + zero = S.Zero + one = S.One + + is_Matrix = True + is_MatrixExpr = False + + @classmethod + def _new(cls, *args, **kwargs): + return cls(*args, **kwargs) + + def __init__(self, rows, cols=None, mat=None, copy=False): + if isfunction(mat): + # if we passed in a function, use that to populate the indices + mat = [mat(i, j) for i in range(rows) for j in range(cols)] + if cols is None and mat is None: + mat = rows + rows, cols = getattr(mat, 'shape', (rows, cols)) + try: + # if we passed in a list of lists, flatten it and set the size + if cols is None and mat is None: + mat = rows + cols = len(mat[0]) + rows = len(mat) + mat = [x for l in mat for x in l] + except (IndexError, TypeError): + pass + self.mat = tuple(self._sympify(x) for x in mat) + self.rows, self.cols = rows, cols + if self.rows is None or self.cols is None: + raise NotImplementedError("Cannot initialize matrix with given parameters") + + def __getitem__(self, key): + def _normalize_slices(row_slice, col_slice): + """Ensure that row_slice and col_slice do not have + `None` in their arguments. Any integers are converted + to slices of length 1""" + if not isinstance(row_slice, slice): + row_slice = slice(row_slice, row_slice + 1, None) + row_slice = slice(*row_slice.indices(self.rows)) + + if not isinstance(col_slice, slice): + col_slice = slice(col_slice, col_slice + 1, None) + col_slice = slice(*col_slice.indices(self.cols)) + + return (row_slice, col_slice) + + def _coord_to_index(i, j): + """Return the index in _mat corresponding + to the (i,j) position in the matrix. """ + return i * self.cols + j + + if isinstance(key, tuple): + i, j = key + if isinstance(i, slice) or isinstance(j, slice): + # if the coordinates are not slices, make them so + # and expand the slices so they don't contain `None` + i, j = _normalize_slices(i, j) + + rowsList, colsList = list(range(self.rows))[i], \ + list(range(self.cols))[j] + indices = (i * self.cols + j for i in rowsList for j in + colsList) + return self._new(len(rowsList), len(colsList), + [self.mat[i] for i in indices]) + + # if the key is a tuple of ints, change + # it to an array index + key = _coord_to_index(i, j) + return self.mat[key] + + def __eq__(self, other): + try: + classof(self, other) + except TypeError: + return False + return ( + self.shape == other.shape and list(self) == list(other)) + + def __len__(self): + return self.rows*self.cols + + def __repr__(self): + return "_MinimalMatrix({}, {}, {})".format(self.rows, self.cols, + self.mat) + + @property + def shape(self): + return (self.rows, self.cols) + + +class _CastableMatrix: # this is needed here ONLY FOR TESTS. + def as_mutable(self): + return self + + def as_immutable(self): + return self + + +class _MatrixWrapper: + """Wrapper class providing the minimum functionality for a matrix-like + object: .rows, .cols, .shape, indexability, and iterability. CommonMatrix + math operations should work on matrix-like objects. This one is intended for + matrix-like objects which use the same indexing format as SymPy with respect + to returning matrix elements instead of rows for non-tuple indexes. + """ + + is_Matrix = False # needs to be here because of __getattr__ + is_MatrixLike = True + + def __init__(self, mat, shape): + self.mat = mat + self.shape = shape + self.rows, self.cols = shape + + def __getitem__(self, key): + if isinstance(key, tuple): + return sympify(self.mat.__getitem__(key)) + + return sympify(self.mat.__getitem__((key // self.rows, key % self.cols))) + + def __iter__(self): # supports numpy.matrix and numpy.array + mat = self.mat + cols = self.cols + + return iter(sympify(mat[r, c]) for r in range(self.rows) for c in range(cols)) + + +def _matrixify(mat): + """If `mat` is a Matrix or is matrix-like, + return a Matrix or MatrixWrapper object. Otherwise + `mat` is passed through without modification.""" + + if getattr(mat, 'is_Matrix', False) or getattr(mat, 'is_MatrixLike', False): + return mat + + if not(getattr(mat, 'is_Matrix', True) or getattr(mat, 'is_MatrixLike', True)): + return mat + + shape = None + + if hasattr(mat, 'shape'): # numpy, scipy.sparse + if len(mat.shape) == 2: + shape = mat.shape + elif hasattr(mat, 'rows') and hasattr(mat, 'cols'): # mpmath + shape = (mat.rows, mat.cols) + + if shape: + return _MatrixWrapper(mat, shape) + + return mat + + +def a2idx(j, n=None): + """Return integer after making positive and validating against n.""" + if not isinstance(j, int): + jindex = getattr(j, '__index__', None) + if jindex is not None: + j = jindex() + else: + raise IndexError("Invalid index a[%r]" % (j,)) + if n is not None: + if j < 0: + j += n + if not (j >= 0 and j < n): + raise IndexError("Index out of range: a[%s]" % (j,)) + return int(j) + + +def classof(A, B): + """ + Get the type of the result when combining matrices of different types. + + Currently the strategy is that immutability is contagious. + + Examples + ======== + + >>> from sympy import Matrix, ImmutableMatrix + >>> from sympy.matrices.matrixbase import classof + >>> M = Matrix([[1, 2], [3, 4]]) # a Mutable Matrix + >>> IM = ImmutableMatrix([[1, 2], [3, 4]]) + >>> classof(M, IM) + + """ + priority_A = getattr(A, '_class_priority', None) + priority_B = getattr(B, '_class_priority', None) + if None not in (priority_A, priority_B): + if A._class_priority > B._class_priority: + return A.__class__ + else: + return B.__class__ + + try: + import numpy + except ImportError: + pass + else: + if isinstance(A, numpy.ndarray): + return B.__class__ + if isinstance(B, numpy.ndarray): + return A.__class__ + + raise TypeError("Incompatible classes %s, %s" % (A.__class__, B.__class__)) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/decompositions.py b/.venv/lib/python3.13/site-packages/sympy/matrices/decompositions.py new file mode 100644 index 0000000000000000000000000000000000000000..a8dd466d84c957b870396a050fd25ec21e7113a3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/decompositions.py @@ -0,0 +1,1621 @@ +import copy + +from sympy.core import S +from sympy.core.function import expand_mul +from sympy.functions.elementary.miscellaneous import Min, sqrt +from sympy.functions.elementary.complexes import sign + +from .exceptions import NonSquareMatrixError, NonPositiveDefiniteMatrixError +from .utilities import _get_intermediate_simp, _iszero +from .determinant import _find_reasonable_pivot_naive + + +def _rank_decomposition(M, iszerofunc=_iszero, simplify=False): + r"""Returns a pair of matrices (`C`, `F`) with matching rank + such that `A = C F`. + + Parameters + ========== + + iszerofunc : Function, optional + A function used for detecting whether an element can + act as a pivot. ``lambda x: x.is_zero`` is used by default. + + simplify : Bool or Function, optional + A function used to simplify elements when looking for a + pivot. By default SymPy's ``simplify`` is used. + + Returns + ======= + + (C, F) : Matrices + `C` and `F` are full-rank matrices with rank as same as `A`, + whose product gives `A`. + + See Notes for additional mathematical details. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([ + ... [1, 3, 1, 4], + ... [2, 7, 3, 9], + ... [1, 5, 3, 1], + ... [1, 2, 0, 8] + ... ]) + >>> C, F = A.rank_decomposition() + >>> C + Matrix([ + [1, 3, 4], + [2, 7, 9], + [1, 5, 1], + [1, 2, 8]]) + >>> F + Matrix([ + [1, 0, -2, 0], + [0, 1, 1, 0], + [0, 0, 0, 1]]) + >>> C * F == A + True + + Notes + ===== + + Obtaining `F`, an RREF of `A`, is equivalent to creating a + product + + .. math:: + E_n E_{n-1} ... E_1 A = F + + where `E_n, E_{n-1}, \dots, E_1` are the elimination matrices or + permutation matrices equivalent to each row-reduction step. + + The inverse of the same product of elimination matrices gives + `C`: + + .. math:: + C = \left(E_n E_{n-1} \dots E_1\right)^{-1} + + It is not necessary, however, to actually compute the inverse: + the columns of `C` are those from the original matrix with the + same column indices as the indices of the pivot columns of `F`. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Rank_factorization + + .. [2] Piziak, R.; Odell, P. L. (1 June 1999). + "Full Rank Factorization of Matrices". + Mathematics Magazine. 72 (3): 193. doi:10.2307/2690882 + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.rref + """ + + F, pivot_cols = M.rref(simplify=simplify, iszerofunc=iszerofunc, + pivots=True) + rank = len(pivot_cols) + + C = M.extract(range(M.rows), pivot_cols) + F = F[:rank, :] + + return C, F + + +def _liupc(M): + """Liu's algorithm, for pre-determination of the Elimination Tree of + the given matrix, used in row-based symbolic Cholesky factorization. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> S = SparseMatrix([ + ... [1, 0, 3, 2], + ... [0, 0, 1, 0], + ... [4, 0, 0, 5], + ... [0, 6, 7, 0]]) + >>> S.liupc() + ([[0], [], [0], [1, 2]], [4, 3, 4, 4]) + + References + ========== + + .. [1] Symbolic Sparse Cholesky Factorization using Elimination Trees, + Jeroen Van Grondelle (1999) + https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.39.7582 + """ + # Algorithm 2.4, p 17 of reference + + # get the indices of the elements that are non-zero on or below diag + R = [[] for r in range(M.rows)] + + for r, c, _ in M.row_list(): + if c <= r: + R[r].append(c) + + inf = len(R) # nothing will be this large + parent = [inf]*M.rows + virtual = [inf]*M.rows + + for r in range(M.rows): + for c in R[r][:-1]: + while virtual[c] < r: + t = virtual[c] + virtual[c] = r + c = t + + if virtual[c] == inf: + parent[c] = virtual[c] = r + + return R, parent + +def _row_structure_symbolic_cholesky(M): + """Symbolic cholesky factorization, for pre-determination of the + non-zero structure of the Cholesky factororization. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> S = SparseMatrix([ + ... [1, 0, 3, 2], + ... [0, 0, 1, 0], + ... [4, 0, 0, 5], + ... [0, 6, 7, 0]]) + >>> S.row_structure_symbolic_cholesky() + [[0], [], [0], [1, 2]] + + References + ========== + + .. [1] Symbolic Sparse Cholesky Factorization using Elimination Trees, + Jeroen Van Grondelle (1999) + https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.39.7582 + """ + + R, parent = M.liupc() + inf = len(R) # this acts as infinity + Lrow = copy.deepcopy(R) + + for k in range(M.rows): + for j in R[k]: + while j != inf and j != k: + Lrow[k].append(j) + j = parent[j] + + Lrow[k] = sorted(set(Lrow[k])) + + return Lrow + + +def _cholesky(M, hermitian=True): + """Returns the Cholesky-type decomposition L of a matrix A + such that L * L.H == A if hermitian flag is True, + or L * L.T == A if hermitian is False. + + A must be a Hermitian positive-definite matrix if hermitian is True, + or a symmetric matrix if it is False. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + >>> A.cholesky() + Matrix([ + [ 5, 0, 0], + [ 3, 3, 0], + [-1, 1, 3]]) + >>> A.cholesky() * A.cholesky().T + Matrix([ + [25, 15, -5], + [15, 18, 0], + [-5, 0, 11]]) + + The matrix can have complex entries: + + >>> from sympy import I + >>> A = Matrix(((9, 3*I), (-3*I, 5))) + >>> A.cholesky() + Matrix([ + [ 3, 0], + [-I, 2]]) + >>> A.cholesky() * A.cholesky().H + Matrix([ + [ 9, 3*I], + [-3*I, 5]]) + + Non-hermitian Cholesky-type decomposition may be useful when the + matrix is not positive-definite. + + >>> A = Matrix([[1, 2], [2, 1]]) + >>> L = A.cholesky(hermitian=False) + >>> L + Matrix([ + [1, 0], + [2, sqrt(3)*I]]) + >>> L*L.T == A + True + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.LDLdecomposition + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + QRdecomposition + """ + + from .dense import MutableDenseMatrix + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if hermitian and not M.is_hermitian: + raise ValueError("Matrix must be Hermitian.") + if not hermitian and not M.is_symmetric(): + raise ValueError("Matrix must be symmetric.") + + L = MutableDenseMatrix.zeros(M.rows, M.rows) + + if hermitian: + for i in range(M.rows): + for j in range(i): + L[i, j] = ((1 / L[j, j])*(M[i, j] - + sum(L[i, k]*L[j, k].conjugate() for k in range(j)))) + + Lii2 = (M[i, i] - + sum(L[i, k]*L[i, k].conjugate() for k in range(i))) + + if Lii2.is_positive is False: + raise NonPositiveDefiniteMatrixError( + "Matrix must be positive-definite") + + L[i, i] = sqrt(Lii2) + + else: + for i in range(M.rows): + for j in range(i): + L[i, j] = ((1 / L[j, j])*(M[i, j] - + sum(L[i, k]*L[j, k] for k in range(j)))) + + L[i, i] = sqrt(M[i, i] - + sum(L[i, k]**2 for k in range(i))) + + return M._new(L) + +def _cholesky_sparse(M, hermitian=True): + """ + Returns the Cholesky decomposition L of a matrix A + such that L * L.T = A + + A must be a square, symmetric, positive-definite + and non-singular matrix + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> A = SparseMatrix(((25,15,-5),(15,18,0),(-5,0,11))) + >>> A.cholesky() + Matrix([ + [ 5, 0, 0], + [ 3, 3, 0], + [-1, 1, 3]]) + >>> A.cholesky() * A.cholesky().T == A + True + + The matrix can have complex entries: + + >>> from sympy import I + >>> A = SparseMatrix(((9, 3*I), (-3*I, 5))) + >>> A.cholesky() + Matrix([ + [ 3, 0], + [-I, 2]]) + >>> A.cholesky() * A.cholesky().H + Matrix([ + [ 9, 3*I], + [-3*I, 5]]) + + Non-hermitian Cholesky-type decomposition may be useful when the + matrix is not positive-definite. + + >>> A = SparseMatrix([[1, 2], [2, 1]]) + >>> L = A.cholesky(hermitian=False) + >>> L + Matrix([ + [1, 0], + [2, sqrt(3)*I]]) + >>> L*L.T == A + True + + See Also + ======== + + sympy.matrices.sparse.SparseMatrix.LDLdecomposition + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + QRdecomposition + """ + + from .dense import MutableDenseMatrix + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if hermitian and not M.is_hermitian: + raise ValueError("Matrix must be Hermitian.") + if not hermitian and not M.is_symmetric(): + raise ValueError("Matrix must be symmetric.") + + dps = _get_intermediate_simp(expand_mul, expand_mul) + Crowstruc = M.row_structure_symbolic_cholesky() + C = MutableDenseMatrix.zeros(M.rows) + + for i in range(len(Crowstruc)): + for j in Crowstruc[i]: + if i != j: + C[i, j] = M[i, j] + summ = 0 + + for p1 in Crowstruc[i]: + if p1 < j: + for p2 in Crowstruc[j]: + if p2 < j: + if p1 == p2: + if hermitian: + summ += C[i, p1]*C[j, p1].conjugate() + else: + summ += C[i, p1]*C[j, p1] + else: + break + else: + break + + C[i, j] = dps((C[i, j] - summ) / C[j, j]) + + else: # i == j + C[j, j] = M[j, j] + summ = 0 + + for k in Crowstruc[j]: + if k < j: + if hermitian: + summ += C[j, k]*C[j, k].conjugate() + else: + summ += C[j, k]**2 + else: + break + + Cjj2 = dps(C[j, j] - summ) + + if hermitian and Cjj2.is_positive is False: + raise NonPositiveDefiniteMatrixError( + "Matrix must be positive-definite") + + C[j, j] = sqrt(Cjj2) + + return M._new(C) + + +def _LDLdecomposition(M, hermitian=True): + """Returns the LDL Decomposition (L, D) of matrix A, + such that L * D * L.H == A if hermitian flag is True, or + L * D * L.T == A if hermitian is False. + This method eliminates the use of square root. + Further this ensures that all the diagonal entries of L are 1. + A must be a Hermitian positive-definite matrix if hermitian is True, + or a symmetric matrix otherwise. + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> A = Matrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + >>> L, D = A.LDLdecomposition() + >>> L + Matrix([ + [ 1, 0, 0], + [ 3/5, 1, 0], + [-1/5, 1/3, 1]]) + >>> D + Matrix([ + [25, 0, 0], + [ 0, 9, 0], + [ 0, 0, 9]]) + >>> L * D * L.T * A.inv() == eye(A.rows) + True + + The matrix can have complex entries: + + >>> from sympy import I + >>> A = Matrix(((9, 3*I), (-3*I, 5))) + >>> L, D = A.LDLdecomposition() + >>> L + Matrix([ + [ 1, 0], + [-I/3, 1]]) + >>> D + Matrix([ + [9, 0], + [0, 4]]) + >>> L*D*L.H == A + True + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.cholesky + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + QRdecomposition + """ + + from .dense import MutableDenseMatrix + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if hermitian and not M.is_hermitian: + raise ValueError("Matrix must be Hermitian.") + if not hermitian and not M.is_symmetric(): + raise ValueError("Matrix must be symmetric.") + + D = MutableDenseMatrix.zeros(M.rows, M.rows) + L = MutableDenseMatrix.eye(M.rows) + + if hermitian: + for i in range(M.rows): + for j in range(i): + L[i, j] = (1 / D[j, j])*(M[i, j] - sum( + L[i, k]*L[j, k].conjugate()*D[k, k] for k in range(j))) + + D[i, i] = (M[i, i] - + sum(L[i, k]*L[i, k].conjugate()*D[k, k] for k in range(i))) + + if D[i, i].is_positive is False: + raise NonPositiveDefiniteMatrixError( + "Matrix must be positive-definite") + + else: + for i in range(M.rows): + for j in range(i): + L[i, j] = (1 / D[j, j])*(M[i, j] - sum( + L[i, k]*L[j, k]*D[k, k] for k in range(j))) + + D[i, i] = M[i, i] - sum(L[i, k]**2*D[k, k] for k in range(i)) + + return M._new(L), M._new(D) + +def _LDLdecomposition_sparse(M, hermitian=True): + """ + Returns the LDL Decomposition (matrices ``L`` and ``D``) of matrix + ``A``, such that ``L * D * L.T == A``. ``A`` must be a square, + symmetric, positive-definite and non-singular. + + This method eliminates the use of square root and ensures that all + the diagonal entries of L are 1. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> A = SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + >>> L, D = A.LDLdecomposition() + >>> L + Matrix([ + [ 1, 0, 0], + [ 3/5, 1, 0], + [-1/5, 1/3, 1]]) + >>> D + Matrix([ + [25, 0, 0], + [ 0, 9, 0], + [ 0, 0, 9]]) + >>> L * D * L.T == A + True + + """ + + from .dense import MutableDenseMatrix + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if hermitian and not M.is_hermitian: + raise ValueError("Matrix must be Hermitian.") + if not hermitian and not M.is_symmetric(): + raise ValueError("Matrix must be symmetric.") + + dps = _get_intermediate_simp(expand_mul, expand_mul) + Lrowstruc = M.row_structure_symbolic_cholesky() + L = MutableDenseMatrix.eye(M.rows) + D = MutableDenseMatrix.zeros(M.rows, M.cols) + + for i in range(len(Lrowstruc)): + for j in Lrowstruc[i]: + if i != j: + L[i, j] = M[i, j] + summ = 0 + + for p1 in Lrowstruc[i]: + if p1 < j: + for p2 in Lrowstruc[j]: + if p2 < j: + if p1 == p2: + if hermitian: + summ += L[i, p1]*L[j, p1].conjugate()*D[p1, p1] + else: + summ += L[i, p1]*L[j, p1]*D[p1, p1] + else: + break + else: + break + + L[i, j] = dps((L[i, j] - summ) / D[j, j]) + + else: # i == j + D[i, i] = M[i, i] + summ = 0 + + for k in Lrowstruc[i]: + if k < i: + if hermitian: + summ += L[i, k]*L[i, k].conjugate()*D[k, k] + else: + summ += L[i, k]**2*D[k, k] + else: + break + + D[i, i] = dps(D[i, i] - summ) + + if hermitian and D[i, i].is_positive is False: + raise NonPositiveDefiniteMatrixError( + "Matrix must be positive-definite") + + return M._new(L), M._new(D) + + +def _LUdecomposition(M, iszerofunc=_iszero, simpfunc=None, rankcheck=False): + """Returns (L, U, perm) where L is a lower triangular matrix with unit + diagonal, U is an upper triangular matrix, and perm is a list of row + swap index pairs. If A is the original matrix, then + ``A = (L*U).permuteBkwd(perm)``, and the row permutation matrix P such + that $P A = L U$ can be computed by ``P = eye(A.rows).permuteFwd(perm)``. + + See documentation for LUCombined for details about the keyword argument + rankcheck, iszerofunc, and simpfunc. + + Parameters + ========== + + rankcheck : bool, optional + Determines if this function should detect the rank + deficiency of the matrixis and should raise a + ``ValueError``. + + iszerofunc : function, optional + A function which determines if a given expression is zero. + + The function should be a callable that takes a single + SymPy expression and returns a 3-valued boolean value + ``True``, ``False``, or ``None``. + + It is internally used by the pivot searching algorithm. + See the notes section for a more information about the + pivot searching algorithm. + + simpfunc : function or None, optional + A function that simplifies the input. + + If this is specified as a function, this function should be + a callable that takes a single SymPy expression and returns + an another SymPy expression that is algebraically + equivalent. + + If ``None``, it indicates that the pivot search algorithm + should not attempt to simplify any candidate pivots. + + It is internally used by the pivot searching algorithm. + See the notes section for a more information about the + pivot searching algorithm. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[4, 3], [6, 3]]) + >>> L, U, _ = a.LUdecomposition() + >>> L + Matrix([ + [ 1, 0], + [3/2, 1]]) + >>> U + Matrix([ + [4, 3], + [0, -3/2]]) + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.cholesky + sympy.matrices.dense.DenseMatrix.LDLdecomposition + QRdecomposition + LUdecomposition_Simple + LUdecompositionFF + LUsolve + """ + + combined, p = M.LUdecomposition_Simple(iszerofunc=iszerofunc, + simpfunc=simpfunc, rankcheck=rankcheck) + + # L is lower triangular ``M.rows x M.rows`` + # U is upper triangular ``M.rows x M.cols`` + # L has unit diagonal. For each column in combined, the subcolumn + # below the diagonal of combined is shared by L. + # If L has more columns than combined, then the remaining subcolumns + # below the diagonal of L are zero. + # The upper triangular portion of L and combined are equal. + def entry_L(i, j): + if i < j: + # Super diagonal entry + return M.zero + elif i == j: + return M.one + elif j < combined.cols: + return combined[i, j] + + # Subdiagonal entry of L with no corresponding + # entry in combined + return M.zero + + def entry_U(i, j): + return M.zero if i > j else combined[i, j] + + L = M._new(combined.rows, combined.rows, entry_L) + U = M._new(combined.rows, combined.cols, entry_U) + + return L, U, p + +def _LUdecomposition_Simple(M, iszerofunc=_iszero, simpfunc=None, + rankcheck=False): + r"""Compute the PLU decomposition of the matrix. + + Parameters + ========== + + rankcheck : bool, optional + Determines if this function should detect the rank + deficiency of the matrixis and should raise a + ``ValueError``. + + iszerofunc : function, optional + A function which determines if a given expression is zero. + + The function should be a callable that takes a single + SymPy expression and returns a 3-valued boolean value + ``True``, ``False``, or ``None``. + + It is internally used by the pivot searching algorithm. + See the notes section for a more information about the + pivot searching algorithm. + + simpfunc : function or None, optional + A function that simplifies the input. + + If this is specified as a function, this function should be + a callable that takes a single SymPy expression and returns + an another SymPy expression that is algebraically + equivalent. + + If ``None``, it indicates that the pivot search algorithm + should not attempt to simplify any candidate pivots. + + It is internally used by the pivot searching algorithm. + See the notes section for a more information about the + pivot searching algorithm. + + Returns + ======= + + (lu, row_swaps) : (Matrix, list) + If the original matrix is a $m, n$ matrix: + + *lu* is a $m, n$ matrix, which contains result of the + decomposition in a compressed form. See the notes section + to see how the matrix is compressed. + + *row_swaps* is a $m$-element list where each element is a + pair of row exchange indices. + + ``A = (L*U).permute_backward(perm)``, and the row + permutation matrix $P$ from the formula $P A = L U$ can be + computed by ``P=eye(A.row).permute_forward(perm)``. + + Raises + ====== + + ValueError + Raised if ``rankcheck=True`` and the matrix is found to + be rank deficient during the computation. + + Notes + ===== + + About the PLU decomposition: + + PLU decomposition is a generalization of a LU decomposition + which can be extended for rank-deficient matrices. + + It can further be generalized for non-square matrices, and this + is the notation that SymPy is using. + + PLU decomposition is a decomposition of a $m, n$ matrix $A$ in + the form of $P A = L U$ where + + * $L$ is a $m, m$ lower triangular matrix with unit diagonal + entries. + * $U$ is a $m, n$ upper triangular matrix. + * $P$ is a $m, m$ permutation matrix. + + So, for a square matrix, the decomposition would look like: + + .. math:: + L = \begin{bmatrix} + 1 & 0 & 0 & \cdots & 0 \\ + L_{1, 0} & 1 & 0 & \cdots & 0 \\ + L_{2, 0} & L_{2, 1} & 1 & \cdots & 0 \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + L_{n-1, 0} & L_{n-1, 1} & L_{n-1, 2} & \cdots & 1 + \end{bmatrix} + + .. math:: + U = \begin{bmatrix} + U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, n-1} \\ + 0 & U_{1, 1} & U_{1, 2} & \cdots & U_{1, n-1} \\ + 0 & 0 & U_{2, 2} & \cdots & U_{2, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + 0 & 0 & 0 & \cdots & U_{n-1, n-1} + \end{bmatrix} + + And for a matrix with more rows than the columns, + the decomposition would look like: + + .. math:: + L = \begin{bmatrix} + 1 & 0 & 0 & \cdots & 0 & 0 & \cdots & 0 \\ + L_{1, 0} & 1 & 0 & \cdots & 0 & 0 & \cdots & 0 \\ + L_{2, 0} & L_{2, 1} & 1 & \cdots & 0 & 0 & \cdots & 0 \\ + \vdots & \vdots & \vdots & \ddots & \vdots & \vdots & \ddots + & \vdots \\ + L_{n-1, 0} & L_{n-1, 1} & L_{n-1, 2} & \cdots & 1 & 0 + & \cdots & 0 \\ + L_{n, 0} & L_{n, 1} & L_{n, 2} & \cdots & L_{n, n-1} & 1 + & \cdots & 0 \\ + \vdots & \vdots & \vdots & \ddots & \vdots & \vdots + & \ddots & \vdots \\ + L_{m-1, 0} & L_{m-1, 1} & L_{m-1, 2} & \cdots & L_{m-1, n-1} + & 0 & \cdots & 1 \\ + \end{bmatrix} + + .. math:: + U = \begin{bmatrix} + U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, n-1} \\ + 0 & U_{1, 1} & U_{1, 2} & \cdots & U_{1, n-1} \\ + 0 & 0 & U_{2, 2} & \cdots & U_{2, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + 0 & 0 & 0 & \cdots & U_{n-1, n-1} \\ + 0 & 0 & 0 & \cdots & 0 \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + 0 & 0 & 0 & \cdots & 0 + \end{bmatrix} + + Finally, for a matrix with more columns than the rows, the + decomposition would look like: + + .. math:: + L = \begin{bmatrix} + 1 & 0 & 0 & \cdots & 0 \\ + L_{1, 0} & 1 & 0 & \cdots & 0 \\ + L_{2, 0} & L_{2, 1} & 1 & \cdots & 0 \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + L_{m-1, 0} & L_{m-1, 1} & L_{m-1, 2} & \cdots & 1 + \end{bmatrix} + + .. math:: + U = \begin{bmatrix} + U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, m-1} + & \cdots & U_{0, n-1} \\ + 0 & U_{1, 1} & U_{1, 2} & \cdots & U_{1, m-1} + & \cdots & U_{1, n-1} \\ + 0 & 0 & U_{2, 2} & \cdots & U_{2, m-1} + & \cdots & U_{2, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots + & \cdots & \vdots \\ + 0 & 0 & 0 & \cdots & U_{m-1, m-1} + & \cdots & U_{m-1, n-1} \\ + \end{bmatrix} + + About the compressed LU storage: + + The results of the decomposition are often stored in compressed + forms rather than returning $L$ and $U$ matrices individually. + + It may be less intiuitive, but it is commonly used for a lot of + numeric libraries because of the efficiency. + + The storage matrix is defined as following for this specific + method: + + * The subdiagonal elements of $L$ are stored in the subdiagonal + portion of $LU$, that is $LU_{i, j} = L_{i, j}$ whenever + $i > j$. + * The elements on the diagonal of $L$ are all 1, and are not + explicitly stored. + * $U$ is stored in the upper triangular portion of $LU$, that is + $LU_{i, j} = U_{i, j}$ whenever $i <= j$. + * For a case of $m > n$, the right side of the $L$ matrix is + trivial to store. + * For a case of $m < n$, the below side of the $U$ matrix is + trivial to store. + + So, for a square matrix, the compressed output matrix would be: + + .. math:: + LU = \begin{bmatrix} + U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, n-1} \\ + L_{1, 0} & U_{1, 1} & U_{1, 2} & \cdots & U_{1, n-1} \\ + L_{2, 0} & L_{2, 1} & U_{2, 2} & \cdots & U_{2, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + L_{n-1, 0} & L_{n-1, 1} & L_{n-1, 2} & \cdots & U_{n-1, n-1} + \end{bmatrix} + + For a matrix with more rows than the columns, the compressed + output matrix would be: + + .. math:: + LU = \begin{bmatrix} + U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, n-1} \\ + L_{1, 0} & U_{1, 1} & U_{1, 2} & \cdots & U_{1, n-1} \\ + L_{2, 0} & L_{2, 1} & U_{2, 2} & \cdots & U_{2, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + L_{n-1, 0} & L_{n-1, 1} & L_{n-1, 2} & \cdots + & U_{n-1, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + L_{m-1, 0} & L_{m-1, 1} & L_{m-1, 2} & \cdots + & L_{m-1, n-1} \\ + \end{bmatrix} + + For a matrix with more columns than the rows, the compressed + output matrix would be: + + .. math:: + LU = \begin{bmatrix} + U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, m-1} + & \cdots & U_{0, n-1} \\ + L_{1, 0} & U_{1, 1} & U_{1, 2} & \cdots & U_{1, m-1} + & \cdots & U_{1, n-1} \\ + L_{2, 0} & L_{2, 1} & U_{2, 2} & \cdots & U_{2, m-1} + & \cdots & U_{2, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots + & \cdots & \vdots \\ + L_{m-1, 0} & L_{m-1, 1} & L_{m-1, 2} & \cdots & U_{m-1, m-1} + & \cdots & U_{m-1, n-1} \\ + \end{bmatrix} + + About the pivot searching algorithm: + + When a matrix contains symbolic entries, the pivot search algorithm + differs from the case where every entry can be categorized as zero or + nonzero. + The algorithm searches column by column through the submatrix whose + top left entry coincides with the pivot position. + If it exists, the pivot is the first entry in the current search + column that iszerofunc guarantees is nonzero. + If no such candidate exists, then each candidate pivot is simplified + if simpfunc is not None. + The search is repeated, with the difference that a candidate may be + the pivot if ``iszerofunc()`` cannot guarantee that it is nonzero. + In the second search the pivot is the first candidate that + iszerofunc can guarantee is nonzero. + If no such candidate exists, then the pivot is the first candidate + for which iszerofunc returns None. + If no such candidate exists, then the search is repeated in the next + column to the right. + The pivot search algorithm differs from the one in ``rref()``, which + relies on ``_find_reasonable_pivot()``. + Future versions of ``LUdecomposition_simple()`` may use + ``_find_reasonable_pivot()``. + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + LUdecompositionFF + LUsolve + """ + + if rankcheck: + # https://github.com/sympy/sympy/issues/9796 + pass + + if S.Zero in M.shape: + # Define LU decomposition of a matrix with no entries as a matrix + # of the same dimensions with all zero entries. + return M.zeros(M.rows, M.cols), [] + + dps = _get_intermediate_simp() + lu = M.as_mutable() + row_swaps = [] + + pivot_col = 0 + + for pivot_row in range(0, lu.rows - 1): + # Search for pivot. Prefer entry that iszeropivot determines + # is nonzero, over entry that iszeropivot cannot guarantee + # is zero. + # XXX ``_find_reasonable_pivot`` uses slow zero testing. Blocked by bug #10279 + # Future versions of LUdecomposition_simple can pass iszerofunc and simpfunc + # to _find_reasonable_pivot(). + # In pass 3 of _find_reasonable_pivot(), the predicate in ``if x.equals(S.Zero):`` + # calls sympy.simplify(), and not the simplification function passed in via + # the keyword argument simpfunc. + iszeropivot = True + + while pivot_col != M.cols and iszeropivot: + sub_col = (lu[r, pivot_col] for r in range(pivot_row, M.rows)) + + pivot_row_offset, pivot_value, is_assumed_non_zero, ind_simplified_pairs =\ + _find_reasonable_pivot_naive(sub_col, iszerofunc, simpfunc) + + iszeropivot = pivot_value is None + + if iszeropivot: + # All candidate pivots in this column are zero. + # Proceed to next column. + pivot_col += 1 + + if rankcheck and pivot_col != pivot_row: + # All entries including and below the pivot position are + # zero, which indicates that the rank of the matrix is + # strictly less than min(num rows, num cols) + # Mimic behavior of previous implementation, by throwing a + # ValueError. + raise ValueError("Rank of matrix is strictly less than" + " number of rows or columns." + " Pass keyword argument" + " rankcheck=False to compute" + " the LU decomposition of this matrix.") + + candidate_pivot_row = None if pivot_row_offset is None else pivot_row + pivot_row_offset + + if candidate_pivot_row is None and iszeropivot: + # If candidate_pivot_row is None and iszeropivot is True + # after pivot search has completed, then the submatrix + # below and to the right of (pivot_row, pivot_col) is + # all zeros, indicating that Gaussian elimination is + # complete. + return lu, row_swaps + + # Update entries simplified during pivot search. + for offset, val in ind_simplified_pairs: + lu[pivot_row + offset, pivot_col] = val + + if pivot_row != candidate_pivot_row: + # Row swap book keeping: + # Record which rows were swapped. + # Update stored portion of L factor by multiplying L on the + # left and right with the current permutation. + # Swap rows of U. + row_swaps.append([pivot_row, candidate_pivot_row]) + + # Update L. + lu[pivot_row, 0:pivot_row], lu[candidate_pivot_row, 0:pivot_row] = \ + lu[candidate_pivot_row, 0:pivot_row], lu[pivot_row, 0:pivot_row] + + # Swap pivot row of U with candidate pivot row. + lu[pivot_row, pivot_col:lu.cols], lu[candidate_pivot_row, pivot_col:lu.cols] = \ + lu[candidate_pivot_row, pivot_col:lu.cols], lu[pivot_row, pivot_col:lu.cols] + + # Introduce zeros below the pivot by adding a multiple of the + # pivot row to a row under it, and store the result in the + # row under it. + # Only entries in the target row whose index is greater than + # start_col may be nonzero. + start_col = pivot_col + 1 + + for row in range(pivot_row + 1, lu.rows): + # Store factors of L in the subcolumn below + # (pivot_row, pivot_row). + lu[row, pivot_row] = \ + dps(lu[row, pivot_col]/lu[pivot_row, pivot_col]) + + # Form the linear combination of the pivot row and the current + # row below the pivot row that zeros the entries below the pivot. + # Employing slicing instead of a loop here raises + # NotImplementedError: Cannot add Zero to MutableSparseMatrix + # in sympy/matrices/tests/test_sparse.py. + # c = pivot_row + 1 if pivot_row == pivot_col else pivot_col + for c in range(start_col, lu.cols): + lu[row, c] = dps(lu[row, c] - lu[row, pivot_row]*lu[pivot_row, c]) + + if pivot_row != pivot_col: + # matrix rank < min(num rows, num cols), + # so factors of L are not stored directly below the pivot. + # These entries are zero by construction, so don't bother + # computing them. + for row in range(pivot_row + 1, lu.rows): + lu[row, pivot_col] = M.zero + + pivot_col += 1 + + if pivot_col == lu.cols: + # All candidate pivots are zero implies that Gaussian + # elimination is complete. + return lu, row_swaps + + if rankcheck: + if iszerofunc( + lu[Min(lu.rows, lu.cols) - 1, Min(lu.rows, lu.cols) - 1]): + raise ValueError("Rank of matrix is strictly less than" + " number of rows or columns." + " Pass keyword argument" + " rankcheck=False to compute" + " the LU decomposition of this matrix.") + + return lu, row_swaps + +def _LUdecompositionFF(M): + """Compute a fraction-free LU decomposition. + + Returns 4 matrices P, L, D, U such that PA = L D**-1 U. + If the elements of the matrix belong to some integral domain I, then all + elements of L, D and U are guaranteed to belong to I. + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + LUdecomposition_Simple + LUsolve + + References + ========== + + .. [1] W. Zhou & D.J. Jeffrey, "Fraction-free matrix factors: new forms + for LU and QR factors". Frontiers in Computer Science in China, + Vol 2, no. 1, pp. 67-80, 2008. + """ + + from sympy.matrices import SparseMatrix + + zeros = SparseMatrix.zeros + eye = SparseMatrix.eye + n, m = M.rows, M.cols + U, L, P = M.as_mutable(), eye(n), eye(n) + DD = zeros(n, n) + oldpivot = 1 + + for k in range(n - 1): + if U[k, k] == 0: + for kpivot in range(k + 1, n): + if U[kpivot, k]: + break + else: + raise ValueError("Matrix is not full rank") + + U[k, k:], U[kpivot, k:] = U[kpivot, k:], U[k, k:] + L[k, :k], L[kpivot, :k] = L[kpivot, :k], L[k, :k] + P[k, :], P[kpivot, :] = P[kpivot, :], P[k, :] + + L [k, k] = Ukk = U[k, k] + DD[k, k] = oldpivot * Ukk + + for i in range(k + 1, n): + L[i, k] = Uik = U[i, k] + + for j in range(k + 1, m): + U[i, j] = (Ukk * U[i, j] - U[k, j] * Uik) / oldpivot + + U[i, k] = 0 + + oldpivot = Ukk + + DD[n - 1, n - 1] = oldpivot + + return P, L, DD, U + +def _singular_value_decomposition(A): + r"""Returns a Condensed Singular Value decomposition. + + Explanation + =========== + + A Singular Value decomposition is a decomposition in the form $A = U \Sigma V^H$ + where + + - $U, V$ are column orthogonal matrix. + - $\Sigma$ is a diagonal matrix, where the main diagonal contains singular + values of matrix A. + + A column orthogonal matrix satisfies + $\mathbb{I} = U^H U$ while a full orthogonal matrix satisfies + relation $\mathbb{I} = U U^H = U^H U$ where $\mathbb{I}$ is an identity + matrix with matching dimensions. + + For matrices which are not square or are rank-deficient, it is + sufficient to return a column orthogonal matrix because augmenting + them may introduce redundant computations. + In condensed Singular Value Decomposition we only return column orthogonal + matrices because of this reason + + If you want to augment the results to return a full orthogonal + decomposition, you should use the following procedures. + + - Augment the $U, V$ matrices with columns that are orthogonal to every + other columns and make it square. + - Augment the $\Sigma$ matrix with zero rows to make it have the same + shape as the original matrix. + + The procedure will be illustrated in the examples section. + + Examples + ======== + + we take a full rank matrix first: + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2],[2,1]]) + >>> U, S, V = A.singular_value_decomposition() + >>> U + Matrix([ + [ sqrt(2)/2, sqrt(2)/2], + [-sqrt(2)/2, sqrt(2)/2]]) + >>> S + Matrix([ + [1, 0], + [0, 3]]) + >>> V + Matrix([ + [-sqrt(2)/2, sqrt(2)/2], + [ sqrt(2)/2, sqrt(2)/2]]) + + If a matrix if square and full rank both U, V + are orthogonal in both directions + + >>> U * U.H + Matrix([ + [1, 0], + [0, 1]]) + >>> U.H * U + Matrix([ + [1, 0], + [0, 1]]) + + >>> V * V.H + Matrix([ + [1, 0], + [0, 1]]) + >>> V.H * V + Matrix([ + [1, 0], + [0, 1]]) + >>> A == U * S * V.H + True + + >>> C = Matrix([ + ... [1, 0, 0, 0, 2], + ... [0, 0, 3, 0, 0], + ... [0, 0, 0, 0, 0], + ... [0, 2, 0, 0, 0], + ... ]) + >>> U, S, V = C.singular_value_decomposition() + + >>> V.H * V + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> V * V.H + Matrix([ + [1/5, 0, 0, 0, 2/5], + [ 0, 1, 0, 0, 0], + [ 0, 0, 1, 0, 0], + [ 0, 0, 0, 0, 0], + [2/5, 0, 0, 0, 4/5]]) + + If you want to augment the results to be a full orthogonal + decomposition, you should augment $V$ with an another orthogonal + column. + + You are able to append an arbitrary standard basis that are linearly + independent to every other columns and you can run the Gram-Schmidt + process to make them augmented as orthogonal basis. + + >>> V_aug = V.row_join(Matrix([[0,0,0,0,1], + ... [0,0,0,1,0]]).H) + >>> V_aug = V_aug.QRdecomposition()[0] + >>> V_aug + Matrix([ + [0, sqrt(5)/5, 0, -2*sqrt(5)/5, 0], + [1, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 1], + [0, 2*sqrt(5)/5, 0, sqrt(5)/5, 0]]) + >>> V_aug.H * V_aug + Matrix([ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1]]) + >>> V_aug * V_aug.H + Matrix([ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1]]) + + Similarly we augment U + + >>> U_aug = U.row_join(Matrix([0,0,1,0])) + >>> U_aug = U_aug.QRdecomposition()[0] + >>> U_aug + Matrix([ + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + [1, 0, 0, 0]]) + + >>> U_aug.H * U_aug + Matrix([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + >>> U_aug * U_aug.H + Matrix([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + + We add 2 zero columns and one row to S + + >>> S_aug = S.col_join(Matrix([[0,0,0]])) + >>> S_aug = S_aug.row_join(Matrix([[0,0,0,0], + ... [0,0,0,0]]).H) + >>> S_aug + Matrix([ + [2, 0, 0, 0, 0], + [0, sqrt(5), 0, 0, 0], + [0, 0, 3, 0, 0], + [0, 0, 0, 0, 0]]) + + + + >>> U_aug * S_aug * V_aug.H == C + True + + """ + + AH = A.H + m, n = A.shape + if m >= n: + V, S = (AH * A).diagonalize() + + ranked = [] + for i, x in enumerate(S.diagonal()): + if not x.is_zero: + ranked.append(i) + + V = V[:, ranked] + + Singular_vals = [sqrt(S[i, i]) for i in range(S.rows) if i in ranked] + + S = S.diag(*Singular_vals) + V, _ = V.QRdecomposition() + U = A * V * S.inv() + else: + U, S = (A * AH).diagonalize() + + ranked = [] + for i, x in enumerate(S.diagonal()): + if not x.is_zero: + ranked.append(i) + + U = U[:, ranked] + Singular_vals = [sqrt(S[i, i]) for i in range(S.rows) if i in ranked] + + S = S.diag(*Singular_vals) + U, _ = U.QRdecomposition() + V = AH * U * S.inv() + + return U, S, V + +def _QRdecomposition_optional(M, normalize=True): + def dot(u, v): + return u.dot(v, hermitian=True) + + dps = _get_intermediate_simp(expand_mul, expand_mul) + + A = M.as_mutable() + ranked = [] + + Q = A + R = A.zeros(A.cols) + + for j in range(A.cols): + for i in range(j): + if Q[:, i].is_zero_matrix: + continue + + R[i, j] = dot(Q[:, i], Q[:, j]) / dot(Q[:, i], Q[:, i]) + R[i, j] = dps(R[i, j]) + Q[:, j] -= Q[:, i] * R[i, j] + + Q[:, j] = dps(Q[:, j]) + if Q[:, j].is_zero_matrix is not True: + ranked.append(j) + R[j, j] = M.one + + Q = Q.extract(range(Q.rows), ranked) + R = R.extract(ranked, range(R.cols)) + + if normalize: + # Normalization + for i in range(Q.cols): + norm = Q[:, i].norm() + Q[:, i] /= norm + R[i, :] *= norm + + return M.__class__(Q), M.__class__(R) + + +def _QRdecomposition(M): + r"""Returns a QR decomposition. + + Explanation + =========== + + A QR decomposition is a decomposition in the form $A = Q R$ + where + + - $Q$ is a column orthogonal matrix. + - $R$ is a upper triangular (trapezoidal) matrix. + + A column orthogonal matrix satisfies + $\mathbb{I} = Q^H Q$ while a full orthogonal matrix satisfies + relation $\mathbb{I} = Q Q^H = Q^H Q$ where $I$ is an identity + matrix with matching dimensions. + + For matrices which are not square or are rank-deficient, it is + sufficient to return a column orthogonal matrix because augmenting + them may introduce redundant computations. + And an another advantage of this is that you can easily inspect the + matrix rank by counting the number of columns of $Q$. + + If you want to augment the results to return a full orthogonal + decomposition, you should use the following procedures. + + - Augment the $Q$ matrix with columns that are orthogonal to every + other columns and make it square. + - Augment the $R$ matrix with zero rows to make it have the same + shape as the original matrix. + + The procedure will be illustrated in the examples section. + + Examples + ======== + + A full rank matrix example: + + >>> from sympy import Matrix + >>> A = Matrix([[12, -51, 4], [6, 167, -68], [-4, 24, -41]]) + >>> Q, R = A.QRdecomposition() + >>> Q + Matrix([ + [ 6/7, -69/175, -58/175], + [ 3/7, 158/175, 6/175], + [-2/7, 6/35, -33/35]]) + >>> R + Matrix([ + [14, 21, -14], + [ 0, 175, -70], + [ 0, 0, 35]]) + + If the matrix is square and full rank, the $Q$ matrix becomes + orthogonal in both directions, and needs no augmentation. + + >>> Q * Q.H + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> Q.H * Q + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + >>> A == Q*R + True + + A rank deficient matrix example: + + >>> A = Matrix([[12, -51, 0], [6, 167, 0], [-4, 24, 0]]) + >>> Q, R = A.QRdecomposition() + >>> Q + Matrix([ + [ 6/7, -69/175], + [ 3/7, 158/175], + [-2/7, 6/35]]) + >>> R + Matrix([ + [14, 21, 0], + [ 0, 175, 0]]) + + QRdecomposition might return a matrix Q that is rectangular. + In this case the orthogonality condition might be satisfied as + $\mathbb{I} = Q.H*Q$ but not in the reversed product + $\mathbb{I} = Q * Q.H$. + + >>> Q.H * Q + Matrix([ + [1, 0], + [0, 1]]) + >>> Q * Q.H + Matrix([ + [27261/30625, 348/30625, -1914/6125], + [ 348/30625, 30589/30625, 198/6125], + [ -1914/6125, 198/6125, 136/1225]]) + + If you want to augment the results to be a full orthogonal + decomposition, you should augment $Q$ with an another orthogonal + column. + + You are able to append an identity matrix, + and you can run the Gram-Schmidt + process to make them augmented as orthogonal basis. + + >>> Q_aug = Q.row_join(Matrix.eye(3)) + >>> Q_aug = Q_aug.QRdecomposition()[0] + >>> Q_aug + Matrix([ + [ 6/7, -69/175, 58/175], + [ 3/7, 158/175, -6/175], + [-2/7, 6/35, 33/35]]) + >>> Q_aug.H * Q_aug + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> Q_aug * Q_aug.H + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + Augmenting the $R$ matrix with zero row is straightforward. + + >>> R_aug = R.col_join(Matrix([[0, 0, 0]])) + >>> R_aug + Matrix([ + [14, 21, 0], + [ 0, 175, 0], + [ 0, 0, 0]]) + >>> Q_aug * R_aug == A + True + + A zero matrix example: + + >>> from sympy import Matrix + >>> A = Matrix.zeros(3, 4) + >>> Q, R = A.QRdecomposition() + + They may return matrices with zero rows and columns. + + >>> Q + Matrix(3, 0, []) + >>> R + Matrix(0, 4, []) + >>> Q*R + Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + + As the same augmentation rule described above, $Q$ can be augmented + with columns of an identity matrix and $R$ can be augmented with + rows of a zero matrix. + + >>> Q_aug = Q.row_join(Matrix.eye(3)) + >>> R_aug = R.col_join(Matrix.zeros(3, 4)) + >>> Q_aug * Q_aug.T + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> R_aug + Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + >>> Q_aug * R_aug == A + True + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.cholesky + sympy.matrices.dense.DenseMatrix.LDLdecomposition + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + QRsolve + """ + return _QRdecomposition_optional(M, normalize=True) + +def _upper_hessenberg_decomposition(A): + """Converts a matrix into Hessenberg matrix H. + + Returns 2 matrices H, P s.t. + $P H P^{T} = A$, where H is an upper hessenberg matrix + and P is an orthogonal matrix + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([ + ... [1,2,3], + ... [-3,5,6], + ... [4,-8,9], + ... ]) + >>> H, P = A.upper_hessenberg_decomposition() + >>> H + Matrix([ + [1, 6/5, 17/5], + [5, 213/25, -134/25], + [0, 216/25, 137/25]]) + >>> P + Matrix([ + [1, 0, 0], + [0, -3/5, 4/5], + [0, 4/5, 3/5]]) + >>> P * H * P.H == A + True + + + References + ========== + + .. [#] https://mathworld.wolfram.com/HessenbergDecomposition.html + """ + + M = A.as_mutable() + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + + n = M.cols + P = M.eye(n) + H = M + + for j in range(n - 2): + + u = H[j + 1:, j] + + if u[1:, :].is_zero_matrix: + continue + + if sign(u[0]) != 0: + u[0] = u[0] + sign(u[0]) * u.norm() + else: + u[0] = u[0] + u.norm() + + v = u / u.norm() + + H[j + 1:, :] = H[j + 1:, :] - 2 * v * (v.H * H[j + 1:, :]) + H[:, j + 1:] = H[:, j + 1:] - (H[:, j + 1:] * (2 * v)) * v.H + P[:, j + 1:] = P[:, j + 1:] - (P[:, j + 1:] * (2 * v)) * v.H + + return H, P diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/dense.py b/.venv/lib/python3.13/site-packages/sympy/matrices/dense.py new file mode 100644 index 0000000000000000000000000000000000000000..98bf9931df54f67abfd9c4dc810b46fdcf70288f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/dense.py @@ -0,0 +1,1094 @@ +from __future__ import annotations +import random + +from sympy.core.basic import Basic +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import is_sequence + +from .exceptions import ShapeError +from .decompositions import _cholesky, _LDLdecomposition +from .matrixbase import MatrixBase +from .repmatrix import MutableRepMatrix, RepMatrix +from .solvers import _lower_triangular_solve, _upper_triangular_solve + + +__doctest_requires__ = {('symarray',): ['numpy']} + + +def _iszero(x): + """Returns True if x is zero.""" + return x.is_zero + + +class DenseMatrix(RepMatrix): + """Matrix implementation based on DomainMatrix as the internal representation""" + + # + # DenseMatrix is a superclass for both MutableDenseMatrix and + # ImmutableDenseMatrix. Methods shared by both classes but not for the + # Sparse classes should be implemented here. + # + + is_MatrixExpr: bool = False + + _op_priority = 10.01 + _class_priority = 4 + + @property + def _mat(self): + sympy_deprecation_warning( + """ + The private _mat attribute of Matrix is deprecated. Use the + .flat() method instead. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-private-matrix-attributes" + ) + + return self.flat() + + def _eval_inverse(self, **kwargs): + return self.inv(method=kwargs.get('method', 'GE'), + iszerofunc=kwargs.get('iszerofunc', _iszero), + try_block_diag=kwargs.get('try_block_diag', False)) + + def as_immutable(self): + """Returns an Immutable version of this Matrix + """ + from .immutable import ImmutableDenseMatrix as cls + return cls._fromrep(self._rep.copy()) + + def as_mutable(self): + """Returns a mutable version of this matrix + + Examples + ======== + + >>> from sympy import ImmutableMatrix + >>> X = ImmutableMatrix([[1, 2], [3, 4]]) + >>> Y = X.as_mutable() + >>> Y[1, 1] = 5 # Can set values in Y + >>> Y + Matrix([ + [1, 2], + [3, 5]]) + """ + return Matrix(self) + + def cholesky(self, hermitian=True): + return _cholesky(self, hermitian=hermitian) + + def LDLdecomposition(self, hermitian=True): + return _LDLdecomposition(self, hermitian=hermitian) + + def lower_triangular_solve(self, rhs): + return _lower_triangular_solve(self, rhs) + + def upper_triangular_solve(self, rhs): + return _upper_triangular_solve(self, rhs) + + cholesky.__doc__ = _cholesky.__doc__ + LDLdecomposition.__doc__ = _LDLdecomposition.__doc__ + lower_triangular_solve.__doc__ = _lower_triangular_solve.__doc__ + upper_triangular_solve.__doc__ = _upper_triangular_solve.__doc__ + + +def _force_mutable(x): + """Return a matrix as a Matrix, otherwise return x.""" + if getattr(x, 'is_Matrix', False): + return x.as_mutable() + elif isinstance(x, Basic): + return x + elif hasattr(x, '__array__'): + a = x.__array__() + if len(a.shape) == 0: + return sympify(a) + return Matrix(x) + return x + + +class MutableDenseMatrix(DenseMatrix, MutableRepMatrix): + + def simplify(self, **kwargs): + """Applies simplify to the elements of a matrix in place. + + This is a shortcut for M.applyfunc(lambda x: simplify(x, ratio, measure)) + + See Also + ======== + + sympy.simplify.simplify.simplify + """ + from sympy.simplify.simplify import simplify as _simplify + for (i, j), element in self.todok().items(): + self[i, j] = _simplify(element, **kwargs) + + +MutableMatrix = Matrix = MutableDenseMatrix + +########### +# Numpy Utility Functions: +# list2numpy, matrix2numpy, symmarray +########### + + +def list2numpy(l, dtype=object): # pragma: no cover + """Converts Python list of SymPy expressions to a NumPy array. + + See Also + ======== + + matrix2numpy + """ + from numpy import empty + a = empty(len(l), dtype) + for i, s in enumerate(l): + a[i] = s + return a + + +def matrix2numpy(m, dtype=object): # pragma: no cover + """Converts SymPy's matrix to a NumPy array. + + See Also + ======== + + list2numpy + """ + from numpy import empty + a = empty(m.shape, dtype) + for i in range(m.rows): + for j in range(m.cols): + a[i, j] = m[i, j] + return a + + +########### +# Rotation matrices: +# rot_givens, rot_axis[123], rot_ccw_axis[123] +########### + + +def rot_givens(i, j, theta, dim=3): + r"""Returns a a Givens rotation matrix, a a rotation in the + plane spanned by two coordinates axes. + + Explanation + =========== + + The Givens rotation corresponds to a generalization of rotation + matrices to any number of dimensions, given by: + + .. math:: + G(i, j, \theta) = + \begin{bmatrix} + 1 & \cdots & 0 & \cdots & 0 & \cdots & 0 \\ + \vdots & \ddots & \vdots & & \vdots & & \vdots \\ + 0 & \cdots & c & \cdots & -s & \cdots & 0 \\ + \vdots & & \vdots & \ddots & \vdots & & \vdots \\ + 0 & \cdots & s & \cdots & c & \cdots & 0 \\ + \vdots & & \vdots & & \vdots & \ddots & \vdots \\ + 0 & \cdots & 0 & \cdots & 0 & \cdots & 1 + \end{bmatrix} + + Where $c = \cos(\theta)$ and $s = \sin(\theta)$ appear at the intersections + ``i``\th and ``j``\th rows and columns. + + For fixed ``i > j``\, the non-zero elements of a Givens matrix are + given by: + + - $g_{kk} = 1$ for $k \ne i,\,j$ + - $g_{kk} = c$ for $k = i,\,j$ + - $g_{ji} = -g_{ij} = -s$ + + Parameters + ========== + + i : int between ``0`` and ``dim - 1`` + Represents first axis + j : int between ``0`` and ``dim - 1`` + Represents second axis + dim : int bigger than 1 + Number of dimensions. Defaults to 3. + + Examples + ======== + + >>> from sympy import pi, rot_givens + + A counterclockwise rotation of pi/3 (60 degrees) around + the third axis (z-axis): + + >>> rot_givens(1, 0, pi/3) + Matrix([ + [ 1/2, -sqrt(3)/2, 0], + [sqrt(3)/2, 1/2, 0], + [ 0, 0, 1]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_givens(1, 0, pi/2) + Matrix([ + [0, -1, 0], + [1, 0, 0], + [0, 0, 1]]) + + This can be generalized to any number + of dimensions: + + >>> rot_givens(1, 0, pi/2, dim=4) + Matrix([ + [0, -1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Givens_rotation + + See Also + ======== + + rot_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (clockwise around the x axis) + rot_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (clockwise around the y axis) + rot_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (clockwise around the z axis) + rot_ccw_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (counterclockwise around the x axis) + rot_ccw_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (counterclockwise around the y axis) + rot_ccw_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (counterclockwise around the z axis) + """ + if not isinstance(dim, int) or dim < 2: + raise ValueError('dim must be an integer biggen than one, ' + 'got {}.'.format(dim)) + + if i == j: + raise ValueError('i and j must be different, ' + 'got ({}, {})'.format(i, j)) + + for ij in [i, j]: + if not isinstance(ij, int) or ij < 0 or ij > dim - 1: + raise ValueError('i and j must be integers between 0 and ' + '{}, got i={} and j={}.'.format(dim-1, i, j)) + + theta = sympify(theta) + c = cos(theta) + s = sin(theta) + M = eye(dim) + M[i, i] = c + M[j, j] = c + M[i, j] = s + M[j, i] = -s + return M + + +def rot_axis3(theta): + r"""Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis. + + Explanation + =========== + + For a right-handed coordinate system, this corresponds to a + clockwise rotation around the `z`-axis, given by: + + .. math:: + + R = \begin{bmatrix} + \cos(\theta) & \sin(\theta) & 0 \\ + -\sin(\theta) & \cos(\theta) & 0 \\ + 0 & 0 & 1 + \end{bmatrix} + + Examples + ======== + + >>> from sympy import pi, rot_axis3 + + A rotation of pi/3 (60 degrees): + + >>> theta = pi/3 + >>> rot_axis3(theta) + Matrix([ + [ 1/2, sqrt(3)/2, 0], + [-sqrt(3)/2, 1/2, 0], + [ 0, 0, 1]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_axis3(pi/2) + Matrix([ + [ 0, 1, 0], + [-1, 0, 0], + [ 0, 0, 1]]) + + See Also + ======== + + rot_givens: Returns a Givens rotation matrix (generalized rotation for + any number of dimensions) + rot_ccw_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (counterclockwise around the z axis) + rot_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (clockwise around the x axis) + rot_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (clockwise around the y axis) + """ + return rot_givens(0, 1, theta, dim=3) + + +def rot_axis2(theta): + r"""Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis. + + Explanation + =========== + + For a right-handed coordinate system, this corresponds to a + clockwise rotation around the `y`-axis, given by: + + .. math:: + + R = \begin{bmatrix} + \cos(\theta) & 0 & -\sin(\theta) \\ + 0 & 1 & 0 \\ + \sin(\theta) & 0 & \cos(\theta) + \end{bmatrix} + + Examples + ======== + + >>> from sympy import pi, rot_axis2 + + A rotation of pi/3 (60 degrees): + + >>> theta = pi/3 + >>> rot_axis2(theta) + Matrix([ + [ 1/2, 0, -sqrt(3)/2], + [ 0, 1, 0], + [sqrt(3)/2, 0, 1/2]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_axis2(pi/2) + Matrix([ + [0, 0, -1], + [0, 1, 0], + [1, 0, 0]]) + + See Also + ======== + + rot_givens: Returns a Givens rotation matrix (generalized rotation for + any number of dimensions) + rot_ccw_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (clockwise around the y axis) + rot_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (counterclockwise around the x axis) + rot_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (counterclockwise around the z axis) + """ + return rot_givens(2, 0, theta, dim=3) + + +def rot_axis1(theta): + r"""Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis. + + Explanation + =========== + + For a right-handed coordinate system, this corresponds to a + clockwise rotation around the `x`-axis, given by: + + .. math:: + + R = \begin{bmatrix} + 1 & 0 & 0 \\ + 0 & \cos(\theta) & \sin(\theta) \\ + 0 & -\sin(\theta) & \cos(\theta) + \end{bmatrix} + + Examples + ======== + + >>> from sympy import pi, rot_axis1 + + A rotation of pi/3 (60 degrees): + + >>> theta = pi/3 + >>> rot_axis1(theta) + Matrix([ + [1, 0, 0], + [0, 1/2, sqrt(3)/2], + [0, -sqrt(3)/2, 1/2]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_axis1(pi/2) + Matrix([ + [1, 0, 0], + [0, 0, 1], + [0, -1, 0]]) + + See Also + ======== + + rot_givens: Returns a Givens rotation matrix (generalized rotation for + any number of dimensions) + rot_ccw_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (counterclockwise around the x axis) + rot_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (clockwise around the y axis) + rot_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (clockwise around the z axis) + """ + return rot_givens(1, 2, theta, dim=3) + + +def rot_ccw_axis3(theta): + r"""Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis. + + Explanation + =========== + + For a right-handed coordinate system, this corresponds to a + counterclockwise rotation around the `z`-axis, given by: + + .. math:: + + R = \begin{bmatrix} + \cos(\theta) & -\sin(\theta) & 0 \\ + \sin(\theta) & \cos(\theta) & 0 \\ + 0 & 0 & 1 + \end{bmatrix} + + Examples + ======== + + >>> from sympy import pi, rot_ccw_axis3 + + A rotation of pi/3 (60 degrees): + + >>> theta = pi/3 + >>> rot_ccw_axis3(theta) + Matrix([ + [ 1/2, -sqrt(3)/2, 0], + [sqrt(3)/2, 1/2, 0], + [ 0, 0, 1]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_ccw_axis3(pi/2) + Matrix([ + [0, -1, 0], + [1, 0, 0], + [0, 0, 1]]) + + See Also + ======== + + rot_givens: Returns a Givens rotation matrix (generalized rotation for + any number of dimensions) + rot_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (clockwise around the z axis) + rot_ccw_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (counterclockwise around the x axis) + rot_ccw_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (counterclockwise around the y axis) + """ + return rot_givens(1, 0, theta, dim=3) + + +def rot_ccw_axis2(theta): + r"""Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis. + + Explanation + =========== + + For a right-handed coordinate system, this corresponds to a + counterclockwise rotation around the `y`-axis, given by: + + .. math:: + + R = \begin{bmatrix} + \cos(\theta) & 0 & \sin(\theta) \\ + 0 & 1 & 0 \\ + -\sin(\theta) & 0 & \cos(\theta) + \end{bmatrix} + + Examples + ======== + + >>> from sympy import pi, rot_ccw_axis2 + + A rotation of pi/3 (60 degrees): + + >>> theta = pi/3 + >>> rot_ccw_axis2(theta) + Matrix([ + [ 1/2, 0, sqrt(3)/2], + [ 0, 1, 0], + [-sqrt(3)/2, 0, 1/2]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_ccw_axis2(pi/2) + Matrix([ + [ 0, 0, 1], + [ 0, 1, 0], + [-1, 0, 0]]) + + See Also + ======== + + rot_givens: Returns a Givens rotation matrix (generalized rotation for + any number of dimensions) + rot_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (clockwise around the y axis) + rot_ccw_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (counterclockwise around the x axis) + rot_ccw_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (counterclockwise around the z axis) + """ + return rot_givens(0, 2, theta, dim=3) + + +def rot_ccw_axis1(theta): + r"""Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis. + + Explanation + =========== + + For a right-handed coordinate system, this corresponds to a + counterclockwise rotation around the `x`-axis, given by: + + .. math:: + + R = \begin{bmatrix} + 1 & 0 & 0 \\ + 0 & \cos(\theta) & -\sin(\theta) \\ + 0 & \sin(\theta) & \cos(\theta) + \end{bmatrix} + + Examples + ======== + + >>> from sympy import pi, rot_ccw_axis1 + + A rotation of pi/3 (60 degrees): + + >>> theta = pi/3 + >>> rot_ccw_axis1(theta) + Matrix([ + [1, 0, 0], + [0, 1/2, -sqrt(3)/2], + [0, sqrt(3)/2, 1/2]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_ccw_axis1(pi/2) + Matrix([ + [1, 0, 0], + [0, 0, -1], + [0, 1, 0]]) + + See Also + ======== + + rot_givens: Returns a Givens rotation matrix (generalized rotation for + any number of dimensions) + rot_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (clockwise around the x axis) + rot_ccw_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (counterclockwise around the y axis) + rot_ccw_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (counterclockwise around the z axis) + """ + return rot_givens(2, 1, theta, dim=3) + + +@doctest_depends_on(modules=('numpy',)) +def symarray(prefix, shape, **kwargs): # pragma: no cover + r"""Create a numpy ndarray of symbols (as an object array). + + The created symbols are named ``prefix_i1_i2_``... You should thus provide a + non-empty prefix if you want your symbols to be unique for different output + arrays, as SymPy symbols with identical names are the same object. + + Parameters + ---------- + + prefix : string + A prefix prepended to the name of every symbol. + + shape : int or tuple + Shape of the created array. If an int, the array is one-dimensional; for + more than one dimension the shape must be a tuple. + + \*\*kwargs : dict + keyword arguments passed on to Symbol + + Examples + ======== + These doctests require numpy. + + >>> from sympy import symarray + >>> symarray('', 3) + [_0 _1 _2] + + If you want multiple symarrays to contain distinct symbols, you *must* + provide unique prefixes: + + >>> a = symarray('', 3) + >>> b = symarray('', 3) + >>> a[0] == b[0] + True + >>> a = symarray('a', 3) + >>> b = symarray('b', 3) + >>> a[0] == b[0] + False + + Creating symarrays with a prefix: + + >>> symarray('a', 3) + [a_0 a_1 a_2] + + For more than one dimension, the shape must be given as a tuple: + + >>> symarray('a', (2, 3)) + [[a_0_0 a_0_1 a_0_2] + [a_1_0 a_1_1 a_1_2]] + >>> symarray('a', (2, 3, 2)) + [[[a_0_0_0 a_0_0_1] + [a_0_1_0 a_0_1_1] + [a_0_2_0 a_0_2_1]] + + [[a_1_0_0 a_1_0_1] + [a_1_1_0 a_1_1_1] + [a_1_2_0 a_1_2_1]]] + + For setting assumptions of the underlying Symbols: + + >>> [s.is_real for s in symarray('a', 2, real=True)] + [True, True] + """ + from numpy import empty, ndindex + arr = empty(shape, dtype=object) + for index in ndindex(shape): + arr[index] = Symbol('%s_%s' % (prefix, '_'.join(map(str, index))), + **kwargs) + return arr + + +############### +# Functions +############### + +def casoratian(seqs, n, zero=True): + """Given linear difference operator L of order 'k' and homogeneous + equation Ly = 0 we want to compute kernel of L, which is a set + of 'k' sequences: a(n), b(n), ... z(n). + + Solutions of L are linearly independent iff their Casoratian, + denoted as C(a, b, ..., z), do not vanish for n = 0. + + Casoratian is defined by k x k determinant:: + + + a(n) b(n) . . . z(n) + + | a(n+1) b(n+1) . . . z(n+1) | + | . . . . | + | . . . . | + | . . . . | + + a(n+k-1) b(n+k-1) . . . z(n+k-1) + + + It proves very useful in rsolve_hyper() where it is applied + to a generating set of a recurrence to factor out linearly + dependent solutions and return a basis: + + >>> from sympy import Symbol, casoratian, factorial + >>> n = Symbol('n', integer=True) + + Exponential and factorial are linearly independent: + + >>> casoratian([2**n, factorial(n)], n) != 0 + True + + """ + + seqs = list(map(sympify, seqs)) + + if not zero: + f = lambda i, j: seqs[j].subs(n, n + i) + else: + f = lambda i, j: seqs[j].subs(n, i) + + k = len(seqs) + + return Matrix(k, k, f).det() + + +def eye(*args, **kwargs): + """Create square identity matrix n x n + + See Also + ======== + + diag + zeros + ones + """ + + return Matrix.eye(*args, **kwargs) + + +def diag(*values, strict=True, unpack=False, **kwargs): + """Returns a matrix with the provided values placed on the + diagonal. If non-square matrices are included, they will + produce a block-diagonal matrix. + + Examples + ======== + + This version of diag is a thin wrapper to Matrix.diag that differs + in that it treats all lists like matrices -- even when a single list + is given. If this is not desired, either put a `*` before the list or + set `unpack=True`. + + >>> from sympy import diag + + >>> diag([1, 2, 3], unpack=True) # = diag(1,2,3) or diag(*[1,2,3]) + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + + >>> diag([1, 2, 3]) # a column vector + Matrix([ + [1], + [2], + [3]]) + + See Also + ======== + .matrixbase.MatrixBase.eye + .matrixbase.MatrixBase.diagonal + .matrixbase.MatrixBase.diag + .expressions.blockmatrix.BlockMatrix + """ + return Matrix.diag(*values, strict=strict, unpack=unpack, **kwargs) + + +def GramSchmidt(vlist, orthonormal=False): + """Apply the Gram-Schmidt process to a set of vectors. + + Parameters + ========== + + vlist : List of Matrix + Vectors to be orthogonalized for. + + orthonormal : Bool, optional + If true, return an orthonormal basis. + + Returns + ======= + + vlist : List of Matrix + Orthogonalized vectors + + Notes + ===== + + This routine is mostly duplicate from ``Matrix.orthogonalize``, + except for some difference that this always raises error when + linearly dependent vectors are found, and the keyword ``normalize`` + has been named as ``orthonormal`` in this function. + + See Also + ======== + + .matrixbase.MatrixBase.orthogonalize + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process + """ + return MutableDenseMatrix.orthogonalize( + *vlist, normalize=orthonormal, rankcheck=True + ) + + +def hessian(f, varlist, constraints=()): + """Compute Hessian matrix for a function f wrt parameters in varlist + which may be given as a sequence or a row/column vector. A list of + constraints may optionally be given. + + Examples + ======== + + >>> from sympy import Function, hessian, pprint + >>> from sympy.abc import x, y + >>> f = Function('f')(x, y) + >>> g1 = Function('g')(x, y) + >>> g2 = x**2 + 3*y + >>> pprint(hessian(f, (x, y), [g1, g2])) + [ d d ] + [ 0 0 --(g(x, y)) --(g(x, y)) ] + [ dx dy ] + [ ] + [ 0 0 2*x 3 ] + [ ] + [ 2 2 ] + [d d d ] + [--(g(x, y)) 2*x ---(f(x, y)) -----(f(x, y))] + [dx 2 dy dx ] + [ dx ] + [ ] + [ 2 2 ] + [d d d ] + [--(g(x, y)) 3 -----(f(x, y)) ---(f(x, y)) ] + [dy dy dx 2 ] + [ dy ] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hessian_matrix + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.jacobian + wronskian + """ + # f is the expression representing a function f, return regular matrix + if isinstance(varlist, MatrixBase): + if 1 not in varlist.shape: + raise ShapeError("`varlist` must be a column or row vector.") + if varlist.cols == 1: + varlist = varlist.T + varlist = varlist.tolist()[0] + if is_sequence(varlist): + n = len(varlist) + if not n: + raise ShapeError("`len(varlist)` must not be zero.") + else: + raise ValueError("Improper variable list in hessian function") + if not getattr(f, 'diff'): + # check differentiability + raise ValueError("Function `f` (%s) is not differentiable" % f) + m = len(constraints) + N = m + n + out = zeros(N) + for k, g in enumerate(constraints): + if not getattr(g, 'diff'): + # check differentiability + raise ValueError("Function `f` (%s) is not differentiable" % f) + for i in range(n): + out[k, i + m] = g.diff(varlist[i]) + for i in range(n): + for j in range(i, n): + out[i + m, j + m] = f.diff(varlist[i]).diff(varlist[j]) + for i in range(N): + for j in range(i + 1, N): + out[j, i] = out[i, j] + return out + + +def jordan_cell(eigenval, n): + """ + Create a Jordan block: + + Examples + ======== + + >>> from sympy import jordan_cell + >>> from sympy.abc import x + >>> jordan_cell(x, 4) + Matrix([ + [x, 1, 0, 0], + [0, x, 1, 0], + [0, 0, x, 1], + [0, 0, 0, x]]) + """ + + return Matrix.jordan_block(size=n, eigenvalue=eigenval) + + +def matrix_multiply_elementwise(A, B): + """Return the Hadamard product (elementwise product) of A and B + + >>> from sympy import Matrix, matrix_multiply_elementwise + >>> A = Matrix([[0, 1, 2], [3, 4, 5]]) + >>> B = Matrix([[1, 10, 100], [100, 10, 1]]) + >>> matrix_multiply_elementwise(A, B) + Matrix([ + [ 0, 10, 200], + [300, 40, 5]]) + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.__mul__ + """ + return A.multiply_elementwise(B) + + +def ones(*args, **kwargs): + """Returns a matrix of ones with ``rows`` rows and ``cols`` columns; + if ``cols`` is omitted a square matrix will be returned. + + See Also + ======== + + zeros + eye + diag + """ + + if 'c' in kwargs: + kwargs['cols'] = kwargs.pop('c') + + return Matrix.ones(*args, **kwargs) + + +def randMatrix(r, c=None, min=0, max=99, seed=None, symmetric=False, + percent=100, prng=None): + """Create random matrix with dimensions ``r`` x ``c``. If ``c`` is omitted + the matrix will be square. If ``symmetric`` is True the matrix must be + square. If ``percent`` is less than 100 then only approximately the given + percentage of elements will be non-zero. + + The pseudo-random number generator used to generate matrix is chosen in the + following way. + + * If ``prng`` is supplied, it will be used as random number generator. + It should be an instance of ``random.Random``, or at least have + ``randint`` and ``shuffle`` methods with same signatures. + * if ``prng`` is not supplied but ``seed`` is supplied, then new + ``random.Random`` with given ``seed`` will be created; + * otherwise, a new ``random.Random`` with default seed will be used. + + Examples + ======== + + >>> from sympy import randMatrix + >>> randMatrix(3) # doctest:+SKIP + [25, 45, 27] + [44, 54, 9] + [23, 96, 46] + >>> randMatrix(3, 2) # doctest:+SKIP + [87, 29] + [23, 37] + [90, 26] + >>> randMatrix(3, 3, 0, 2) # doctest:+SKIP + [0, 2, 0] + [2, 0, 1] + [0, 0, 1] + >>> randMatrix(3, symmetric=True) # doctest:+SKIP + [85, 26, 29] + [26, 71, 43] + [29, 43, 57] + >>> A = randMatrix(3, seed=1) + >>> B = randMatrix(3, seed=2) + >>> A == B + False + >>> A == randMatrix(3, seed=1) + True + >>> randMatrix(3, symmetric=True, percent=50) # doctest:+SKIP + [77, 70, 0], + [70, 0, 0], + [ 0, 0, 88] + """ + # Note that ``Random()`` is equivalent to ``Random(None)`` + prng = prng or random.Random(seed) + + if c is None: + c = r + + if symmetric and r != c: + raise ValueError('For symmetric matrices, r must equal c, but %i != %i' % (r, c)) + + ij = range(r * c) + if percent != 100: + ij = prng.sample(ij, int(len(ij)*percent // 100)) + + m = zeros(r, c) + + if not symmetric: + for ijk in ij: + i, j = divmod(ijk, c) + m[i, j] = prng.randint(min, max) + else: + for ijk in ij: + i, j = divmod(ijk, c) + if i <= j: + m[i, j] = m[j, i] = prng.randint(min, max) + + return m + + +def wronskian(functions, var, method='bareiss'): + """ + Compute Wronskian for [] of functions + + :: + + | f1 f2 ... fn | + | f1' f2' ... fn' | + | . . . . | + W(f1, ..., fn) = | . . . . | + | . . . . | + | (n) (n) (n) | + | D (f1) D (f2) ... D (fn) | + + see: https://en.wikipedia.org/wiki/Wronskian + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.jacobian + hessian + """ + + functions = [sympify(f) for f in functions] + n = len(functions) + if n == 0: + return S.One + W = Matrix(n, n, lambda i, j: functions[i].diff(var, j)) + return W.det(method) + + +def zeros(*args, **kwargs): + """Returns a matrix of zeros with ``rows`` rows and ``cols`` columns; + if ``cols`` is omitted a square matrix will be returned. + + See Also + ======== + + ones + eye + diag + """ + + if 'c' in kwargs: + kwargs['cols'] = kwargs.pop('c') + + return Matrix.zeros(*args, **kwargs) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/determinant.py b/.venv/lib/python3.13/site-packages/sympy/matrices/determinant.py new file mode 100644 index 0000000000000000000000000000000000000000..9206c0714999ebe0cde5c4300d9b3293939177df --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/determinant.py @@ -0,0 +1,1021 @@ +from types import FunctionType + +from sympy.core.cache import cacheit +from sympy.core.numbers import Float, Integer +from sympy.core.singleton import S +from sympy.core.symbol import uniquely_named_symbol +from sympy.core.mul import Mul +from sympy.polys import PurePoly, cancel +from sympy.functions.combinatorial.numbers import nC +from sympy.polys.matrices.domainmatrix import DomainMatrix +from sympy.polys.matrices.ddm import DDM + +from .exceptions import NonSquareMatrixError +from .utilities import ( + _get_intermediate_simp, _get_intermediate_simp_bool, + _iszero, _is_zero_after_expand_mul, _dotprodsimp, _simplify) + + +def _find_reasonable_pivot(col, iszerofunc=_iszero, simpfunc=_simplify): + """ Find the lowest index of an item in ``col`` that is + suitable for a pivot. If ``col`` consists only of + Floats, the pivot with the largest norm is returned. + Otherwise, the first element where ``iszerofunc`` returns + False is used. If ``iszerofunc`` does not return false, + items are simplified and retested until a suitable + pivot is found. + + Returns a 4-tuple + (pivot_offset, pivot_val, assumed_nonzero, newly_determined) + where pivot_offset is the index of the pivot, pivot_val is + the (possibly simplified) value of the pivot, assumed_nonzero + is True if an assumption that the pivot was non-zero + was made without being proved, and newly_determined are + elements that were simplified during the process of pivot + finding.""" + + newly_determined = [] + col = list(col) + # a column that contains a mix of floats and integers + # but at least one float is considered a numerical + # column, and so we do partial pivoting + if all(isinstance(x, (Float, Integer)) for x in col) and any( + isinstance(x, Float) for x in col): + col_abs = [abs(x) for x in col] + max_value = max(col_abs) + if iszerofunc(max_value): + # just because iszerofunc returned True, doesn't + # mean the value is numerically zero. Make sure + # to replace all entries with numerical zeros + if max_value != 0: + newly_determined = [(i, 0) for i, x in enumerate(col) if x != 0] + return (None, None, False, newly_determined) + index = col_abs.index(max_value) + return (index, col[index], False, newly_determined) + + # PASS 1 (iszerofunc directly) + possible_zeros = [] + for i, x in enumerate(col): + is_zero = iszerofunc(x) + # is someone wrote a custom iszerofunc, it may return + # BooleanFalse or BooleanTrue instead of True or False, + # so use == for comparison instead of `is` + if is_zero == False: + # we found something that is definitely not zero + return (i, x, False, newly_determined) + possible_zeros.append(is_zero) + + # by this point, we've found no certain non-zeros + if all(possible_zeros): + # if everything is definitely zero, we have + # no pivot + return (None, None, False, newly_determined) + + # PASS 2 (iszerofunc after simplify) + # we haven't found any for-sure non-zeros, so + # go through the elements iszerofunc couldn't + # make a determination about and opportunistically + # simplify to see if we find something + for i, x in enumerate(col): + if possible_zeros[i] is not None: + continue + simped = simpfunc(x) + is_zero = iszerofunc(simped) + if is_zero in (True, False): + newly_determined.append((i, simped)) + if is_zero == False: + return (i, simped, False, newly_determined) + possible_zeros[i] = is_zero + + # after simplifying, some things that were recognized + # as zeros might be zeros + if all(possible_zeros): + # if everything is definitely zero, we have + # no pivot + return (None, None, False, newly_determined) + + # PASS 3 (.equals(0)) + # some expressions fail to simplify to zero, but + # ``.equals(0)`` evaluates to True. As a last-ditch + # attempt, apply ``.equals`` to these expressions + for i, x in enumerate(col): + if possible_zeros[i] is not None: + continue + if x.equals(S.Zero): + # ``.iszero`` may return False with + # an implicit assumption (e.g., ``x.equals(0)`` + # when ``x`` is a symbol), so only treat it + # as proved when ``.equals(0)`` returns True + possible_zeros[i] = True + newly_determined.append((i, S.Zero)) + + if all(possible_zeros): + return (None, None, False, newly_determined) + + # at this point there is nothing that could definitely + # be a pivot. To maintain compatibility with existing + # behavior, we'll assume that an illdetermined thing is + # non-zero. We should probably raise a warning in this case + i = possible_zeros.index(None) + return (i, col[i], True, newly_determined) + + +def _find_reasonable_pivot_naive(col, iszerofunc=_iszero, simpfunc=None): + """ + Helper that computes the pivot value and location from a + sequence of contiguous matrix column elements. As a side effect + of the pivot search, this function may simplify some of the elements + of the input column. A list of these simplified entries and their + indices are also returned. + This function mimics the behavior of _find_reasonable_pivot(), + but does less work trying to determine if an indeterminate candidate + pivot simplifies to zero. This more naive approach can be much faster, + with the trade-off that it may erroneously return a pivot that is zero. + + ``col`` is a sequence of contiguous column entries to be searched for + a suitable pivot. + ``iszerofunc`` is a callable that returns a Boolean that indicates + if its input is zero, or None if no such determination can be made. + ``simpfunc`` is a callable that simplifies its input. It must return + its input if it does not simplify its input. Passing in + ``simpfunc=None`` indicates that the pivot search should not attempt + to simplify any candidate pivots. + + Returns a 4-tuple: + (pivot_offset, pivot_val, assumed_nonzero, newly_determined) + ``pivot_offset`` is the sequence index of the pivot. + ``pivot_val`` is the value of the pivot. + pivot_val and col[pivot_index] are equivalent, but will be different + when col[pivot_index] was simplified during the pivot search. + ``assumed_nonzero`` is a boolean indicating if the pivot cannot be + guaranteed to be zero. If assumed_nonzero is true, then the pivot + may or may not be non-zero. If assumed_nonzero is false, then + the pivot is non-zero. + ``newly_determined`` is a list of index-value pairs of pivot candidates + that were simplified during the pivot search. + """ + + # indeterminates holds the index-value pairs of each pivot candidate + # that is neither zero or non-zero, as determined by iszerofunc(). + # If iszerofunc() indicates that a candidate pivot is guaranteed + # non-zero, or that every candidate pivot is zero then the contents + # of indeterminates are unused. + # Otherwise, the only viable candidate pivots are symbolic. + # In this case, indeterminates will have at least one entry, + # and all but the first entry are ignored when simpfunc is None. + indeterminates = [] + for i, col_val in enumerate(col): + col_val_is_zero = iszerofunc(col_val) + if col_val_is_zero == False: + # This pivot candidate is non-zero. + return i, col_val, False, [] + elif col_val_is_zero is None: + # The candidate pivot's comparison with zero + # is indeterminate. + indeterminates.append((i, col_val)) + + if len(indeterminates) == 0: + # All candidate pivots are guaranteed to be zero, i.e. there is + # no pivot. + return None, None, False, [] + + if simpfunc is None: + # Caller did not pass in a simplification function that might + # determine if an indeterminate pivot candidate is guaranteed + # to be nonzero, so assume the first indeterminate candidate + # is non-zero. + return indeterminates[0][0], indeterminates[0][1], True, [] + + # newly_determined holds index-value pairs of candidate pivots + # that were simplified during the search for a non-zero pivot. + newly_determined = [] + for i, col_val in indeterminates: + tmp_col_val = simpfunc(col_val) + if id(col_val) != id(tmp_col_val): + # simpfunc() simplified this candidate pivot. + newly_determined.append((i, tmp_col_val)) + if iszerofunc(tmp_col_val) == False: + # Candidate pivot simplified to a guaranteed non-zero value. + return i, tmp_col_val, False, newly_determined + + return indeterminates[0][0], indeterminates[0][1], True, newly_determined + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _berkowitz_toeplitz_matrix(M): + """Return (A,T) where T the Toeplitz matrix used in the Berkowitz algorithm + corresponding to ``M`` and A is the first principal submatrix. + """ + + # the 0 x 0 case is trivial + if M.rows == 0 and M.cols == 0: + return M._new(1,1, [M.one]) + + # + # Partition M = [ a_11 R ] + # [ C A ] + # + + a, R = M[0,0], M[0, 1:] + C, A = M[1:, 0], M[1:,1:] + + # + # The Toeplitz matrix looks like + # + # [ 1 ] + # [ -a 1 ] + # [ -RC -a 1 ] + # [ -RAC -RC -a 1 ] + # [ -RA**2C -RAC -RC -a 1 ] + # etc. + + # Compute the diagonal entries. + # Because multiplying matrix times vector is so much + # more efficient than matrix times matrix, recursively + # compute -R * A**n * C. + diags = [C] + for i in range(M.rows - 2): + diags.append(A.multiply(diags[i], dotprodsimp=None)) + diags = [(-R).multiply(d, dotprodsimp=None)[0, 0] for d in diags] + diags = [M.one, -a] + diags + + def entry(i,j): + if j > i: + return M.zero + return diags[i - j] + + toeplitz = M._new(M.cols + 1, M.rows, entry) + return (A, toeplitz) + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _berkowitz_vector(M): + """ Run the Berkowitz algorithm and return a vector whose entries + are the coefficients of the characteristic polynomial of ``M``. + + Given N x N matrix, efficiently compute + coefficients of characteristic polynomials of ``M`` + without division in the ground domain. + + This method is particularly useful for computing determinant, + principal minors and characteristic polynomial when ``M`` + has complicated coefficients e.g. polynomials. Semi-direct + usage of this algorithm is also important in computing + efficiently sub-resultant PRS. + + Assuming that M is a square matrix of dimension N x N and + I is N x N identity matrix, then the Berkowitz vector is + an N x 1 vector whose entries are coefficients of the + polynomial + + charpoly(M) = det(t*I - M) + + As a consequence, all polynomials generated by Berkowitz + algorithm are monic. + + For more information on the implemented algorithm refer to: + + [1] S.J. Berkowitz, On computing the determinant in small + parallel time using a small number of processors, ACM, + Information Processing Letters 18, 1984, pp. 147-150 + + [2] M. Keber, Division-Free computation of sub-resultants + using Bezout matrices, Tech. Report MPI-I-2006-1-006, + Saarbrucken, 2006 + """ + + # handle the trivial cases + if M.rows == 0 and M.cols == 0: + return M._new(1, 1, [M.one]) + elif M.rows == 1 and M.cols == 1: + return M._new(2, 1, [M.one, -M[0,0]]) + + submat, toeplitz = _berkowitz_toeplitz_matrix(M) + + return toeplitz.multiply(_berkowitz_vector(submat), dotprodsimp=None) + + +def _adjugate(M, method="berkowitz"): + """Returns the adjugate, or classical adjoint, of + a matrix. That is, the transpose of the matrix of cofactors. + + https://en.wikipedia.org/wiki/Adjugate + + Parameters + ========== + + method : string, optional + Method to use to find the cofactors, can be "bareiss", "berkowitz", + "bird", "laplace" or "lu". + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> M.adjugate() + Matrix([ + [ 4, -2], + [-3, 1]]) + + See Also + ======== + + cofactor_matrix + sympy.matrices.matrixbase.MatrixBase.transpose + """ + + return M.cofactor_matrix(method=method).transpose() + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _charpoly(M, x='lambda', simplify=_simplify): + """Computes characteristic polynomial det(x*I - M) where I is + the identity matrix. + + A PurePoly is returned, so using different variables for ``x`` does + not affect the comparison or the polynomials: + + Parameters + ========== + + x : string, optional + Name for the "lambda" variable, defaults to "lambda". + + simplify : function, optional + Simplification function to use on the characteristic polynomial + calculated. Defaults to ``simplify``. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[1, 3], [2, 0]]) + >>> M.charpoly() + PurePoly(lambda**2 - lambda - 6, lambda, domain='ZZ') + >>> M.charpoly(x) == M.charpoly(y) + True + >>> M.charpoly(x) == M.charpoly(y) + True + + Specifying ``x`` is optional; a symbol named ``lambda`` is used by + default (which looks good when pretty-printed in unicode): + + >>> M.charpoly().as_expr() + lambda**2 - lambda - 6 + + And if ``x`` clashes with an existing symbol, underscores will + be prepended to the name to make it unique: + + >>> M = Matrix([[1, 2], [x, 0]]) + >>> M.charpoly(x).as_expr() + _x**2 - _x - 2*x + + Whether you pass a symbol or not, the generator can be obtained + with the gen attribute since it may not be the same as the symbol + that was passed: + + >>> M.charpoly(x).gen + _x + >>> M.charpoly(x).gen == x + False + + Notes + ===== + + The Samuelson-Berkowitz algorithm is used to compute + the characteristic polynomial efficiently and without any + division operations. Thus the characteristic polynomial over any + commutative ring without zero divisors can be computed. + + If the determinant det(x*I - M) can be found out easily as + in the case of an upper or a lower triangular matrix, then + instead of Samuelson-Berkowitz algorithm, eigenvalues are computed + and the characteristic polynomial with their help. + + See Also + ======== + + det + """ + + if not M.is_square: + raise NonSquareMatrixError() + + # Use DomainMatrix. We are already going to convert this to a Poly so there + # is no need to worry about expanding powers etc. Also since this algorithm + # does not require division or zero detection it is fine to use EX. + # + # M.to_DM() will fall back on EXRAW rather than EX. EXRAW is a lot faster + # for elementary arithmetic because it does not call cancel for each + # operation but it generates large unsimplified results that are slow in + # the subsequent call to simplify. Using EX instead is faster overall + # but at least in some cases EXRAW+simplify gives a simpler result so we + # preserve that existing behaviour of charpoly for now... + dM = M.to_DM() + + K = dM.domain + + cp = dM.charpoly() + + x = uniquely_named_symbol(x, [M], modify=lambda s: '_' + s) + + if K.is_EXRAW or simplify is not _simplify: + # XXX: Converting back to Expr is expensive. We only do it if the + # caller supplied a custom simplify function for backwards + # compatibility or otherwise if the domain was EX. For any other domain + # there should be no benefit in simplifying at this stage because Poly + # will put everything into canonical form anyway. + berk_vector = [K.to_sympy(c) for c in cp] + berk_vector = [simplify(a) for a in berk_vector] + p = PurePoly(berk_vector, x) + + else: + # Convert from the list of domain elements directly to Poly. + p = PurePoly(cp, x, domain=K) + + return p + + +def _cofactor(M, i, j, method="berkowitz"): + """Calculate the cofactor of an element. + + Parameters + ========== + + method : string, optional + Method to use to find the cofactors, can be "bareiss", "berkowitz", + "bird", "laplace" or "lu". + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> M.cofactor(0, 1) + -3 + + See Also + ======== + + cofactor_matrix + minor + minor_submatrix + """ + + if not M.is_square or M.rows < 1: + raise NonSquareMatrixError() + + return S.NegativeOne**((i + j) % 2) * M.minor(i, j, method) + + +def _cofactor_matrix(M, method="berkowitz"): + """Return a matrix containing the cofactor of each element. + + Parameters + ========== + + method : string, optional + Method to use to find the cofactors, can be "bareiss", "berkowitz", + "bird", "laplace" or "lu". + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> M.cofactor_matrix() + Matrix([ + [ 4, -3], + [-2, 1]]) + + See Also + ======== + + cofactor + minor + minor_submatrix + """ + + if not M.is_square: + raise NonSquareMatrixError() + + return M._new(M.rows, M.cols, + lambda i, j: M.cofactor(i, j, method)) + +def _per(M): + """Returns the permanent of a matrix. Unlike determinant, + permanent is defined for both square and non-square matrices. + + For an m x n matrix, with m less than or equal to n, + it is given as the sum over the permutations s of size + less than or equal to m on [1, 2, . . . n] of the product + from i = 1 to m of M[i, s[i]]. Taking the transpose will + not affect the value of the permanent. + + In the case of a square matrix, this is the same as the permutation + definition of the determinant, but it does not take the sign of the + permutation into account. Computing the permanent with this definition + is quite inefficient, so here the Ryser formula is used. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> M.per() + 450 + >>> M = Matrix([1, 5, 7]) + >>> M.per() + 13 + + References + ========== + + .. [1] Prof. Frank Ben's notes: https://math.berkeley.edu/~bernd/ban275.pdf + .. [2] Wikipedia article on Permanent: https://en.wikipedia.org/wiki/Permanent_%28mathematics%29 + .. [3] https://reference.wolfram.com/language/ref/Permanent.html + .. [4] Permanent of a rectangular matrix : https://arxiv.org/pdf/0904.3251.pdf + """ + import itertools + + m, n = M.shape + if m > n: + M = M.T + m, n = n, m + s = list(range(n)) + + subsets = [] + for i in range(1, m + 1): + subsets += list(map(list, itertools.combinations(s, i))) + + perm = 0 + for subset in subsets: + prod = 1 + sub_len = len(subset) + for i in range(m): + prod *= sum(M[i, j] for j in subset) + perm += prod * S.NegativeOne**sub_len * nC(n - sub_len, m - sub_len) + perm *= S.NegativeOne**m + return perm.simplify() + +def _det_DOM(M): + DOM = DomainMatrix.from_Matrix(M, field=True, extension=True) + K = DOM.domain + return K.to_sympy(DOM.det()) + +# This functions is a candidate for caching if it gets implemented for matrices. +def _det(M, method="bareiss", iszerofunc=None): + """Computes the determinant of a matrix if ``M`` is a concrete matrix object + otherwise return an expressions ``Determinant(M)`` if ``M`` is a + ``MatrixSymbol`` or other expression. + + Parameters + ========== + + method : string, optional + Specifies the algorithm used for computing the matrix determinant. + + If the matrix is at most 3x3, a hard-coded formula is used and the + specified method is ignored. Otherwise, it defaults to + ``'bareiss'``. + + Also, if the matrix is an upper or a lower triangular matrix, determinant + is computed by simple multiplication of diagonal elements, and the + specified method is ignored. + + If it is set to ``'domain-ge'``, then Gaussian elimination method will + be used via using DomainMatrix. + + If it is set to ``'bareiss'``, Bareiss' fraction-free algorithm will + be used. + + If it is set to ``'berkowitz'``, Berkowitz' algorithm will be used. + + If it is set to ``'bird'``, Bird's algorithm will be used [1]_. + + If it is set to ``'laplace'``, Laplace's algorithm will be used. + + Otherwise, if it is set to ``'lu'``, LU decomposition will be used. + + .. note:: + For backward compatibility, legacy keys like "bareis" and + "det_lu" can still be used to indicate the corresponding + methods. + And the keys are also case-insensitive for now. However, it is + suggested to use the precise keys for specifying the method. + + iszerofunc : FunctionType or None, optional + If it is set to ``None``, it will be defaulted to ``_iszero`` if the + method is set to ``'bareiss'``, and ``_is_zero_after_expand_mul`` if + the method is set to ``'lu'``. + + It can also accept any user-specified zero testing function, if it + is formatted as a function which accepts a single symbolic argument + and returns ``True`` if it is tested as zero and ``False`` if it + tested as non-zero, and also ``None`` if it is undecidable. + + Returns + ======= + + det : Basic + Result of determinant. + + Raises + ====== + + ValueError + If unrecognized keys are given for ``method`` or ``iszerofunc``. + + NonSquareMatrixError + If attempted to calculate determinant from a non-square matrix. + + Examples + ======== + + >>> from sympy import Matrix, eye, det + >>> I3 = eye(3) + >>> det(I3) + 1 + >>> M = Matrix([[1, 2], [3, 4]]) + >>> det(M) + -2 + >>> det(M) == M.det() + True + >>> M.det(method="domain-ge") + -2 + + References + ========== + + .. [1] Bird, R. S. (2011). A simple division-free algorithm for computing + determinants. Inf. Process. Lett., 111(21), 1072-1074. doi: + 10.1016/j.ipl.2011.08.006 + """ + + # sanitize `method` + method = method.lower() + + if method == "bareis": + method = "bareiss" + elif method == "det_lu": + method = "lu" + + if method not in ("bareiss", "berkowitz", "lu", "domain-ge", "bird", + "laplace"): + raise ValueError("Determinant method '%s' unrecognized" % method) + + if iszerofunc is None: + if method == "bareiss": + iszerofunc = _is_zero_after_expand_mul + elif method == "lu": + iszerofunc = _iszero + + elif not isinstance(iszerofunc, FunctionType): + raise ValueError("Zero testing method '%s' unrecognized" % iszerofunc) + + n = M.rows + + if n == M.cols: # square check is done in individual method functions + if n == 0: + return M.one + elif n == 1: + return M[0, 0] + elif n == 2: + m = M[0, 0] * M[1, 1] - M[0, 1] * M[1, 0] + return _get_intermediate_simp(_dotprodsimp)(m) + elif n == 3: + m = (M[0, 0] * M[1, 1] * M[2, 2] + + M[0, 1] * M[1, 2] * M[2, 0] + + M[0, 2] * M[1, 0] * M[2, 1] + - M[0, 2] * M[1, 1] * M[2, 0] + - M[0, 0] * M[1, 2] * M[2, 1] + - M[0, 1] * M[1, 0] * M[2, 2]) + return _get_intermediate_simp(_dotprodsimp)(m) + + dets = [] + for b in M.strongly_connected_components(): + if method == "domain-ge": # uses DomainMatrix to evaluate determinant + det = _det_DOM(M[b, b]) + elif method == "bareiss": + det = M[b, b]._eval_det_bareiss(iszerofunc=iszerofunc) + elif method == "berkowitz": + det = M[b, b]._eval_det_berkowitz() + elif method == "lu": + det = M[b, b]._eval_det_lu(iszerofunc=iszerofunc) + elif method == "bird": + det = M[b, b]._eval_det_bird() + elif method == "laplace": + det = M[b, b]._eval_det_laplace() + dets.append(det) + return Mul(*dets) + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _det_bareiss(M, iszerofunc=_is_zero_after_expand_mul): + """Compute matrix determinant using Bareiss' fraction-free + algorithm which is an extension of the well known Gaussian + elimination method. This approach is best suited for dense + symbolic matrices and will result in a determinant with + minimal number of fractions. It means that less term + rewriting is needed on resulting formulae. + + Parameters + ========== + + iszerofunc : function, optional + The function to use to determine zeros when doing an LU decomposition. + Defaults to ``lambda x: x.is_zero``. + + TODO: Implement algorithm for sparse matrices (SFF), + http://www.eecis.udel.edu/~saunders/papers/sffge/it5.ps. + """ + + # Recursively implemented Bareiss' algorithm as per Deanna Richelle Leggett's + # thesis http://www.math.usm.edu/perry/Research/Thesis_DRL.pdf + def bareiss(mat, cumm=1): + if mat.rows == 0: + return mat.one + elif mat.rows == 1: + return mat[0, 0] + + # find a pivot and extract the remaining matrix + # With the default iszerofunc, _find_reasonable_pivot slows down + # the computation by the factor of 2.5 in one test. + # Relevant issues: #10279 and #13877. + pivot_pos, pivot_val, _, _ = _find_reasonable_pivot(mat[:, 0], iszerofunc=iszerofunc) + if pivot_pos is None: + return mat.zero + + # if we have a valid pivot, we'll do a "row swap", so keep the + # sign of the det + sign = (-1) ** (pivot_pos % 2) + + # we want every row but the pivot row and every column + rows = [i for i in range(mat.rows) if i != pivot_pos] + cols = list(range(mat.cols)) + tmp_mat = mat.extract(rows, cols) + + def entry(i, j): + ret = (pivot_val*tmp_mat[i, j + 1] - mat[pivot_pos, j + 1]*tmp_mat[i, 0]) / cumm + if _get_intermediate_simp_bool(True): + return _dotprodsimp(ret) + elif not ret.is_Atom: + return cancel(ret) + return ret + + return sign*bareiss(M._new(mat.rows - 1, mat.cols - 1, entry), pivot_val) + + if not M.is_square: + raise NonSquareMatrixError() + + if M.rows == 0: + return M.one + # sympy/matrices/tests/test_matrices.py contains a test that + # suggests that the determinant of a 0 x 0 matrix is one, by + # convention. + + return bareiss(M) + + +def _det_berkowitz(M): + """ Use the Berkowitz algorithm to compute the determinant.""" + + if not M.is_square: + raise NonSquareMatrixError() + + if M.rows == 0: + return M.one + # sympy/matrices/tests/test_matrices.py contains a test that + # suggests that the determinant of a 0 x 0 matrix is one, by + # convention. + + berk_vector = _berkowitz_vector(M) + return (-1)**(len(berk_vector) - 1) * berk_vector[-1] + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _det_LU(M, iszerofunc=_iszero, simpfunc=None): + """ Computes the determinant of a matrix from its LU decomposition. + This function uses the LU decomposition computed by + LUDecomposition_Simple(). + + The keyword arguments iszerofunc and simpfunc are passed to + LUDecomposition_Simple(). + iszerofunc is a callable that returns a boolean indicating if its + input is zero, or None if it cannot make the determination. + simpfunc is a callable that simplifies its input. + The default is simpfunc=None, which indicate that the pivot search + algorithm should not attempt to simplify any candidate pivots. + If simpfunc fails to simplify its input, then it must return its input + instead of a copy. + + Parameters + ========== + + iszerofunc : function, optional + The function to use to determine zeros when doing an LU decomposition. + Defaults to ``lambda x: x.is_zero``. + + simpfunc : function, optional + The simplification function to use when looking for zeros for pivots. + """ + + if not M.is_square: + raise NonSquareMatrixError() + + if M.rows == 0: + return M.one + # sympy/matrices/tests/test_matrices.py contains a test that + # suggests that the determinant of a 0 x 0 matrix is one, by + # convention. + + lu, row_swaps = M.LUdecomposition_Simple(iszerofunc=iszerofunc, + simpfunc=simpfunc) + # P*A = L*U => det(A) = det(L)*det(U)/det(P) = det(P)*det(U). + # Lower triangular factor L encoded in lu has unit diagonal => det(L) = 1. + # P is a permutation matrix => det(P) in {-1, 1} => 1/det(P) = det(P). + # LUdecomposition_Simple() returns a list of row exchange index pairs, rather + # than a permutation matrix, but det(P) = (-1)**len(row_swaps). + + # Avoid forming the potentially time consuming product of U's diagonal entries + # if the product is zero. + # Bottom right entry of U is 0 => det(A) = 0. + # It may be impossible to determine if this entry of U is zero when it is symbolic. + if iszerofunc(lu[lu.rows-1, lu.rows-1]): + return M.zero + + # Compute det(P) + det = -M.one if len(row_swaps)%2 else M.one + + # Compute det(U) by calculating the product of U's diagonal entries. + # The upper triangular portion of lu is the upper triangular portion of the + # U factor in the LU decomposition. + for k in range(lu.rows): + det *= lu[k, k] + + # return det(P)*det(U) + return det + + +@cacheit +def __det_laplace(M): + """Compute the determinant of a matrix using Laplace expansion. + + This is a recursive function, and it should not be called directly. + Use _det_laplace() instead. The reason for splitting this function + into two is to allow caching of determinants of submatrices. While + one could also define this function inside _det_laplace(), that + would remove the advantage of using caching in Cramer Solve. + """ + n = M.shape[0] + if n == 1: + return M[0] + elif n == 2: + return M[0, 0] * M[1, 1] - M[0, 1] * M[1, 0] + else: + return sum((-1) ** i * M[0, i] * + __det_laplace(M.minor_submatrix(0, i)) for i in range(n)) + + +def _det_laplace(M): + """Compute the determinant of a matrix using Laplace expansion. + + While Laplace expansion is not the most efficient method of computing + a determinant, it is a simple one, and it has the advantage of + being division free. To improve efficiency, this function uses + caching to avoid recomputing determinants of submatrices. + """ + if not M.is_square: + raise NonSquareMatrixError() + if M.shape[0] == 0: + return M.one + # sympy/matrices/tests/test_matrices.py contains a test that + # suggests that the determinant of a 0 x 0 matrix is one, by + # convention. + return __det_laplace(M.as_immutable()) + + +def _det_bird(M): + r"""Compute the determinant of a matrix using Bird's algorithm. + + Bird's algorithm is a simple division-free algorithm for computing, which + is of lower order than the Laplace's algorithm. It is described in [1]_. + + References + ========== + + .. [1] Bird, R. S. (2011). A simple division-free algorithm for computing + determinants. Inf. Process. Lett., 111(21), 1072-1074. doi: + 10.1016/j.ipl.2011.08.006 + """ + def mu(X): + n = X.shape[0] + zero = X.domain.zero + + total = zero + diag_sums = [zero] + for i in reversed(range(1, n)): + total -= X[i][i] + diag_sums.append(total) + diag_sums = diag_sums[::-1] + + elems = [[zero] * i + [diag_sums[i]] + X_i[i + 1:] for i, X_i in + enumerate(X)] + return DDM(elems, X.shape, X.domain) + + Mddm = M._rep.to_ddm() + n = M.shape[0] + if n == 0: + return M.one + # sympy/matrices/tests/test_matrices.py contains a test that + # suggests that the determinant of a 0 x 0 matrix is one, by + # convention. + Fn1 = Mddm + for _ in range(n - 1): + Fn1 = mu(Fn1).matmul(Mddm) + detA = Fn1[0][0] + if n % 2 == 0: + detA = -detA + + return Mddm.domain.to_sympy(detA) + + +def _minor(M, i, j, method="berkowitz"): + """Return the (i,j) minor of ``M``. That is, + return the determinant of the matrix obtained by deleting + the `i`th row and `j`th column from ``M``. + + Parameters + ========== + + i, j : int + The row and column to exclude to obtain the submatrix. + + method : string, optional + Method to use to find the determinant of the submatrix, can be + "bareiss", "berkowitz", "bird", "laplace" or "lu". + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> M.minor(1, 1) + -12 + + See Also + ======== + + minor_submatrix + cofactor + det + """ + + if not M.is_square: + raise NonSquareMatrixError() + + return M.minor_submatrix(i, j).det(method=method) + + +def _minor_submatrix(M, i, j): + """Return the submatrix obtained by removing the `i`th row + and `j`th column from ``M`` (works with Pythonic negative indices). + + Parameters + ========== + + i, j : int + The row and column to exclude to obtain the submatrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> M.minor_submatrix(1, 1) + Matrix([ + [1, 3], + [7, 9]]) + + See Also + ======== + + minor + cofactor + """ + + if i < 0: + i += M.rows + if j < 0: + j += M.cols + + if not 0 <= i < M.rows or not 0 <= j < M.cols: + raise ValueError("`i` and `j` must satisfy 0 <= i < ``M.rows`` " + "(%d)" % M.rows + "and 0 <= j < ``M.cols`` (%d)." % M.cols) + + rows = [a for a in range(M.rows) if a != i] + cols = [a for a in range(M.cols) if a != j] + + return M.extract(rows, cols) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/eigen.py b/.venv/lib/python3.13/site-packages/sympy/matrices/eigen.py new file mode 100644 index 0000000000000000000000000000000000000000..87b2418efcece1c0b158ec56995bb011286feb3c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/eigen.py @@ -0,0 +1,1346 @@ +from types import FunctionType +from collections import Counter + +from mpmath import mp, workprec +from mpmath.libmp.libmpf import prec_to_dps + +from sympy.core.sorting import default_sort_key +from sympy.core.evalf import DEFAULT_MAXPREC, PrecisionExhausted +from sympy.core.logic import fuzzy_and, fuzzy_or +from sympy.core.numbers import Float +from sympy.core.sympify import _sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys import roots, CRootOf, ZZ, QQ, EX +from sympy.polys.matrices import DomainMatrix +from sympy.polys.matrices.eigen import dom_eigenvects, dom_eigenvects_to_sympy +from sympy.polys.polytools import gcd + +from .exceptions import MatrixError, NonSquareMatrixError +from .determinant import _find_reasonable_pivot + +from .utilities import _iszero, _simplify + + +__doctest_requires__ = { + ('_is_indefinite', + '_is_negative_definite', + '_is_negative_semidefinite', + '_is_positive_definite', + '_is_positive_semidefinite'): ['matplotlib'], +} + + +def _eigenvals_eigenvects_mpmath(M): + norm2 = lambda v: mp.sqrt(sum(i**2 for i in v)) + + v1 = None + prec = max(x._prec for x in M.atoms(Float)) + eps = 2**-prec + + while prec < DEFAULT_MAXPREC: + with workprec(prec): + A = mp.matrix(M.evalf(n=prec_to_dps(prec))) + E, ER = mp.eig(A) + v2 = norm2([i for e in E for i in (mp.re(e), mp.im(e))]) + if v1 is not None and mp.fabs(v1 - v2) < eps: + return E, ER + v1 = v2 + prec *= 2 + + # we get here because the next step would have taken us + # past MAXPREC or because we never took a step; in case + # of the latter, we refuse to send back a solution since + # it would not have been verified; we also resist taking + # a small step to arrive exactly at MAXPREC since then + # the two calculations might be artificially close. + raise PrecisionExhausted + + +def _eigenvals_mpmath(M, multiple=False): + """Compute eigenvalues using mpmath""" + E, _ = _eigenvals_eigenvects_mpmath(M) + result = [_sympify(x) for x in E] + if multiple: + return result + return dict(Counter(result)) + + +def _eigenvects_mpmath(M): + E, ER = _eigenvals_eigenvects_mpmath(M) + result = [] + for i in range(M.rows): + eigenval = _sympify(E[i]) + eigenvect = _sympify(ER[:, i]) + result.append((eigenval, 1, [eigenvect])) + + return result + + +# This function is a candidate for caching if it gets implemented for matrices. +def _eigenvals( + M, error_when_incomplete=True, *, simplify=False, multiple=False, + rational=False, **flags): + r"""Compute eigenvalues of the matrix. + + Parameters + ========== + + error_when_incomplete : bool, optional + If it is set to ``True``, it will raise an error if not all + eigenvalues are computed. This is caused by ``roots`` not returning + a full list of eigenvalues. + + simplify : bool or function, optional + If it is set to ``True``, it attempts to return the most + simplified form of expressions returned by applying default + simplification method in every routine. + + If it is set to ``False``, it will skip simplification in this + particular routine to save computation resources. + + If a function is passed to, it will attempt to apply + the particular function as simplification method. + + rational : bool, optional + If it is set to ``True``, every floating point numbers would be + replaced with rationals before computation. It can solve some + issues of ``roots`` routine not working well with floats. + + multiple : bool, optional + If it is set to ``True``, the result will be in the form of a + list. + + If it is set to ``False``, the result will be in the form of a + dictionary. + + Returns + ======= + + eigs : list or dict + Eigenvalues of a matrix. The return format would be specified by + the key ``multiple``. + + Raises + ====== + + MatrixError + If not enough roots had got computed. + + NonSquareMatrixError + If attempted to compute eigenvalues from a non-square matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix(3, 3, [0, 1, 1, 1, 0, 0, 1, 1, 1]) + >>> M.eigenvals() + {-1: 1, 0: 1, 2: 1} + + See Also + ======== + + MatrixBase.charpoly + eigenvects + + Notes + ===== + + Eigenvalues of a matrix $A$ can be computed by solving a matrix + equation $\det(A - \lambda I) = 0$ + + It's not always possible to return radical solutions for + eigenvalues for matrices larger than $4, 4$ shape due to + Abel-Ruffini theorem. + + If there is no radical solution is found for the eigenvalue, + it may return eigenvalues in the form of + :class:`sympy.polys.rootoftools.ComplexRootOf`. + """ + if not M: + if multiple: + return [] + return {} + + if not M.is_square: + raise NonSquareMatrixError("{} must be a square matrix.".format(M)) + + if M._rep.domain not in (ZZ, QQ): + # Skip this check for ZZ/QQ because it can be slow + if all(x.is_number for x in M) and M.has(Float): + return _eigenvals_mpmath(M, multiple=multiple) + + if rational: + from sympy.simplify import nsimplify + M = M.applyfunc( + lambda x: nsimplify(x, rational=True) if x.has(Float) else x) + + if multiple: + return _eigenvals_list( + M, error_when_incomplete=error_when_incomplete, simplify=simplify, + **flags) + return _eigenvals_dict( + M, error_when_incomplete=error_when_incomplete, simplify=simplify, + **flags) + + +eigenvals_error_message = \ +"It is not always possible to express the eigenvalues of a matrix " + \ +"of size 5x5 or higher in radicals. " + \ +"We have CRootOf, but domains other than the rationals are not " + \ +"currently supported. " + \ +"If there are no symbols in the matrix, " + \ +"it should still be possible to compute numeric approximations " + \ +"of the eigenvalues using " + \ +"M.evalf().eigenvals() or M.charpoly().nroots()." + + +def _eigenvals_list( + M, error_when_incomplete=True, simplify=False, **flags): + iblocks = M.strongly_connected_components() + all_eigs = [] + is_dom = M._rep.domain in (ZZ, QQ) + for b in iblocks: + + # Fast path for a 1x1 block: + if is_dom and len(b) == 1: + index = b[0] + val = M[index, index] + all_eigs.append(val) + continue + + block = M[b, b] + + if isinstance(simplify, FunctionType): + charpoly = block.charpoly(simplify=simplify) + else: + charpoly = block.charpoly() + + eigs = roots(charpoly, multiple=True, **flags) + + if len(eigs) != block.rows: + try: + eigs = charpoly.all_roots(multiple=True) + except NotImplementedError: + if error_when_incomplete: + raise MatrixError(eigenvals_error_message) + else: + eigs = [] + + all_eigs += eigs + + if not simplify: + return all_eigs + if not isinstance(simplify, FunctionType): + simplify = _simplify + return [simplify(value) for value in all_eigs] + + +def _eigenvals_dict( + M, error_when_incomplete=True, simplify=False, **flags): + iblocks = M.strongly_connected_components() + all_eigs = {} + is_dom = M._rep.domain in (ZZ, QQ) + for b in iblocks: + + # Fast path for a 1x1 block: + if is_dom and len(b) == 1: + index = b[0] + val = M[index, index] + all_eigs[val] = all_eigs.get(val, 0) + 1 + continue + + block = M[b, b] + + if isinstance(simplify, FunctionType): + charpoly = block.charpoly(simplify=simplify) + else: + charpoly = block.charpoly() + + eigs = roots(charpoly, multiple=False, **flags) + + if sum(eigs.values()) != block.rows: + try: + eigs = dict(charpoly.all_roots(multiple=False)) + except NotImplementedError: + if error_when_incomplete: + raise MatrixError(eigenvals_error_message) + else: + eigs = {} + + for k, v in eigs.items(): + if k in all_eigs: + all_eigs[k] += v + else: + all_eigs[k] = v + + if not simplify: + return all_eigs + if not isinstance(simplify, FunctionType): + simplify = _simplify + return {simplify(key): value for key, value in all_eigs.items()} + + +def _eigenspace(M, eigenval, iszerofunc=_iszero, simplify=False): + """Get a basis for the eigenspace for a particular eigenvalue""" + m = M - M.eye(M.rows) * eigenval + ret = m.nullspace(iszerofunc=iszerofunc) + + # The nullspace for a real eigenvalue should be non-trivial. + # If we didn't find an eigenvector, try once more a little harder + if len(ret) == 0 and simplify: + ret = m.nullspace(iszerofunc=iszerofunc, simplify=True) + if len(ret) == 0: + raise NotImplementedError( + "Can't evaluate eigenvector for eigenvalue {}".format(eigenval)) + return ret + + +def _eigenvects_DOM(M, **kwargs): + DOM = DomainMatrix.from_Matrix(M, field=True, extension=True) + DOM = DOM.to_dense() + + if DOM.domain != EX: + rational, algebraic = dom_eigenvects(DOM) + eigenvects = dom_eigenvects_to_sympy( + rational, algebraic, M.__class__, **kwargs) + eigenvects = sorted(eigenvects, key=lambda x: default_sort_key(x[0])) + + return eigenvects + return None + + +def _eigenvects_sympy(M, iszerofunc, simplify=True, **flags): + eigenvals = M.eigenvals(rational=False, **flags) + + # Make sure that we have all roots in radical form + for x in eigenvals: + if x.has(CRootOf): + raise MatrixError( + "Eigenvector computation is not implemented if the matrix have " + "eigenvalues in CRootOf form") + + eigenvals = sorted(eigenvals.items(), key=default_sort_key) + ret = [] + for val, mult in eigenvals: + vects = _eigenspace(M, val, iszerofunc=iszerofunc, simplify=simplify) + ret.append((val, mult, vects)) + return ret + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _eigenvects(M, error_when_incomplete=True, iszerofunc=_iszero, *, chop=False, **flags): + """Compute eigenvectors of the matrix. + + Parameters + ========== + + error_when_incomplete : bool, optional + Raise an error when not all eigenvalues are computed. This is + caused by ``roots`` not returning a full list of eigenvalues. + + iszerofunc : function, optional + Specifies a zero testing function to be used in ``rref``. + + Default value is ``_iszero``, which uses SymPy's naive and fast + default assumption handler. + + It can also accept any user-specified zero testing function, if it + is formatted as a function which accepts a single symbolic argument + and returns ``True`` if it is tested as zero and ``False`` if it + is tested as non-zero, and ``None`` if it is undecidable. + + simplify : bool or function, optional + If ``True``, ``as_content_primitive()`` will be used to tidy up + normalization artifacts. + + It will also be used by the ``nullspace`` routine. + + chop : bool or positive number, optional + If the matrix contains any Floats, they will be changed to Rationals + for computation purposes, but the answers will be returned after + being evaluated with evalf. The ``chop`` flag is passed to ``evalf``. + When ``chop=True`` a default precision will be used; a number will + be interpreted as the desired level of precision. + + Returns + ======= + + ret : [(eigenval, multiplicity, eigenspace), ...] + A ragged list containing tuples of data obtained by ``eigenvals`` + and ``nullspace``. + + ``eigenspace`` is a list containing the ``eigenvector`` for each + eigenvalue. + + ``eigenvector`` is a vector in the form of a ``Matrix``. e.g. + a vector of length 3 is returned as ``Matrix([a_1, a_2, a_3])``. + + Raises + ====== + + NotImplementedError + If failed to compute nullspace. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix(3, 3, [0, 1, 1, 1, 0, 0, 1, 1, 1]) + >>> M.eigenvects() + [(-1, 1, [Matrix([ + [-1], + [ 1], + [ 0]])]), (0, 1, [Matrix([ + [ 0], + [-1], + [ 1]])]), (2, 1, [Matrix([ + [2/3], + [1/3], + [ 1]])])] + + See Also + ======== + + eigenvals + MatrixBase.nullspace + """ + simplify = flags.get('simplify', True) + primitive = flags.get('simplify', False) + flags.pop('simplify', None) # remove this if it's there + flags.pop('multiple', None) # remove this if it's there + + if not isinstance(simplify, FunctionType): + simpfunc = _simplify if simplify else lambda x: x + + has_floats = M.has(Float) + if has_floats: + if all(x.is_number for x in M): + return _eigenvects_mpmath(M) + from sympy.simplify import nsimplify + M = M.applyfunc(lambda x: nsimplify(x, rational=True)) + + ret = _eigenvects_DOM(M) + if ret is None: + ret = _eigenvects_sympy(M, iszerofunc, simplify=simplify, **flags) + + if primitive: + # if the primitive flag is set, get rid of any common + # integer denominators + def denom_clean(l): + return [(v / gcd(list(v))).applyfunc(simpfunc) for v in l] + + ret = [(val, mult, denom_clean(es)) for val, mult, es in ret] + + if has_floats: + # if we had floats to start with, turn the eigenvectors to floats + ret = [(val.evalf(chop=chop), mult, [v.evalf(chop=chop) for v in es]) + for val, mult, es in ret] + + return ret + + +def _is_diagonalizable_with_eigen(M, reals_only=False): + """See _is_diagonalizable. This function returns the bool along with the + eigenvectors to avoid calculating them again in functions like + ``diagonalize``.""" + + if not M.is_square: + return False, [] + + eigenvecs = M.eigenvects(simplify=True) + + for val, mult, basis in eigenvecs: + if reals_only and not val.is_real: # if we have a complex eigenvalue + return False, eigenvecs + + if mult != len(basis): # if the geometric multiplicity doesn't equal the algebraic + return False, eigenvecs + + return True, eigenvecs + +def _is_diagonalizable(M, reals_only=False, **kwargs): + """Returns ``True`` if a matrix is diagonalizable. + + Parameters + ========== + + reals_only : bool, optional + If ``True``, it tests whether the matrix can be diagonalized + to contain only real numbers on the diagonal. + + + If ``False``, it tests whether the matrix can be diagonalized + at all, even with numbers that may not be real. + + Examples + ======== + + Example of a diagonalizable matrix: + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 0], [0, 3, 0], [2, -4, 2]]) + >>> M.is_diagonalizable() + True + + Example of a non-diagonalizable matrix: + + >>> M = Matrix([[0, 1], [0, 0]]) + >>> M.is_diagonalizable() + False + + Example of a matrix that is diagonalized in terms of non-real entries: + + >>> M = Matrix([[0, 1], [-1, 0]]) + >>> M.is_diagonalizable(reals_only=False) + True + >>> M.is_diagonalizable(reals_only=True) + False + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.is_diagonal + diagonalize + """ + if not M.is_square: + return False + + if all(e.is_real for e in M) and M.is_symmetric(): + return True + + if all(e.is_complex for e in M) and M.is_hermitian: + return True + + return _is_diagonalizable_with_eigen(M, reals_only=reals_only)[0] + + +#G&VL, Matrix Computations, Algo 5.4.2 +def _householder_vector(x): + if not x.cols == 1: + raise ValueError("Input must be a column matrix") + v = x.copy() + v_plus = x.copy() + v_minus = x.copy() + q = x[0, 0] / abs(x[0, 0]) + norm_x = x.norm() + v_plus[0, 0] = x[0, 0] + q * norm_x + v_minus[0, 0] = x[0, 0] - q * norm_x + if x[1:, 0].norm() == 0: + bet = 0 + v[0, 0] = 1 + else: + if v_plus.norm() <= v_minus.norm(): + v = v_plus + else: + v = v_minus + v = v / v[0] + bet = 2 / (v.norm() ** 2) + return v, bet + + +def _bidiagonal_decmp_hholder(M): + m = M.rows + n = M.cols + A = M.as_mutable() + U, V = A.eye(m), A.eye(n) + for i in range(min(m, n)): + v, bet = _householder_vector(A[i:, i]) + hh_mat = A.eye(m - i) - bet * v * v.H + A[i:, i:] = hh_mat * A[i:, i:] + temp = A.eye(m) + temp[i:, i:] = hh_mat + U = U * temp + if i + 1 <= n - 2: + v, bet = _householder_vector(A[i, i+1:].T) + hh_mat = A.eye(n - i - 1) - bet * v * v.H + A[i:, i+1:] = A[i:, i+1:] * hh_mat + temp = A.eye(n) + temp[i+1:, i+1:] = hh_mat + V = temp * V + return U, A, V + + +def _eval_bidiag_hholder(M): + m = M.rows + n = M.cols + A = M.as_mutable() + for i in range(min(m, n)): + v, bet = _householder_vector(A[i:, i]) + hh_mat = A.eye(m-i) - bet * v * v.H + A[i:, i:] = hh_mat * A[i:, i:] + if i + 1 <= n - 2: + v, bet = _householder_vector(A[i, i+1:].T) + hh_mat = A.eye(n - i - 1) - bet * v * v.H + A[i:, i+1:] = A[i:, i+1:] * hh_mat + return A + + +def _bidiagonal_decomposition(M, upper=True): + """ + Returns $(U,B,V.H)$ for + + $$A = UBV^{H}$$ + + where $A$ is the input matrix, and $B$ is its Bidiagonalized form + + Note: Bidiagonal Computation can hang for symbolic matrices. + + Parameters + ========== + + upper : bool. Whether to do upper bidiagnalization or lower. + True for upper and False for lower. + + References + ========== + + .. [1] Algorithm 5.4.2, Matrix computations by Golub and Van Loan, 4th edition + .. [2] Complex Matrix Bidiagonalization, https://github.com/vslobody/Householder-Bidiagonalization + + """ + + if not isinstance(upper, bool): + raise ValueError("upper must be a boolean") + + if upper: + return _bidiagonal_decmp_hholder(M) + + X = _bidiagonal_decmp_hholder(M.H) + return X[2].H, X[1].H, X[0].H + + +def _bidiagonalize(M, upper=True): + """ + Returns $B$, the Bidiagonalized form of the input matrix. + + Note: Bidiagonal Computation can hang for symbolic matrices. + + Parameters + ========== + + upper : bool. Whether to do upper bidiagnalization or lower. + True for upper and False for lower. + + References + ========== + + .. [1] Algorithm 5.4.2, Matrix computations by Golub and Van Loan, 4th edition + .. [2] Complex Matrix Bidiagonalization : https://github.com/vslobody/Householder-Bidiagonalization + + """ + + if not isinstance(upper, bool): + raise ValueError("upper must be a boolean") + + if upper: + return _eval_bidiag_hholder(M) + return _eval_bidiag_hholder(M.H).H + + +def _diagonalize(M, reals_only=False, sort=False, normalize=False): + """ + Return (P, D), where D is diagonal and + + D = P^-1 * M * P + + where M is current matrix. + + Parameters + ========== + + reals_only : bool. Whether to throw an error if complex numbers are need + to diagonalize. (Default: False) + + sort : bool. Sort the eigenvalues along the diagonal. (Default: False) + + normalize : bool. If True, normalize the columns of P. (Default: False) + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix(3, 3, [1, 2, 0, 0, 3, 0, 2, -4, 2]) + >>> M + Matrix([ + [1, 2, 0], + [0, 3, 0], + [2, -4, 2]]) + >>> (P, D) = M.diagonalize() + >>> D + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + >>> P + Matrix([ + [-1, 0, -1], + [ 0, 0, -1], + [ 2, 1, 2]]) + >>> P.inv() * M * P + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.is_diagonal + is_diagonalizable + """ + + if not M.is_square: + raise NonSquareMatrixError() + + is_diagonalizable, eigenvecs = _is_diagonalizable_with_eigen(M, + reals_only=reals_only) + + if not is_diagonalizable: + raise MatrixError("Matrix is not diagonalizable") + + if sort: + eigenvecs = sorted(eigenvecs, key=default_sort_key) + + p_cols, diag = [], [] + + for val, mult, basis in eigenvecs: + diag += [val] * mult + p_cols += basis + + if normalize: + p_cols = [v / v.norm() for v in p_cols] + + return M.hstack(*p_cols), M.diag(*diag) + + +def _fuzzy_positive_definite(M): + positive_diagonals = M._has_positive_diagonals() + if positive_diagonals is False: + return False + + if positive_diagonals and M.is_strongly_diagonally_dominant: + return True + + return None + + +def _fuzzy_positive_semidefinite(M): + nonnegative_diagonals = M._has_nonnegative_diagonals() + if nonnegative_diagonals is False: + return False + + if nonnegative_diagonals and M.is_weakly_diagonally_dominant: + return True + + return None + + +def _is_positive_definite(M): + if not M.is_hermitian: + if not M.is_square: + return False + M = M + M.H + + fuzzy = _fuzzy_positive_definite(M) + if fuzzy is not None: + return fuzzy + + return _is_positive_definite_GE(M) + + +def _is_positive_semidefinite(M): + if not M.is_hermitian: + if not M.is_square: + return False + M = M + M.H + + fuzzy = _fuzzy_positive_semidefinite(M) + if fuzzy is not None: + return fuzzy + + return _is_positive_semidefinite_cholesky(M) + + +def _is_negative_definite(M): + return _is_positive_definite(-M) + + +def _is_negative_semidefinite(M): + return _is_positive_semidefinite(-M) + + +def _is_indefinite(M): + if M.is_hermitian: + eigen = M.eigenvals() + args1 = [x.is_positive for x in eigen.keys()] + any_positive = fuzzy_or(args1) + args2 = [x.is_negative for x in eigen.keys()] + any_negative = fuzzy_or(args2) + + return fuzzy_and([any_positive, any_negative]) + + elif M.is_square: + return (M + M.H).is_indefinite + + return False + + +def _is_positive_definite_GE(M): + """A division-free gaussian elimination method for testing + positive-definiteness.""" + M = M.as_mutable() + size = M.rows + + for i in range(size): + is_positive = M[i, i].is_positive + if is_positive is not True: + return is_positive + for j in range(i+1, size): + M[j, i+1:] = M[i, i] * M[j, i+1:] - M[j, i] * M[i, i+1:] + return True + + +def _is_positive_semidefinite_cholesky(M): + """Uses Cholesky factorization with complete pivoting + + References + ========== + + .. [1] http://eprints.ma.man.ac.uk/1199/1/covered/MIMS_ep2008_116.pdf + + .. [2] https://www.value-at-risk.net/cholesky-factorization/ + """ + M = M.as_mutable() + for k in range(M.rows): + diags = [M[i, i] for i in range(k, M.rows)] + pivot, pivot_val, nonzero, _ = _find_reasonable_pivot(diags) + + if nonzero: + return None + + if pivot is None: + for i in range(k+1, M.rows): + for j in range(k, M.cols): + iszero = M[i, j].is_zero + if iszero is None: + return None + elif iszero is False: + return False + return True + + if M[k, k].is_negative or pivot_val.is_negative: + return False + elif not (M[k, k].is_nonnegative and pivot_val.is_nonnegative): + return None + + if pivot > 0: + M.col_swap(k, k+pivot) + M.row_swap(k, k+pivot) + + M[k, k] = sqrt(M[k, k]) + M[k, k+1:] /= M[k, k] + M[k+1:, k+1:] -= M[k, k+1:].H * M[k, k+1:] + + return M[-1, -1].is_nonnegative + + +_doc_positive_definite = \ + r"""Finds out the definiteness of a matrix. + + Explanation + =========== + + A square real matrix $A$ is: + + - A positive definite matrix if $x^T A x > 0$ + for all non-zero real vectors $x$. + - A positive semidefinite matrix if $x^T A x \geq 0$ + for all non-zero real vectors $x$. + - A negative definite matrix if $x^T A x < 0$ + for all non-zero real vectors $x$. + - A negative semidefinite matrix if $x^T A x \leq 0$ + for all non-zero real vectors $x$. + - An indefinite matrix if there exists non-zero real vectors + $x, y$ with $x^T A x > 0 > y^T A y$. + + A square complex matrix $A$ is: + + - A positive definite matrix if $\text{re}(x^H A x) > 0$ + for all non-zero complex vectors $x$. + - A positive semidefinite matrix if $\text{re}(x^H A x) \geq 0$ + for all non-zero complex vectors $x$. + - A negative definite matrix if $\text{re}(x^H A x) < 0$ + for all non-zero complex vectors $x$. + - A negative semidefinite matrix if $\text{re}(x^H A x) \leq 0$ + for all non-zero complex vectors $x$. + - An indefinite matrix if there exists non-zero complex vectors + $x, y$ with $\text{re}(x^H A x) > 0 > \text{re}(y^H A y)$. + + A matrix need not be symmetric or hermitian to be positive definite. + + - A real non-symmetric matrix is positive definite if and only if + $\frac{A + A^T}{2}$ is positive definite. + - A complex non-hermitian matrix is positive definite if and only if + $\frac{A + A^H}{2}$ is positive definite. + + And this extension can apply for all the definitions above. + + However, for complex cases, you can restrict the definition of + $\text{re}(x^H A x) > 0$ to $x^H A x > 0$ and require the matrix + to be hermitian. + But we do not present this restriction for computation because you + can check ``M.is_hermitian`` independently with this and use + the same procedure. + + Examples + ======== + + An example of symmetric positive definite matrix: + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import Matrix, symbols + >>> from sympy.plotting import plot3d + >>> a, b = symbols('a b') + >>> x = Matrix([a, b]) + + >>> A = Matrix([[1, 0], [0, 1]]) + >>> A.is_positive_definite + True + >>> A.is_positive_semidefinite + True + + >>> p = plot3d((x.T*A*x)[0, 0], (a, -1, 1), (b, -1, 1)) + + An example of symmetric positive semidefinite matrix: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> A = Matrix([[1, -1], [-1, 1]]) + >>> A.is_positive_definite + False + >>> A.is_positive_semidefinite + True + + >>> p = plot3d((x.T*A*x)[0, 0], (a, -1, 1), (b, -1, 1)) + + An example of symmetric negative definite matrix: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> A = Matrix([[-1, 0], [0, -1]]) + >>> A.is_negative_definite + True + >>> A.is_negative_semidefinite + True + >>> A.is_indefinite + False + + >>> p = plot3d((x.T*A*x)[0, 0], (a, -1, 1), (b, -1, 1)) + + An example of symmetric indefinite matrix: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> A = Matrix([[1, 2], [2, -1]]) + >>> A.is_indefinite + True + + >>> p = plot3d((x.T*A*x)[0, 0], (a, -1, 1), (b, -1, 1)) + + An example of non-symmetric positive definite matrix. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> A = Matrix([[1, 2], [-2, 1]]) + >>> A.is_positive_definite + True + >>> A.is_positive_semidefinite + True + + >>> p = plot3d((x.T*A*x)[0, 0], (a, -1, 1), (b, -1, 1)) + + Notes + ===== + + Although some people trivialize the definition of positive definite + matrices only for symmetric or hermitian matrices, this restriction + is not correct because it does not classify all instances of + positive definite matrices from the definition $x^T A x > 0$ or + $\text{re}(x^H A x) > 0$. + + For instance, ``Matrix([[1, 2], [-2, 1]])`` presented in + the example above is an example of real positive definite matrix + that is not symmetric. + + However, since the following formula holds true; + + .. math:: + \text{re}(x^H A x) > 0 \iff + \text{re}(x^H \frac{A + A^H}{2} x) > 0 + + We can classify all positive definite matrices that may or may not + be symmetric or hermitian by transforming the matrix to + $\frac{A + A^T}{2}$ or $\frac{A + A^H}{2}$ + (which is guaranteed to be always real symmetric or complex + hermitian) and we can defer most of the studies to symmetric or + hermitian positive definite matrices. + + But it is a different problem for the existence of Cholesky + decomposition. Because even though a non symmetric or a non + hermitian matrix can be positive definite, Cholesky or LDL + decomposition does not exist because the decompositions require the + matrix to be symmetric or hermitian. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Eigenvalues + + .. [2] https://mathworld.wolfram.com/PositiveDefiniteMatrix.html + + .. [3] Johnson, C. R. "Positive Definite Matrices." Amer. + Math. Monthly 77, 259-264 1970. + """ + +_is_positive_definite.__doc__ = _doc_positive_definite +_is_positive_semidefinite.__doc__ = _doc_positive_definite +_is_negative_definite.__doc__ = _doc_positive_definite +_is_negative_semidefinite.__doc__ = _doc_positive_definite +_is_indefinite.__doc__ = _doc_positive_definite + + +def _jordan_form(M, calc_transform=True, *, chop=False): + """Return $(P, J)$ where $J$ is a Jordan block + matrix and $P$ is a matrix such that $M = P J P^{-1}$ + + Parameters + ========== + + calc_transform : bool + If ``False``, then only $J$ is returned. + + chop : bool + All matrices are converted to exact types when computing + eigenvalues and eigenvectors. As a result, there may be + approximation errors. If ``chop==True``, these errors + will be truncated. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[ 6, 5, -2, -3], [-3, -1, 3, 3], [ 2, 1, -2, -3], [-1, 1, 5, 5]]) + >>> P, J = M.jordan_form() + >>> J + Matrix([ + [2, 1, 0, 0], + [0, 2, 0, 0], + [0, 0, 2, 1], + [0, 0, 0, 2]]) + + See Also + ======== + + jordan_block + """ + + if not M.is_square: + raise NonSquareMatrixError("Only square matrices have Jordan forms") + + mat = M + has_floats = M.has(Float) + + if has_floats: + try: + max_prec = max(term._prec for term in M.values() if isinstance(term, Float)) + except ValueError: + # if no term in the matrix is explicitly a Float calling max() + # will throw a error so setting max_prec to default value of 53 + max_prec = 53 + + # setting minimum max_dps to 15 to prevent loss of precision in + # matrix containing non evaluated expressions + max_dps = max(prec_to_dps(max_prec), 15) + + def restore_floats(*args): + """If ``has_floats`` is `True`, cast all ``args`` as + matrices of floats.""" + + if has_floats: + args = [m.evalf(n=max_dps, chop=chop) for m in args] + if len(args) == 1: + return args[0] + + return args + + # cache calculations for some speedup + mat_cache = {} + + def eig_mat(val, pow): + """Cache computations of ``(M - val*I)**pow`` for quick + retrieval""" + + if (val, pow) in mat_cache: + return mat_cache[(val, pow)] + + if (val, pow - 1) in mat_cache: + mat_cache[(val, pow)] = mat_cache[(val, pow - 1)].multiply( + mat_cache[(val, 1)], dotprodsimp=None) + else: + mat_cache[(val, pow)] = (mat - val*M.eye(M.rows)).pow(pow) + + return mat_cache[(val, pow)] + + # helper functions + def nullity_chain(val, algebraic_multiplicity): + """Calculate the sequence [0, nullity(E), nullity(E**2), ...] + until it is constant where ``E = M - val*I``""" + + # mat.rank() is faster than computing the null space, + # so use the rank-nullity theorem + cols = M.cols + ret = [0] + nullity = cols - eig_mat(val, 1).rank() + i = 2 + + while nullity != ret[-1]: + ret.append(nullity) + + if nullity == algebraic_multiplicity: + break + + nullity = cols - eig_mat(val, i).rank() + i += 1 + + # Due to issues like #7146 and #15872, SymPy sometimes + # gives the wrong rank. In this case, raise an error + # instead of returning an incorrect matrix + if nullity < ret[-1] or nullity > algebraic_multiplicity: + raise MatrixError( + "SymPy had encountered an inconsistent " + "result while computing Jordan block: " + "{}".format(M)) + + return ret + + def blocks_from_nullity_chain(d): + """Return a list of the size of each Jordan block. + If d_n is the nullity of E**n, then the number + of Jordan blocks of size n is + + 2*d_n - d_(n-1) - d_(n+1)""" + + # d[0] is always the number of columns, so skip past it + mid = [2*d[n] - d[n - 1] - d[n + 1] for n in range(1, len(d) - 1)] + # d is assumed to plateau with "d[ len(d) ] == d[-1]", so + # 2*d_n - d_(n-1) - d_(n+1) == d_n - d_(n-1) + end = [d[-1] - d[-2]] if len(d) > 1 else [d[0]] + + return mid + end + + def pick_vec(small_basis, big_basis): + """Picks a vector from big_basis that isn't in + the subspace spanned by small_basis""" + + if len(small_basis) == 0: + return big_basis[0] + + for v in big_basis: + _, pivots = M.hstack(*(small_basis + [v])).echelon_form( + with_pivots=True) + + if pivots[-1] == len(small_basis): + return v + + # roots doesn't like Floats, so replace them with Rationals + if has_floats: + from sympy.simplify import nsimplify + mat = mat.applyfunc(lambda x: nsimplify(x, rational=True)) + + # first calculate the jordan block structure + eigs = mat.eigenvals() + + # Make sure that we have all roots in radical form + for x in eigs: + if x.has(CRootOf): + raise MatrixError( + "Jordan normal form is not implemented if the matrix have " + "eigenvalues in CRootOf form") + + # most matrices have distinct eigenvalues + # and so are diagonalizable. In this case, don't + # do extra work! + if len(eigs.keys()) == mat.cols: + blocks = sorted(eigs.keys(), key=default_sort_key) + jordan_mat = mat.diag(*blocks) + + if not calc_transform: + return restore_floats(jordan_mat) + + jordan_basis = [eig_mat(eig, 1).nullspace()[0] + for eig in blocks] + basis_mat = mat.hstack(*jordan_basis) + + return restore_floats(basis_mat, jordan_mat) + + block_structure = [] + + for eig in sorted(eigs.keys(), key=default_sort_key): + algebraic_multiplicity = eigs[eig] + chain = nullity_chain(eig, algebraic_multiplicity) + block_sizes = blocks_from_nullity_chain(chain) + + # if block_sizes = = [a, b, c, ...], then the number of + # Jordan blocks of size 1 is a, of size 2 is b, etc. + # create an array that has (eig, block_size) with one + # entry for each block + size_nums = [(i+1, num) for i, num in enumerate(block_sizes)] + + # we expect larger Jordan blocks to come earlier + size_nums.reverse() + + block_structure.extend( + [(eig, size) for size, num in size_nums for _ in range(num)]) + + jordan_form_size = sum(size for eig, size in block_structure) + + if jordan_form_size != M.rows: + raise MatrixError( + "SymPy had encountered an inconsistent result while " + "computing Jordan block. : {}".format(M)) + + blocks = (mat.jordan_block(size=size, eigenvalue=eig) for eig, size in block_structure) + jordan_mat = mat.diag(*blocks) + + if not calc_transform: + return restore_floats(jordan_mat) + + # For each generalized eigenspace, calculate a basis. + # We start by looking for a vector in null( (A - eig*I)**n ) + # which isn't in null( (A - eig*I)**(n-1) ) where n is + # the size of the Jordan block + # + # Ideally we'd just loop through block_structure and + # compute each generalized eigenspace. However, this + # causes a lot of unneeded computation. Instead, we + # go through the eigenvalues separately, since we know + # their generalized eigenspaces must have bases that + # are linearly independent. + jordan_basis = [] + + for eig in sorted(eigs.keys(), key=default_sort_key): + eig_basis = [] + + for block_eig, size in block_structure: + if block_eig != eig: + continue + + null_big = (eig_mat(eig, size)).nullspace() + null_small = (eig_mat(eig, size - 1)).nullspace() + + # we want to pick something that is in the big basis + # and not the small, but also something that is independent + # of any other generalized eigenvectors from a different + # generalized eigenspace sharing the same eigenvalue. + vec = pick_vec(null_small + eig_basis, null_big) + new_vecs = [eig_mat(eig, i).multiply(vec, dotprodsimp=None) + for i in range(size)] + + eig_basis.extend(new_vecs) + jordan_basis.extend(reversed(new_vecs)) + + basis_mat = mat.hstack(*jordan_basis) + + return restore_floats(basis_mat, jordan_mat) + + +def _left_eigenvects(M, **flags): + """Returns left eigenvectors and eigenvalues. + + This function returns the list of triples (eigenval, multiplicity, + basis) for the left eigenvectors. Options are the same as for + eigenvects(), i.e. the ``**flags`` arguments gets passed directly to + eigenvects(). + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[0, 1, 1], [1, 0, 0], [1, 1, 1]]) + >>> M.eigenvects() + [(-1, 1, [Matrix([ + [-1], + [ 1], + [ 0]])]), (0, 1, [Matrix([ + [ 0], + [-1], + [ 1]])]), (2, 1, [Matrix([ + [2/3], + [1/3], + [ 1]])])] + >>> M.left_eigenvects() + [(-1, 1, [Matrix([[-2, 1, 1]])]), (0, 1, [Matrix([[-1, -1, 1]])]), (2, + 1, [Matrix([[1, 1, 1]])])] + + """ + + eigs = M.transpose().eigenvects(**flags) + + return [(val, mult, [l.transpose() for l in basis]) for val, mult, basis in eigs] + + +def _singular_values(M): + """Compute the singular values of a Matrix + + Examples + ======== + + >>> from sympy import Matrix, Symbol + >>> x = Symbol('x', real=True) + >>> M = Matrix([[0, 1, 0], [0, x, 0], [-1, 0, 0]]) + >>> M.singular_values() + [sqrt(x**2 + 1), 1, 0] + + See Also + ======== + + condition_number + """ + + if M.rows >= M.cols: + valmultpairs = M.H.multiply(M).eigenvals() + else: + valmultpairs = M.multiply(M.H).eigenvals() + + # Expands result from eigenvals into a simple list + vals = [] + + for k, v in valmultpairs.items(): + vals += [sqrt(k)] * v # dangerous! same k in several spots! + + # Pad with zeros if singular values are computed in reverse way, + # to give consistent format. + if len(vals) < M.cols: + vals += [M.zero] * (M.cols - len(vals)) + + # sort them in descending order + vals.sort(reverse=True, key=default_sort_key) + + return vals diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/exceptions.py b/.venv/lib/python3.13/site-packages/sympy/matrices/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc7cfa0bdffd59ff2bc5a9cd85cf9b04ed1a63d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/exceptions.py @@ -0,0 +1,26 @@ +""" +Exceptions raised by the matrix module. +""" + + +class MatrixError(Exception): + pass + + +class ShapeError(ValueError, MatrixError): + """Wrong matrix shape""" + pass + + +class NonSquareMatrixError(ShapeError): + pass + + +class NonInvertibleMatrixError(ValueError, MatrixError): + """The matrix in not invertible (division by multidimensional zero error).""" + pass + + +class NonPositiveDefiniteMatrixError(ValueError, MatrixError): + """The matrix is not a positive-definite matrix.""" + pass diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/__init__.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4ab203ab74165d1003cdedd83945ea3fcf8f47 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/__init__.py @@ -0,0 +1,62 @@ +""" A module which handles Matrix Expressions """ + +from .slice import MatrixSlice +from .blockmatrix import BlockMatrix, BlockDiagMatrix, block_collapse, blockcut +from .companion import CompanionMatrix +from .funcmatrix import FunctionMatrix +from .inverse import Inverse +from .matadd import MatAdd +from .matexpr import MatrixExpr, MatrixSymbol, matrix_symbols +from .matmul import MatMul +from .matpow import MatPow +from .trace import Trace, trace +from .determinant import Determinant, det, Permanent, per +from .transpose import Transpose +from .adjoint import Adjoint +from .hadamard import hadamard_product, HadamardProduct, hadamard_power, HadamardPower +from .diagonal import DiagonalMatrix, DiagonalOf, DiagMatrix, diagonalize_vector +from .dotproduct import DotProduct +from .kronecker import kronecker_product, KroneckerProduct, combine_kronecker +from .permutation import PermutationMatrix, MatrixPermute +from .sets import MatrixSet +from .special import ZeroMatrix, Identity, OneMatrix + +__all__ = [ + 'MatrixSlice', + + 'BlockMatrix', 'BlockDiagMatrix', 'block_collapse', 'blockcut', + 'FunctionMatrix', + + 'CompanionMatrix', + + 'Inverse', + + 'MatAdd', + + 'Identity', 'MatrixExpr', 'MatrixSymbol', 'ZeroMatrix', 'OneMatrix', + 'matrix_symbols', 'MatrixSet', + + 'MatMul', + + 'MatPow', + + 'Trace', 'trace', + + 'Determinant', 'det', + + 'Transpose', + + 'Adjoint', + + 'hadamard_product', 'HadamardProduct', 'hadamard_power', 'HadamardPower', + + 'DiagonalMatrix', 'DiagonalOf', 'DiagMatrix', 'diagonalize_vector', + + 'DotProduct', + + 'kronecker_product', 'KroneckerProduct', 'combine_kronecker', + + 'PermutationMatrix', 'MatrixPermute', + + 'Permanent', 'per' +] diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/_shape.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..a95d481bf8e1edf4c62992044cd50563b335caac --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/_shape.py @@ -0,0 +1,102 @@ +from sympy.core.relational import Eq +from sympy.core.expr import Expr +from sympy.core.numbers import Integer +from sympy.logic.boolalg import Boolean, And +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.exceptions import ShapeError +from typing import Union + + +def is_matadd_valid(*args: MatrixExpr) -> Boolean: + """Return the symbolic condition how ``MatAdd``, ``HadamardProduct`` + makes sense. + + Parameters + ========== + + args + The list of arguments of matrices to be tested for. + + Examples + ======== + + >>> from sympy import MatrixSymbol, symbols + >>> from sympy.matrices.expressions._shape import is_matadd_valid + + >>> m, n, p, q = symbols('m n p q') + >>> A = MatrixSymbol('A', m, n) + >>> B = MatrixSymbol('B', p, q) + >>> is_matadd_valid(A, B) + Eq(m, p) & Eq(n, q) + """ + rows, cols = zip(*(arg.shape for arg in args)) + return And( + *(Eq(i, j) for i, j in zip(rows[:-1], rows[1:])), + *(Eq(i, j) for i, j in zip(cols[:-1], cols[1:])), + ) + + +def is_matmul_valid(*args: Union[MatrixExpr, Expr]) -> Boolean: + """Return the symbolic condition how ``MatMul`` makes sense + + Parameters + ========== + + args + The list of arguments of matrices and scalar expressions to be tested + for. + + Examples + ======== + + >>> from sympy import MatrixSymbol, symbols + >>> from sympy.matrices.expressions._shape import is_matmul_valid + + >>> m, n, p, q = symbols('m n p q') + >>> A = MatrixSymbol('A', m, n) + >>> B = MatrixSymbol('B', p, q) + >>> is_matmul_valid(A, B) + Eq(n, p) + """ + rows, cols = zip(*(arg.shape for arg in args if isinstance(arg, MatrixExpr))) + return And(*(Eq(i, j) for i, j in zip(cols[:-1], rows[1:]))) + + +def is_square(arg: MatrixExpr, /) -> Boolean: + """Return the symbolic condition how the matrix is assumed to be square + + Parameters + ========== + + arg + The matrix to be tested for. + + Examples + ======== + + >>> from sympy import MatrixSymbol, symbols + >>> from sympy.matrices.expressions._shape import is_square + + >>> m, n = symbols('m n') + >>> A = MatrixSymbol('A', m, n) + >>> is_square(A) + Eq(m, n) + """ + return Eq(arg.rows, arg.cols) + + +def validate_matadd_integer(*args: MatrixExpr) -> None: + """Validate matrix shape for addition only for integer values""" + rows, cols = zip(*(x.shape for x in args)) + if len(set(filter(lambda x: isinstance(x, (int, Integer)), rows))) > 1: + raise ShapeError(f"Matrices have mismatching shape: {rows}") + if len(set(filter(lambda x: isinstance(x, (int, Integer)), cols))) > 1: + raise ShapeError(f"Matrices have mismatching shape: {cols}") + + +def validate_matmul_integer(*args: MatrixExpr) -> None: + """Validate matrix shape for multiplication only for integer values""" + for A, B in zip(args[:-1], args[1:]): + i, j = A.cols, B.rows + if isinstance(i, (int, Integer)) and isinstance(j, (int, Integer)) and i != j: + raise ShapeError("Matrices are not aligned", i, j) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/adjoint.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/adjoint.py new file mode 100644 index 0000000000000000000000000000000000000000..2039a7b2eb8eeacb02435979121c4133a11d8e02 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/adjoint.py @@ -0,0 +1,60 @@ +from sympy.core import Basic +from sympy.functions import adjoint, conjugate +from sympy.matrices.expressions.matexpr import MatrixExpr + + +class Adjoint(MatrixExpr): + """ + The Hermitian adjoint of a matrix expression. + + This is a symbolic object that simply stores its argument without + evaluating it. To actually compute the adjoint, use the ``adjoint()`` + function. + + Examples + ======== + + >>> from sympy import MatrixSymbol, Adjoint, adjoint + >>> A = MatrixSymbol('A', 3, 5) + >>> B = MatrixSymbol('B', 5, 3) + >>> Adjoint(A*B) + Adjoint(A*B) + >>> adjoint(A*B) + Adjoint(B)*Adjoint(A) + >>> adjoint(A*B) == Adjoint(A*B) + False + >>> adjoint(A*B) == Adjoint(A*B).doit() + True + """ + is_Adjoint = True + + def doit(self, **hints): + arg = self.arg + if hints.get('deep', True) and isinstance(arg, Basic): + return adjoint(arg.doit(**hints)) + else: + return adjoint(self.arg) + + @property + def arg(self): + return self.args[0] + + @property + def shape(self): + return self.arg.shape[::-1] + + def _entry(self, i, j, **kwargs): + return conjugate(self.arg._entry(j, i, **kwargs)) + + def _eval_adjoint(self): + return self.arg + + def _eval_transpose(self): + return self.arg.conjugate() + + def _eval_conjugate(self): + return self.arg.transpose() + + def _eval_trace(self): + from sympy.matrices.expressions.trace import Trace + return conjugate(Trace(self.arg)) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/applyfunc.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/applyfunc.py new file mode 100644 index 0000000000000000000000000000000000000000..c0363658447a8dc37a152b30e45533bac582b10c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/applyfunc.py @@ -0,0 +1,204 @@ +from sympy.core.expr import ExprBuilder +from sympy.core.function import (Function, FunctionClass, Lambda) +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify, _sympify +from sympy.matrices.expressions import MatrixExpr +from sympy.matrices.matrixbase import MatrixBase + + +class ElementwiseApplyFunction(MatrixExpr): + r""" + Apply function to a matrix elementwise without evaluating. + + Examples + ======== + + It can be created by calling ``.applyfunc()`` on a matrix + expression: + + >>> from sympy import MatrixSymbol + >>> from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction + >>> from sympy import exp + >>> X = MatrixSymbol("X", 3, 3) + >>> X.applyfunc(exp) + Lambda(_d, exp(_d)).(X) + + Otherwise using the class constructor: + + >>> from sympy import eye + >>> expr = ElementwiseApplyFunction(exp, eye(3)) + >>> expr + Lambda(_d, exp(_d)).(Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]])) + >>> expr.doit() + Matrix([ + [E, 1, 1], + [1, E, 1], + [1, 1, E]]) + + Notice the difference with the real mathematical functions: + + >>> exp(eye(3)) + Matrix([ + [E, 0, 0], + [0, E, 0], + [0, 0, E]]) + """ + + def __new__(cls, function, expr): + expr = _sympify(expr) + if not expr.is_Matrix: + raise ValueError("{} must be a matrix instance.".format(expr)) + + if expr.shape == (1, 1): + # Check if the function returns a matrix, in that case, just apply + # the function instead of creating an ElementwiseApplyFunc object: + ret = function(expr) + if isinstance(ret, MatrixExpr): + return ret + + if not isinstance(function, (FunctionClass, Lambda)): + d = Dummy('d') + function = Lambda(d, function(d)) + + function = sympify(function) + if not isinstance(function, (FunctionClass, Lambda)): + raise ValueError( + "{} should be compatible with SymPy function classes." + .format(function)) + + if 1 not in function.nargs: + raise ValueError( + '{} should be able to accept 1 arguments.'.format(function)) + + if not isinstance(function, Lambda): + d = Dummy('d') + function = Lambda(d, function(d)) + + obj = MatrixExpr.__new__(cls, function, expr) + return obj + + @property + def function(self): + return self.args[0] + + @property + def expr(self): + return self.args[1] + + @property + def shape(self): + return self.expr.shape + + def doit(self, **hints): + deep = hints.get("deep", True) + expr = self.expr + if deep: + expr = expr.doit(**hints) + function = self.function + if isinstance(function, Lambda) and function.is_identity: + # This is a Lambda containing the identity function. + return expr + if isinstance(expr, MatrixBase): + return expr.applyfunc(self.function) + elif isinstance(expr, ElementwiseApplyFunction): + return ElementwiseApplyFunction( + lambda x: self.function(expr.function(x)), + expr.expr + ).doit(**hints) + else: + return self + + def _entry(self, i, j, **kwargs): + return self.function(self.expr._entry(i, j, **kwargs)) + + def _get_function_fdiff(self): + d = Dummy("d") + function = self.function(d) + fdiff = function.diff(d) + if isinstance(fdiff, Function): + fdiff = type(fdiff) + else: + fdiff = Lambda(d, fdiff) + return fdiff + + def _eval_derivative(self, x): + from sympy.matrices.expressions.hadamard import hadamard_product + dexpr = self.expr.diff(x) + fdiff = self._get_function_fdiff() + return hadamard_product( + dexpr, + ElementwiseApplyFunction(fdiff, self.expr) + ) + + def _eval_derivative_matrix_lines(self, x): + from sympy.matrices.expressions.special import Identity + from sympy.tensor.array.expressions.array_expressions import ArrayContraction + from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct + + fdiff = self._get_function_fdiff() + lr = self.expr._eval_derivative_matrix_lines(x) + ewdiff = ElementwiseApplyFunction(fdiff, self.expr) + if 1 in x.shape: + # Vector: + iscolumn = self.shape[1] == 1 + for i in lr: + if iscolumn: + ptr1 = i.first_pointer + ptr2 = Identity(self.shape[1]) + else: + ptr1 = Identity(self.shape[0]) + ptr2 = i.second_pointer + + subexpr = ExprBuilder( + ArrayDiagonal, + [ + ExprBuilder( + ArrayTensorProduct, + [ + ewdiff, + ptr1, + ptr2, + ] + ), + (0, 2) if iscolumn else (1, 4) + ], + validator=ArrayDiagonal._validate + ) + i._lines = [subexpr] + i._first_pointer_parent = subexpr.args[0].args + i._first_pointer_index = 1 + i._second_pointer_parent = subexpr.args[0].args + i._second_pointer_index = 2 + else: + # Matrix case: + for i in lr: + ptr1 = i.first_pointer + ptr2 = i.second_pointer + newptr1 = Identity(ptr1.shape[1]) + newptr2 = Identity(ptr2.shape[1]) + subexpr = ExprBuilder( + ArrayContraction, + [ + ExprBuilder( + ArrayTensorProduct, + [ptr1, newptr1, ewdiff, ptr2, newptr2] + ), + (1, 2, 4), + (5, 7, 8), + ], + validator=ArrayContraction._validate + ) + i._first_pointer_parent = subexpr.args[0].args + i._first_pointer_index = 1 + i._second_pointer_parent = subexpr.args[0].args + i._second_pointer_index = 4 + i._lines = [subexpr] + return lr + + def _eval_transpose(self): + from sympy.matrices.expressions.transpose import Transpose + return self.func(self.function, Transpose(self.expr).doit()) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/blockmatrix.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/blockmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..0125d6233ba7cf8c0b590fbb655d9c7c447e0bd4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/blockmatrix.py @@ -0,0 +1,975 @@ +from sympy.assumptions.ask import (Q, ask) +from sympy.core import Basic, Add, Mul, S +from sympy.core.sympify import _sympify +from sympy.functions.elementary.complexes import re, im +from sympy.strategies import typed, exhaust, condition, do_one, unpack +from sympy.strategies.traverse import bottom_up +from sympy.utilities.iterables import is_sequence, sift +from sympy.utilities.misc import filldedent + +from sympy.matrices import Matrix, ShapeError +from sympy.matrices.exceptions import NonInvertibleMatrixError +from sympy.matrices.expressions.determinant import det, Determinant +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matadd import MatAdd +from sympy.matrices.expressions.matexpr import MatrixExpr, MatrixElement +from sympy.matrices.expressions.matmul import MatMul +from sympy.matrices.expressions.matpow import MatPow +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices.expressions.special import ZeroMatrix, Identity +from sympy.matrices.expressions.trace import trace +from sympy.matrices.expressions.transpose import Transpose, transpose + + +class BlockMatrix(MatrixExpr): + """A BlockMatrix is a Matrix comprised of other matrices. + + The submatrices are stored in a SymPy Matrix object but accessed as part of + a Matrix Expression + + >>> from sympy import (MatrixSymbol, BlockMatrix, symbols, + ... Identity, ZeroMatrix, block_collapse) + >>> n,m,l = symbols('n m l') + >>> X = MatrixSymbol('X', n, n) + >>> Y = MatrixSymbol('Y', m, m) + >>> Z = MatrixSymbol('Z', n, m) + >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m,n), Y]]) + >>> print(B) + Matrix([ + [X, Z], + [0, Y]]) + + >>> C = BlockMatrix([[Identity(n), Z]]) + >>> print(C) + Matrix([[I, Z]]) + + >>> print(block_collapse(C*B)) + Matrix([[X, Z + Z*Y]]) + + Some matrices might be comprised of rows of blocks with + the matrices in each row having the same height and the + rows all having the same total number of columns but + not having the same number of columns for each matrix + in each row. In this case, the matrix is not a block + matrix and should be instantiated by Matrix. + + >>> from sympy import ones, Matrix + >>> dat = [ + ... [ones(3,2), ones(3,3)*2], + ... [ones(2,3)*3, ones(2,2)*4]] + ... + >>> BlockMatrix(dat) + Traceback (most recent call last): + ... + ValueError: + Although this matrix is comprised of blocks, the blocks do not fill + the matrix in a size-symmetric fashion. To create a full matrix from + these arguments, pass them directly to Matrix. + >>> Matrix(dat) + Matrix([ + [1, 1, 2, 2, 2], + [1, 1, 2, 2, 2], + [1, 1, 2, 2, 2], + [3, 3, 3, 4, 4], + [3, 3, 3, 4, 4]]) + + See Also + ======== + sympy.matrices.matrixbase.MatrixBase.irregular + """ + def __new__(cls, *args, **kwargs): + from sympy.matrices.immutable import ImmutableDenseMatrix + isMat = lambda i: getattr(i, 'is_Matrix', False) + if len(args) != 1 or \ + not is_sequence(args[0]) or \ + len({isMat(r) for r in args[0]}) != 1: + raise ValueError(filldedent(''' + expecting a sequence of 1 or more rows + containing Matrices.''')) + rows = args[0] if args else [] + if not isMat(rows): + if rows and isMat(rows[0]): + rows = [rows] # rows is not list of lists or [] + # regularity check + # same number of matrices in each row + blocky = ok = len({len(r) for r in rows}) == 1 + if ok: + # same number of rows for each matrix in a row + for r in rows: + ok = len({i.rows for i in r}) == 1 + if not ok: + break + blocky = ok + if ok: + # same number of cols for each matrix in each col + for c in range(len(rows[0])): + ok = len({rows[i][c].cols + for i in range(len(rows))}) == 1 + if not ok: + break + if not ok: + # same total cols in each row + ok = len({ + sum(i.cols for i in r) for r in rows}) == 1 + if blocky and ok: + raise ValueError(filldedent(''' + Although this matrix is comprised of blocks, + the blocks do not fill the matrix in a + size-symmetric fashion. To create a full matrix + from these arguments, pass them directly to + Matrix.''')) + raise ValueError(filldedent(''' + When there are not the same number of rows in each + row's matrices or there are not the same number of + total columns in each row, the matrix is not a + block matrix. If this matrix is known to consist of + blocks fully filling a 2-D space then see + Matrix.irregular.''')) + mat = ImmutableDenseMatrix(rows, evaluate=False) + obj = Basic.__new__(cls, mat) + return obj + + @property + def shape(self): + numrows = numcols = 0 + M = self.blocks + for i in range(M.shape[0]): + numrows += M[i, 0].shape[0] + for i in range(M.shape[1]): + numcols += M[0, i].shape[1] + return (numrows, numcols) + + @property + def blockshape(self): + return self.blocks.shape + + @property + def blocks(self): + return self.args[0] + + @property + def rowblocksizes(self): + return [self.blocks[i, 0].rows for i in range(self.blockshape[0])] + + @property + def colblocksizes(self): + return [self.blocks[0, i].cols for i in range(self.blockshape[1])] + + def structurally_equal(self, other): + return (isinstance(other, BlockMatrix) + and self.shape == other.shape + and self.blockshape == other.blockshape + and self.rowblocksizes == other.rowblocksizes + and self.colblocksizes == other.colblocksizes) + + def _blockmul(self, other): + if (isinstance(other, BlockMatrix) and + self.colblocksizes == other.rowblocksizes): + return BlockMatrix(self.blocks*other.blocks) + + return self * other + + def _blockadd(self, other): + if (isinstance(other, BlockMatrix) + and self.structurally_equal(other)): + return BlockMatrix(self.blocks + other.blocks) + + return self + other + + def _eval_transpose(self): + # Flip all the individual matrices + matrices = [transpose(matrix) for matrix in self.blocks] + # Make a copy + M = Matrix(self.blockshape[0], self.blockshape[1], matrices) + # Transpose the block structure + M = M.transpose() + return BlockMatrix(M) + + def _eval_adjoint(self): + return BlockMatrix( + Matrix(self.blockshape[0], self.blockshape[1], self.blocks).adjoint() + ) + + def _eval_trace(self): + if self.rowblocksizes == self.colblocksizes: + blocks = [self.blocks[i, i] for i in range(self.blockshape[0])] + return Add(*[trace(block) for block in blocks]) + + def _eval_determinant(self): + if self.blockshape == (1, 1): + return det(self.blocks[0, 0]) + if self.blockshape == (2, 2): + [[A, B], + [C, D]] = self.blocks.tolist() + if ask(Q.invertible(A)): + return det(A)*det(D - C*A.I*B) + elif ask(Q.invertible(D)): + return det(D)*det(A - B*D.I*C) + return Determinant(self) + + def _eval_as_real_imag(self): + real_matrices = [re(matrix) for matrix in self.blocks] + real_matrices = Matrix(self.blockshape[0], self.blockshape[1], real_matrices) + + im_matrices = [im(matrix) for matrix in self.blocks] + im_matrices = Matrix(self.blockshape[0], self.blockshape[1], im_matrices) + + return (BlockMatrix(real_matrices), BlockMatrix(im_matrices)) + + def _eval_derivative(self, x): + return BlockMatrix(self.blocks.diff(x)) + + def transpose(self): + """Return transpose of matrix. + + Examples + ======== + + >>> from sympy import MatrixSymbol, BlockMatrix, ZeroMatrix + >>> from sympy.abc import m, n + >>> X = MatrixSymbol('X', n, n) + >>> Y = MatrixSymbol('Y', m, m) + >>> Z = MatrixSymbol('Z', n, m) + >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m,n), Y]]) + >>> B.transpose() + Matrix([ + [X.T, 0], + [Z.T, Y.T]]) + >>> _.transpose() + Matrix([ + [X, Z], + [0, Y]]) + """ + return self._eval_transpose() + + def schur(self, mat = 'A', generalized = False): + """Return the Schur Complement of the 2x2 BlockMatrix + + Parameters + ========== + + mat : String, optional + The matrix with respect to which the + Schur Complement is calculated. 'A' is + used by default + + generalized : bool, optional + If True, returns the generalized Schur + Component which uses Moore-Penrose Inverse + + Examples + ======== + + >>> from sympy import symbols, MatrixSymbol, BlockMatrix + >>> m, n = symbols('m n') + >>> A = MatrixSymbol('A', n, n) + >>> B = MatrixSymbol('B', n, m) + >>> C = MatrixSymbol('C', m, n) + >>> D = MatrixSymbol('D', m, m) + >>> X = BlockMatrix([[A, B], [C, D]]) + + The default Schur Complement is evaluated with "A" + + >>> X.schur() + -C*A**(-1)*B + D + >>> X.schur('D') + A - B*D**(-1)*C + + Schur complement with non-invertible matrices is not + defined. Instead, the generalized Schur complement can + be calculated which uses the Moore-Penrose Inverse. To + achieve this, `generalized` must be set to `True` + + >>> X.schur('B', generalized=True) + C - D*(B.T*B)**(-1)*B.T*A + >>> X.schur('C', generalized=True) + -A*(C.T*C)**(-1)*C.T*D + B + + Returns + ======= + + M : Matrix + The Schur Complement Matrix + + Raises + ====== + + ShapeError + If the block matrix is not a 2x2 matrix + + NonInvertibleMatrixError + If given matrix is non-invertible + + References + ========== + + .. [1] Wikipedia Article on Schur Component : https://en.wikipedia.org/wiki/Schur_complement + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.pinv + """ + + if self.blockshape == (2, 2): + [[A, B], + [C, D]] = self.blocks.tolist() + d={'A' : A, 'B' : B, 'C' : C, 'D' : D} + try: + inv = (d[mat].T*d[mat]).inv()*d[mat].T if generalized else d[mat].inv() + if mat == 'A': + return D - C * inv * B + elif mat == 'B': + return C - D * inv * A + elif mat == 'C': + return B - A * inv * D + elif mat == 'D': + return A - B * inv * C + #For matrices where no sub-matrix is square + return self + except NonInvertibleMatrixError: + raise NonInvertibleMatrixError('The given matrix is not invertible. Please set generalized=True \ + to compute the generalized Schur Complement which uses Moore-Penrose Inverse') + else: + raise ShapeError('Schur Complement can only be calculated for 2x2 block matrices') + + def LDUdecomposition(self): + """Returns the Block LDU decomposition of + a 2x2 Block Matrix + + Returns + ======= + + (L, D, U) : Matrices + L : Lower Diagonal Matrix + D : Diagonal Matrix + U : Upper Diagonal Matrix + + Examples + ======== + + >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse + >>> m, n = symbols('m n') + >>> A = MatrixSymbol('A', n, n) + >>> B = MatrixSymbol('B', n, m) + >>> C = MatrixSymbol('C', m, n) + >>> D = MatrixSymbol('D', m, m) + >>> X = BlockMatrix([[A, B], [C, D]]) + >>> L, D, U = X.LDUdecomposition() + >>> block_collapse(L*D*U) + Matrix([ + [A, B], + [C, D]]) + + Raises + ====== + + ShapeError + If the block matrix is not a 2x2 matrix + + NonInvertibleMatrixError + If the matrix "A" is non-invertible + + See Also + ======== + sympy.matrices.expressions.blockmatrix.BlockMatrix.UDLdecomposition + sympy.matrices.expressions.blockmatrix.BlockMatrix.LUdecomposition + """ + if self.blockshape == (2,2): + [[A, B], + [C, D]] = self.blocks.tolist() + try: + AI = A.I + except NonInvertibleMatrixError: + raise NonInvertibleMatrixError('Block LDU decomposition cannot be calculated when\ + "A" is singular') + Ip = Identity(B.shape[0]) + Iq = Identity(B.shape[1]) + Z = ZeroMatrix(*B.shape) + L = BlockMatrix([[Ip, Z], [C*AI, Iq]]) + D = BlockDiagMatrix(A, self.schur()) + U = BlockMatrix([[Ip, AI*B],[Z.T, Iq]]) + return L, D, U + else: + raise ShapeError("Block LDU decomposition is supported only for 2x2 block matrices") + + def UDLdecomposition(self): + """Returns the Block UDL decomposition of + a 2x2 Block Matrix + + Returns + ======= + + (U, D, L) : Matrices + U : Upper Diagonal Matrix + D : Diagonal Matrix + L : Lower Diagonal Matrix + + Examples + ======== + + >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse + >>> m, n = symbols('m n') + >>> A = MatrixSymbol('A', n, n) + >>> B = MatrixSymbol('B', n, m) + >>> C = MatrixSymbol('C', m, n) + >>> D = MatrixSymbol('D', m, m) + >>> X = BlockMatrix([[A, B], [C, D]]) + >>> U, D, L = X.UDLdecomposition() + >>> block_collapse(U*D*L) + Matrix([ + [A, B], + [C, D]]) + + Raises + ====== + + ShapeError + If the block matrix is not a 2x2 matrix + + NonInvertibleMatrixError + If the matrix "D" is non-invertible + + See Also + ======== + sympy.matrices.expressions.blockmatrix.BlockMatrix.LDUdecomposition + sympy.matrices.expressions.blockmatrix.BlockMatrix.LUdecomposition + """ + if self.blockshape == (2,2): + [[A, B], + [C, D]] = self.blocks.tolist() + try: + DI = D.I + except NonInvertibleMatrixError: + raise NonInvertibleMatrixError('Block UDL decomposition cannot be calculated when\ + "D" is singular') + Ip = Identity(A.shape[0]) + Iq = Identity(B.shape[1]) + Z = ZeroMatrix(*B.shape) + U = BlockMatrix([[Ip, B*DI], [Z.T, Iq]]) + D = BlockDiagMatrix(self.schur('D'), D) + L = BlockMatrix([[Ip, Z],[DI*C, Iq]]) + return U, D, L + else: + raise ShapeError("Block UDL decomposition is supported only for 2x2 block matrices") + + def LUdecomposition(self): + """Returns the Block LU decomposition of + a 2x2 Block Matrix + + Returns + ======= + + (L, U) : Matrices + L : Lower Diagonal Matrix + U : Upper Diagonal Matrix + + Examples + ======== + + >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse + >>> m, n = symbols('m n') + >>> A = MatrixSymbol('A', n, n) + >>> B = MatrixSymbol('B', n, m) + >>> C = MatrixSymbol('C', m, n) + >>> D = MatrixSymbol('D', m, m) + >>> X = BlockMatrix([[A, B], [C, D]]) + >>> L, U = X.LUdecomposition() + >>> block_collapse(L*U) + Matrix([ + [A, B], + [C, D]]) + + Raises + ====== + + ShapeError + If the block matrix is not a 2x2 matrix + + NonInvertibleMatrixError + If the matrix "A" is non-invertible + + See Also + ======== + sympy.matrices.expressions.blockmatrix.BlockMatrix.UDLdecomposition + sympy.matrices.expressions.blockmatrix.BlockMatrix.LDUdecomposition + """ + if self.blockshape == (2,2): + [[A, B], + [C, D]] = self.blocks.tolist() + try: + A = A**S.Half + AI = A.I + except NonInvertibleMatrixError: + raise NonInvertibleMatrixError('Block LU decomposition cannot be calculated when\ + "A" is singular') + Z = ZeroMatrix(*B.shape) + Q = self.schur()**S.Half + L = BlockMatrix([[A, Z], [C*AI, Q]]) + U = BlockMatrix([[A, AI*B],[Z.T, Q]]) + return L, U + else: + raise ShapeError("Block LU decomposition is supported only for 2x2 block matrices") + + def _entry(self, i, j, **kwargs): + # Find row entry + orig_i, orig_j = i, j + for row_block, numrows in enumerate(self.rowblocksizes): + cmp = i < numrows + if cmp == True: + break + elif cmp == False: + i -= numrows + elif row_block < self.blockshape[0] - 1: + # Can't tell which block and it's not the last one, return unevaluated + return MatrixElement(self, orig_i, orig_j) + for col_block, numcols in enumerate(self.colblocksizes): + cmp = j < numcols + if cmp == True: + break + elif cmp == False: + j -= numcols + elif col_block < self.blockshape[1] - 1: + return MatrixElement(self, orig_i, orig_j) + return self.blocks[row_block, col_block][i, j] + + @property + def is_Identity(self): + if self.blockshape[0] != self.blockshape[1]: + return False + for i in range(self.blockshape[0]): + for j in range(self.blockshape[1]): + if i==j and not self.blocks[i, j].is_Identity: + return False + if i!=j and not self.blocks[i, j].is_ZeroMatrix: + return False + return True + + @property + def is_structurally_symmetric(self): + return self.rowblocksizes == self.colblocksizes + + def equals(self, other): + if self == other: + return True + if (isinstance(other, BlockMatrix) and self.blocks == other.blocks): + return True + return super().equals(other) + + +class BlockDiagMatrix(BlockMatrix): + """A sparse matrix with block matrices along its diagonals + + Examples + ======== + + >>> from sympy import MatrixSymbol, BlockDiagMatrix, symbols + >>> n, m, l = symbols('n m l') + >>> X = MatrixSymbol('X', n, n) + >>> Y = MatrixSymbol('Y', m, m) + >>> BlockDiagMatrix(X, Y) + Matrix([ + [X, 0], + [0, Y]]) + + Notes + ===== + + If you want to get the individual diagonal blocks, use + :meth:`get_diag_blocks`. + + See Also + ======== + + sympy.matrices.dense.diag + """ + def __new__(cls, *mats): + return Basic.__new__(BlockDiagMatrix, *[_sympify(m) for m in mats]) + + @property + def diag(self): + return self.args + + @property + def blocks(self): + from sympy.matrices.immutable import ImmutableDenseMatrix + mats = self.args + data = [[mats[i] if i == j else ZeroMatrix(mats[i].rows, mats[j].cols) + for j in range(len(mats))] + for i in range(len(mats))] + return ImmutableDenseMatrix(data, evaluate=False) + + @property + def shape(self): + return (sum(block.rows for block in self.args), + sum(block.cols for block in self.args)) + + @property + def blockshape(self): + n = len(self.args) + return (n, n) + + @property + def rowblocksizes(self): + return [block.rows for block in self.args] + + @property + def colblocksizes(self): + return [block.cols for block in self.args] + + def _all_square_blocks(self): + """Returns true if all blocks are square""" + return all(mat.is_square for mat in self.args) + + def _eval_determinant(self): + if self._all_square_blocks(): + return Mul(*[det(mat) for mat in self.args]) + # At least one block is non-square. Since the entire matrix must be square we know there must + # be at least two blocks in this matrix, in which case the entire matrix is necessarily rank-deficient + return S.Zero + + def _eval_inverse(self, expand='ignored'): + if self._all_square_blocks(): + return BlockDiagMatrix(*[mat.inverse() for mat in self.args]) + # See comment in _eval_determinant() + raise NonInvertibleMatrixError('Matrix det == 0; not invertible.') + + def _eval_transpose(self): + return BlockDiagMatrix(*[mat.transpose() for mat in self.args]) + + def _blockmul(self, other): + if (isinstance(other, BlockDiagMatrix) and + self.colblocksizes == other.rowblocksizes): + return BlockDiagMatrix(*[a*b for a, b in zip(self.args, other.args)]) + else: + return BlockMatrix._blockmul(self, other) + + def _blockadd(self, other): + if (isinstance(other, BlockDiagMatrix) and + self.blockshape == other.blockshape and + self.rowblocksizes == other.rowblocksizes and + self.colblocksizes == other.colblocksizes): + return BlockDiagMatrix(*[a + b for a, b in zip(self.args, other.args)]) + else: + return BlockMatrix._blockadd(self, other) + + def get_diag_blocks(self): + """Return the list of diagonal blocks of the matrix. + + Examples + ======== + + >>> from sympy import BlockDiagMatrix, Matrix + + >>> A = Matrix([[1, 2], [3, 4]]) + >>> B = Matrix([[5, 6], [7, 8]]) + >>> M = BlockDiagMatrix(A, B) + + How to get diagonal blocks from the block diagonal matrix: + + >>> diag_blocks = M.get_diag_blocks() + >>> diag_blocks[0] + Matrix([ + [1, 2], + [3, 4]]) + >>> diag_blocks[1] + Matrix([ + [5, 6], + [7, 8]]) + """ + return self.args + + +def block_collapse(expr): + """Evaluates a block matrix expression + + >>> from sympy import MatrixSymbol, BlockMatrix, symbols, Identity, ZeroMatrix, block_collapse + >>> n,m,l = symbols('n m l') + >>> X = MatrixSymbol('X', n, n) + >>> Y = MatrixSymbol('Y', m, m) + >>> Z = MatrixSymbol('Z', n, m) + >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m, n), Y]]) + >>> print(B) + Matrix([ + [X, Z], + [0, Y]]) + + >>> C = BlockMatrix([[Identity(n), Z]]) + >>> print(C) + Matrix([[I, Z]]) + + >>> print(block_collapse(C*B)) + Matrix([[X, Z + Z*Y]]) + """ + from sympy.strategies.util import expr_fns + + hasbm = lambda expr: isinstance(expr, MatrixExpr) and expr.has(BlockMatrix) + + conditioned_rl = condition( + hasbm, + typed( + {MatAdd: do_one(bc_matadd, bc_block_plus_ident), + MatMul: do_one(bc_matmul, bc_dist), + MatPow: bc_matmul, + Transpose: bc_transpose, + Inverse: bc_inverse, + BlockMatrix: do_one(bc_unpack, deblock)} + ) + ) + + rule = exhaust( + bottom_up( + exhaust(conditioned_rl), + fns=expr_fns + ) + ) + + result = rule(expr) + doit = getattr(result, 'doit', None) + if doit is not None: + return doit() + else: + return result + +def bc_unpack(expr): + if expr.blockshape == (1, 1): + return expr.blocks[0, 0] + return expr + +def bc_matadd(expr): + args = sift(expr.args, lambda M: isinstance(M, BlockMatrix)) + blocks = args[True] + if not blocks: + return expr + + nonblocks = args[False] + block = blocks[0] + for b in blocks[1:]: + block = block._blockadd(b) + if nonblocks: + return MatAdd(*nonblocks) + block + else: + return block + +def bc_block_plus_ident(expr): + idents = [arg for arg in expr.args if arg.is_Identity] + if not idents: + return expr + + blocks = [arg for arg in expr.args if isinstance(arg, BlockMatrix)] + if (blocks and all(b.structurally_equal(blocks[0]) for b in blocks) + and blocks[0].is_structurally_symmetric): + block_id = BlockDiagMatrix(*[Identity(k) + for k in blocks[0].rowblocksizes]) + rest = [arg for arg in expr.args if not arg.is_Identity and not isinstance(arg, BlockMatrix)] + return MatAdd(block_id * len(idents), *blocks, *rest).doit() + + return expr + +def bc_dist(expr): + """ Turn a*[X, Y] into [a*X, a*Y] """ + factor, mat = expr.as_coeff_mmul() + if factor == 1: + return expr + + unpacked = unpack(mat) + + if isinstance(unpacked, BlockDiagMatrix): + B = unpacked.diag + new_B = [factor * mat for mat in B] + return BlockDiagMatrix(*new_B) + elif isinstance(unpacked, BlockMatrix): + B = unpacked.blocks + new_B = [ + [factor * B[i, j] for j in range(B.cols)] for i in range(B.rows)] + return BlockMatrix(new_B) + return expr + + +def bc_matmul(expr): + if isinstance(expr, MatPow): + if expr.args[1].is_Integer and expr.args[1] > 0: + factor, matrices = 1, [expr.args[0]]*expr.args[1] + else: + return expr + else: + factor, matrices = expr.as_coeff_matrices() + + i = 0 + while (i+1 < len(matrices)): + A, B = matrices[i:i+2] + if isinstance(A, BlockMatrix) and isinstance(B, BlockMatrix): + matrices[i] = A._blockmul(B) + matrices.pop(i+1) + elif isinstance(A, BlockMatrix): + matrices[i] = A._blockmul(BlockMatrix([[B]])) + matrices.pop(i+1) + elif isinstance(B, BlockMatrix): + matrices[i] = BlockMatrix([[A]])._blockmul(B) + matrices.pop(i+1) + else: + i+=1 + return MatMul(factor, *matrices).doit() + +def bc_transpose(expr): + collapse = block_collapse(expr.arg) + return collapse._eval_transpose() + + +def bc_inverse(expr): + if isinstance(expr.arg, BlockDiagMatrix): + return expr.inverse() + + expr2 = blockinverse_1x1(expr) + if expr != expr2: + return expr2 + return blockinverse_2x2(Inverse(reblock_2x2(expr.arg))) + +def blockinverse_1x1(expr): + if isinstance(expr.arg, BlockMatrix) and expr.arg.blockshape == (1, 1): + mat = Matrix([[expr.arg.blocks[0].inverse()]]) + return BlockMatrix(mat) + return expr + + +def blockinverse_2x2(expr): + if isinstance(expr.arg, BlockMatrix) and expr.arg.blockshape == (2, 2): + # See: Inverses of 2x2 Block Matrices, Tzon-Tzer Lu and Sheng-Hua Shiou + [[A, B], + [C, D]] = expr.arg.blocks.tolist() + + formula = _choose_2x2_inversion_formula(A, B, C, D) + if formula != None: + MI = expr.arg.schur(formula).I + if formula == 'A': + AI = A.I + return BlockMatrix([[AI + AI * B * MI * C * AI, -AI * B * MI], [-MI * C * AI, MI]]) + if formula == 'B': + BI = B.I + return BlockMatrix([[-MI * D * BI, MI], [BI + BI * A * MI * D * BI, -BI * A * MI]]) + if formula == 'C': + CI = C.I + return BlockMatrix([[-CI * D * MI, CI + CI * D * MI * A * CI], [MI, -MI * A * CI]]) + if formula == 'D': + DI = D.I + return BlockMatrix([[MI, -MI * B * DI], [-DI * C * MI, DI + DI * C * MI * B * DI]]) + + return expr + + +def _choose_2x2_inversion_formula(A, B, C, D): + """ + Assuming [[A, B], [C, D]] would form a valid square block matrix, find + which of the classical 2x2 block matrix inversion formulas would be + best suited. + + Returns 'A', 'B', 'C', 'D' to represent the algorithm involving inversion + of the given argument or None if the matrix cannot be inverted using + any of those formulas. + """ + # Try to find a known invertible matrix. Note that the Schur complement + # is currently not being considered for this + A_inv = ask(Q.invertible(A)) + if A_inv == True: + return 'A' + B_inv = ask(Q.invertible(B)) + if B_inv == True: + return 'B' + C_inv = ask(Q.invertible(C)) + if C_inv == True: + return 'C' + D_inv = ask(Q.invertible(D)) + if D_inv == True: + return 'D' + # Otherwise try to find a matrix that isn't known to be non-invertible + if A_inv != False: + return 'A' + if B_inv != False: + return 'B' + if C_inv != False: + return 'C' + if D_inv != False: + return 'D' + return None + + +def deblock(B): + """ Flatten a BlockMatrix of BlockMatrices """ + if not isinstance(B, BlockMatrix) or not B.blocks.has(BlockMatrix): + return B + wrap = lambda x: x if isinstance(x, BlockMatrix) else BlockMatrix([[x]]) + bb = B.blocks.applyfunc(wrap) # everything is a block + + try: + MM = Matrix(0, sum(bb[0, i].blocks.shape[1] for i in range(bb.shape[1])), []) + for row in range(0, bb.shape[0]): + M = Matrix(bb[row, 0].blocks) + for col in range(1, bb.shape[1]): + M = M.row_join(bb[row, col].blocks) + MM = MM.col_join(M) + + return BlockMatrix(MM) + except ShapeError: + return B + + +def reblock_2x2(expr): + """ + Reblock a BlockMatrix so that it has 2x2 blocks of block matrices. If + possible in such a way that the matrix continues to be invertible using the + classical 2x2 block inversion formulas. + """ + if not isinstance(expr, BlockMatrix) or not all(d > 2 for d in expr.blockshape): + return expr + + BM = BlockMatrix # for brevity's sake + rowblocks, colblocks = expr.blockshape + blocks = expr.blocks + for i in range(1, rowblocks): + for j in range(1, colblocks): + # try to split rows at i and cols at j + A = bc_unpack(BM(blocks[:i, :j])) + B = bc_unpack(BM(blocks[:i, j:])) + C = bc_unpack(BM(blocks[i:, :j])) + D = bc_unpack(BM(blocks[i:, j:])) + + formula = _choose_2x2_inversion_formula(A, B, C, D) + if formula is not None: + return BlockMatrix([[A, B], [C, D]]) + + # else: nothing worked, just split upper left corner + return BM([[blocks[0, 0], BM(blocks[0, 1:])], + [BM(blocks[1:, 0]), BM(blocks[1:, 1:])]]) + + +def bounds(sizes): + """ Convert sequence of numbers into pairs of low-high pairs + + >>> from sympy.matrices.expressions.blockmatrix import bounds + >>> bounds((1, 10, 50)) + [(0, 1), (1, 11), (11, 61)] + """ + low = 0 + rv = [] + for size in sizes: + rv.append((low, low + size)) + low += size + return rv + +def blockcut(expr, rowsizes, colsizes): + """ Cut a matrix expression into Blocks + + >>> from sympy import ImmutableMatrix, blockcut + >>> M = ImmutableMatrix(4, 4, range(16)) + >>> B = blockcut(M, (1, 3), (1, 3)) + >>> type(B).__name__ + 'BlockMatrix' + >>> ImmutableMatrix(B.blocks[0, 1]) + Matrix([[1, 2, 3]]) + """ + + rowbounds = bounds(rowsizes) + colbounds = bounds(colsizes) + return BlockMatrix([[MatrixSlice(expr, rowbound, colbound) + for colbound in colbounds] + for rowbound in rowbounds]) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/companion.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/companion.py new file mode 100644 index 0000000000000000000000000000000000000000..6969c917f63806cb1f5417804e01ecc1350d1406 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/companion.py @@ -0,0 +1,56 @@ +from sympy.core.singleton import S +from sympy.core.sympify import _sympify +from sympy.polys.polytools import Poly + +from .matexpr import MatrixExpr + + +class CompanionMatrix(MatrixExpr): + """A symbolic companion matrix of a polynomial. + + Examples + ======== + + >>> from sympy import Poly, Symbol, symbols + >>> from sympy.matrices.expressions import CompanionMatrix + >>> x = Symbol('x') + >>> c0, c1, c2, c3, c4 = symbols('c0:5') + >>> p = Poly(c0 + c1*x + c2*x**2 + c3*x**3 + c4*x**4 + x**5, x) + >>> CompanionMatrix(p) + CompanionMatrix(Poly(x**5 + c4*x**4 + c3*x**3 + c2*x**2 + c1*x + c0, + x, domain='ZZ[c0,c1,c2,c3,c4]')) + """ + def __new__(cls, poly): + poly = _sympify(poly) + if not isinstance(poly, Poly): + raise ValueError("{} must be a Poly instance.".format(poly)) + if not poly.is_monic: + raise ValueError("{} must be a monic polynomial.".format(poly)) + if not poly.is_univariate: + raise ValueError( + "{} must be a univariate polynomial.".format(poly)) + if not poly.degree() >= 1: + raise ValueError( + "{} must have degree not less than 1.".format(poly)) + + return super().__new__(cls, poly) + + + @property + def shape(self): + poly = self.args[0] + size = poly.degree() + return size, size + + + def _entry(self, i, j): + if j == self.cols - 1: + return -self.args[0].all_coeffs()[-1 - i] + elif i == j + 1: + return S.One + return S.Zero + + + def as_explicit(self): + from sympy.matrices.immutable import ImmutableDenseMatrix + return ImmutableDenseMatrix.companion(self.args[0]) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/determinant.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/determinant.py new file mode 100644 index 0000000000000000000000000000000000000000..b323b3f93a5a0404bf2205f39d25b931d173b6d9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/determinant.py @@ -0,0 +1,148 @@ +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices.matrixbase import MatrixBase + + +class Determinant(Expr): + """Matrix Determinant + + Represents the determinant of a matrix expression. + + Examples + ======== + + >>> from sympy import MatrixSymbol, Determinant, eye + >>> A = MatrixSymbol('A', 3, 3) + >>> Determinant(A) + Determinant(A) + >>> Determinant(eye(3)).doit() + 1 + """ + is_commutative = True + + def __new__(cls, mat): + mat = sympify(mat) + if not mat.is_Matrix: + raise TypeError("Input to Determinant, %s, not a matrix" % str(mat)) + + if mat.is_square is False: + raise NonSquareMatrixError("Det of a non-square matrix") + + return Basic.__new__(cls, mat) + + @property + def arg(self): + return self.args[0] + + @property + def kind(self): + return self.arg.kind.element_kind + + def doit(self, **hints): + arg = self.arg + if hints.get('deep', True): + arg = arg.doit(**hints) + + result = arg._eval_determinant() + if result is not None: + return result + + return self + + +def det(matexpr): + """ Matrix Determinant + + Examples + ======== + + >>> from sympy import MatrixSymbol, det, eye + >>> A = MatrixSymbol('A', 3, 3) + >>> det(A) + Determinant(A) + >>> det(eye(3)) + 1 + """ + + return Determinant(matexpr).doit() + +class Permanent(Expr): + """Matrix Permanent + + Represents the permanent of a matrix expression. + + Examples + ======== + + >>> from sympy import MatrixSymbol, Permanent, ones + >>> A = MatrixSymbol('A', 3, 3) + >>> Permanent(A) + Permanent(A) + >>> Permanent(ones(3, 3)).doit() + 6 + """ + + def __new__(cls, mat): + mat = sympify(mat) + if not mat.is_Matrix: + raise TypeError("Input to Permanent, %s, not a matrix" % str(mat)) + + return Basic.__new__(cls, mat) + + @property + def arg(self): + return self.args[0] + + def doit(self, expand=False, **hints): + if isinstance(self.arg, MatrixBase): + return self.arg.per() + else: + return self + +def per(matexpr): + """ Matrix Permanent + + Examples + ======== + + >>> from sympy import MatrixSymbol, Matrix, per, ones + >>> A = MatrixSymbol('A', 3, 3) + >>> per(A) + Permanent(A) + >>> per(ones(5, 5)) + 120 + >>> M = Matrix([1, 2, 5]) + >>> per(M) + 8 + """ + + return Permanent(matexpr).doit() + +from sympy.assumptions.ask import ask, Q +from sympy.assumptions.refine import handlers_dict + + +def refine_Determinant(expr, assumptions): + """ + >>> from sympy import MatrixSymbol, Q, assuming, refine, det + >>> X = MatrixSymbol('X', 2, 2) + >>> det(X) + Determinant(X) + >>> with assuming(Q.orthogonal(X)): + ... print(refine(det(X))) + 1 + """ + if ask(Q.orthogonal(expr.arg), assumptions): + return S.One + elif ask(Q.singular(expr.arg), assumptions): + return S.Zero + elif ask(Q.unit_triangular(expr.arg), assumptions): + return S.One + + return expr + + +handlers_dict['Determinant'] = refine_Determinant diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/diagonal.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/diagonal.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8a0216588143e3e251dab84c25f038fad550a4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/diagonal.py @@ -0,0 +1,220 @@ +from sympy.core.sympify import _sympify + +from sympy.matrices.expressions import MatrixExpr +from sympy.core import S, Eq, Ge +from sympy.core.mul import Mul +from sympy.functions.special.tensor_functions import KroneckerDelta + + +class DiagonalMatrix(MatrixExpr): + """DiagonalMatrix(M) will create a matrix expression that + behaves as though all off-diagonal elements, + `M[i, j]` where `i != j`, are zero. + + Examples + ======== + + >>> from sympy import MatrixSymbol, DiagonalMatrix, Symbol + >>> n = Symbol('n', integer=True) + >>> m = Symbol('m', integer=True) + >>> D = DiagonalMatrix(MatrixSymbol('x', 2, 3)) + >>> D[1, 2] + 0 + >>> D[1, 1] + x[1, 1] + + The length of the diagonal -- the lesser of the two dimensions of `M` -- + is accessed through the `diagonal_length` property: + + >>> D.diagonal_length + 2 + >>> DiagonalMatrix(MatrixSymbol('x', n + 1, n)).diagonal_length + n + + When one of the dimensions is symbolic the other will be treated as + though it is smaller: + + >>> tall = DiagonalMatrix(MatrixSymbol('x', n, 3)) + >>> tall.diagonal_length + 3 + >>> tall[10, 1] + 0 + + When the size of the diagonal is not known, a value of None will + be returned: + + >>> DiagonalMatrix(MatrixSymbol('x', n, m)).diagonal_length is None + True + + """ + arg = property(lambda self: self.args[0]) + + shape = property(lambda self: self.arg.shape) # type:ignore + + @property + def diagonal_length(self): + r, c = self.shape + if r.is_Integer and c.is_Integer: + m = min(r, c) + elif r.is_Integer and not c.is_Integer: + m = r + elif c.is_Integer and not r.is_Integer: + m = c + elif r == c: + m = r + else: + try: + m = min(r, c) + except TypeError: + m = None + return m + + def _entry(self, i, j, **kwargs): + if self.diagonal_length is not None: + if Ge(i, self.diagonal_length) is S.true: + return S.Zero + elif Ge(j, self.diagonal_length) is S.true: + return S.Zero + eq = Eq(i, j) + if eq is S.true: + return self.arg[i, i] + elif eq is S.false: + return S.Zero + return self.arg[i, j]*KroneckerDelta(i, j) + + +class DiagonalOf(MatrixExpr): + """DiagonalOf(M) will create a matrix expression that + is equivalent to the diagonal of `M`, represented as + a single column matrix. + + Examples + ======== + + >>> from sympy import MatrixSymbol, DiagonalOf, Symbol + >>> n = Symbol('n', integer=True) + >>> m = Symbol('m', integer=True) + >>> x = MatrixSymbol('x', 2, 3) + >>> diag = DiagonalOf(x) + >>> diag.shape + (2, 1) + + The diagonal can be addressed like a matrix or vector and will + return the corresponding element of the original matrix: + + >>> diag[1, 0] == diag[1] == x[1, 1] + True + + The length of the diagonal -- the lesser of the two dimensions of `M` -- + is accessed through the `diagonal_length` property: + + >>> diag.diagonal_length + 2 + >>> DiagonalOf(MatrixSymbol('x', n + 1, n)).diagonal_length + n + + When only one of the dimensions is symbolic the other will be + treated as though it is smaller: + + >>> dtall = DiagonalOf(MatrixSymbol('x', n, 3)) + >>> dtall.diagonal_length + 3 + + When the size of the diagonal is not known, a value of None will + be returned: + + >>> DiagonalOf(MatrixSymbol('x', n, m)).diagonal_length is None + True + + """ + arg = property(lambda self: self.args[0]) + @property + def shape(self): + r, c = self.arg.shape + if r.is_Integer and c.is_Integer: + m = min(r, c) + elif r.is_Integer and not c.is_Integer: + m = r + elif c.is_Integer and not r.is_Integer: + m = c + elif r == c: + m = r + else: + try: + m = min(r, c) + except TypeError: + m = None + return m, S.One + + @property + def diagonal_length(self): + return self.shape[0] + + def _entry(self, i, j, **kwargs): + return self.arg._entry(i, i, **kwargs) + + +class DiagMatrix(MatrixExpr): + """ + Turn a vector into a diagonal matrix. + """ + def __new__(cls, vector): + vector = _sympify(vector) + obj = MatrixExpr.__new__(cls, vector) + shape = vector.shape + dim = shape[1] if shape[0] == 1 else shape[0] + if vector.shape[0] != 1: + obj._iscolumn = True + else: + obj._iscolumn = False + obj._shape = (dim, dim) + obj._vector = vector + return obj + + @property + def shape(self): + return self._shape + + def _entry(self, i, j, **kwargs): + if self._iscolumn: + result = self._vector._entry(i, 0, **kwargs) + else: + result = self._vector._entry(0, j, **kwargs) + if i != j: + result *= KroneckerDelta(i, j) + return result + + def _eval_transpose(self): + return self + + def as_explicit(self): + from sympy.matrices.dense import diag + return diag(*list(self._vector.as_explicit())) + + def doit(self, **hints): + from sympy.assumptions import ask, Q + from sympy.matrices.expressions.matmul import MatMul + from sympy.matrices.expressions.transpose import Transpose + from sympy.matrices.dense import eye + from sympy.matrices.matrixbase import MatrixBase + vector = self._vector + # This accounts for shape (1, 1) and identity matrices, among others: + if ask(Q.diagonal(vector)): + return vector + if isinstance(vector, MatrixBase): + ret = eye(max(vector.shape)) + for i in range(ret.shape[0]): + ret[i, i] = vector[i] + return type(vector)(ret) + if vector.is_MatMul: + matrices = [arg for arg in vector.args if arg.is_Matrix] + scalars = [arg for arg in vector.args if arg not in matrices] + if scalars: + return Mul.fromiter(scalars)*DiagMatrix(MatMul.fromiter(matrices).doit()).doit() + if isinstance(vector, Transpose): + vector = vector.arg + return DiagMatrix(vector) + + +def diagonalize_vector(vector): + return DiagMatrix(vector).doit() diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/dotproduct.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/dotproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..3a413f8c79a221505f0c082d7f19f78597a2befc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/dotproduct.py @@ -0,0 +1,55 @@ +from sympy.core import Basic, Expr +from sympy.core.sympify import _sympify +from sympy.matrices.expressions.transpose import transpose + + +class DotProduct(Expr): + """ + Dot product of vector matrices + + The input should be two 1 x n or n x 1 matrices. The output represents the + scalar dotproduct. + + This is similar to using MatrixElement and MatMul, except DotProduct does + not require that one vector to be a row vector and the other vector to be + a column vector. + + >>> from sympy import MatrixSymbol, DotProduct + >>> A = MatrixSymbol('A', 1, 3) + >>> B = MatrixSymbol('B', 1, 3) + >>> DotProduct(A, B) + DotProduct(A, B) + >>> DotProduct(A, B).doit() + A[0, 0]*B[0, 0] + A[0, 1]*B[0, 1] + A[0, 2]*B[0, 2] + """ + + def __new__(cls, arg1, arg2): + arg1, arg2 = _sympify((arg1, arg2)) + + if not arg1.is_Matrix: + raise TypeError("Argument 1 of DotProduct is not a matrix") + if not arg2.is_Matrix: + raise TypeError("Argument 2 of DotProduct is not a matrix") + if not (1 in arg1.shape): + raise TypeError("Argument 1 of DotProduct is not a vector") + if not (1 in arg2.shape): + raise TypeError("Argument 2 of DotProduct is not a vector") + + if set(arg1.shape) != set(arg2.shape): + raise TypeError("DotProduct arguments are not the same length") + + return Basic.__new__(cls, arg1, arg2) + + def doit(self, expand=False, **hints): + if self.args[0].shape == self.args[1].shape: + if self.args[0].shape[0] == 1: + mul = self.args[0]*transpose(self.args[1]) + else: + mul = transpose(self.args[0])*self.args[1] + else: + if self.args[0].shape[0] == 1: + mul = self.args[0]*self.args[1] + else: + mul = transpose(self.args[0])*transpose(self.args[1]) + + return mul[0] diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/factorizations.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/factorizations.py new file mode 100644 index 0000000000000000000000000000000000000000..aff2bb81ecff99d8e733f282ac2dd187d76ce895 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/factorizations.py @@ -0,0 +1,62 @@ +from sympy.matrices.expressions import MatrixExpr +from sympy.assumptions.ask import Q + +class Factorization(MatrixExpr): + arg = property(lambda self: self.args[0]) + shape = property(lambda self: self.arg.shape) # type: ignore + +class LofLU(Factorization): + @property + def predicates(self): + return (Q.lower_triangular,) +class UofLU(Factorization): + @property + def predicates(self): + return (Q.upper_triangular,) + +class LofCholesky(LofLU): pass +class UofCholesky(UofLU): pass + +class QofQR(Factorization): + @property + def predicates(self): + return (Q.orthogonal,) +class RofQR(Factorization): + @property + def predicates(self): + return (Q.upper_triangular,) + +class EigenVectors(Factorization): + @property + def predicates(self): + return (Q.orthogonal,) +class EigenValues(Factorization): + @property + def predicates(self): + return (Q.diagonal,) + +class UofSVD(Factorization): + @property + def predicates(self): + return (Q.orthogonal,) +class SofSVD(Factorization): + @property + def predicates(self): + return (Q.diagonal,) +class VofSVD(Factorization): + @property + def predicates(self): + return (Q.orthogonal,) + + +def lu(expr): + return LofLU(expr), UofLU(expr) + +def qr(expr): + return QofQR(expr), RofQR(expr) + +def eig(expr): + return EigenValues(expr), EigenVectors(expr) + +def svd(expr): + return UofSVD(expr), SofSVD(expr), VofSVD(expr) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/fourier.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/fourier.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa9222c2a9b218f42636267235d5dd44c25f8bb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/fourier.py @@ -0,0 +1,91 @@ +from sympy.core.sympify import _sympify +from sympy.matrices.expressions import MatrixExpr +from sympy.core.numbers import I +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt + + +class DFT(MatrixExpr): + r""" + Returns a discrete Fourier transform matrix. The matrix is scaled + with :math:`\frac{1}{\sqrt{n}}` so that it is unitary. + + Parameters + ========== + + n : integer or Symbol + Size of the transform. + + Examples + ======== + + >>> from sympy.abc import n + >>> from sympy.matrices.expressions.fourier import DFT + >>> DFT(3) + DFT(3) + >>> DFT(3).as_explicit() + Matrix([ + [sqrt(3)/3, sqrt(3)/3, sqrt(3)/3], + [sqrt(3)/3, sqrt(3)*exp(-2*I*pi/3)/3, sqrt(3)*exp(2*I*pi/3)/3], + [sqrt(3)/3, sqrt(3)*exp(2*I*pi/3)/3, sqrt(3)*exp(-2*I*pi/3)/3]]) + >>> DFT(n).shape + (n, n) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/DFT_matrix + + """ + + def __new__(cls, n): + n = _sympify(n) + cls._check_dim(n) + + obj = super().__new__(cls, n) + return obj + + n = property(lambda self: self.args[0]) # type: ignore + shape = property(lambda self: (self.n, self.n)) # type: ignore + + def _entry(self, i, j, **kwargs): + w = exp(-2*S.Pi*I/self.n) + return w**(i*j) / sqrt(self.n) + + def _eval_inverse(self): + return IDFT(self.n) + + +class IDFT(DFT): + r""" + Returns an inverse discrete Fourier transform matrix. The matrix is scaled + with :math:`\frac{1}{\sqrt{n}}` so that it is unitary. + + Parameters + ========== + + n : integer or Symbol + Size of the transform + + Examples + ======== + + >>> from sympy.matrices.expressions.fourier import DFT, IDFT + >>> IDFT(3) + IDFT(3) + >>> IDFT(4)*DFT(4) + I + + See Also + ======== + + DFT + + """ + def _entry(self, i, j, **kwargs): + w = exp(-2*S.Pi*I/self.n) + return w**(-i*j) / sqrt(self.n) + + def _eval_inverse(self): + return DFT(self.n) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/funcmatrix.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/funcmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..91106edb489b73ac9dd6cb94adc508c0db75d3a5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/funcmatrix.py @@ -0,0 +1,118 @@ +from .matexpr import MatrixExpr +from sympy.core.function import FunctionClass, Lambda +from sympy.core.symbol import Dummy +from sympy.core.sympify import _sympify, sympify +from sympy.matrices import Matrix +from sympy.functions.elementary.complexes import re, im + + +class FunctionMatrix(MatrixExpr): + """Represents a matrix using a function (``Lambda``) which gives + outputs according to the coordinates of each matrix entries. + + Parameters + ========== + + rows : nonnegative integer. Can be symbolic. + + cols : nonnegative integer. Can be symbolic. + + lamda : Function, Lambda or str + If it is a SymPy ``Function`` or ``Lambda`` instance, + it should be able to accept two arguments which represents the + matrix coordinates. + + If it is a pure string containing Python ``lambda`` semantics, + it is interpreted by the SymPy parser and casted into a SymPy + ``Lambda`` instance. + + Examples + ======== + + Creating a ``FunctionMatrix`` from ``Lambda``: + + >>> from sympy import FunctionMatrix, symbols, Lambda, MatPow + >>> i, j, n, m = symbols('i,j,n,m') + >>> FunctionMatrix(n, m, Lambda((i, j), i + j)) + FunctionMatrix(n, m, Lambda((i, j), i + j)) + + Creating a ``FunctionMatrix`` from a SymPy function: + + >>> from sympy import KroneckerDelta + >>> X = FunctionMatrix(3, 3, KroneckerDelta) + >>> X.as_explicit() + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + Creating a ``FunctionMatrix`` from a SymPy undefined function: + + >>> from sympy import Function + >>> f = Function('f') + >>> X = FunctionMatrix(3, 3, f) + >>> X.as_explicit() + Matrix([ + [f(0, 0), f(0, 1), f(0, 2)], + [f(1, 0), f(1, 1), f(1, 2)], + [f(2, 0), f(2, 1), f(2, 2)]]) + + Creating a ``FunctionMatrix`` from Python ``lambda``: + + >>> FunctionMatrix(n, m, 'lambda i, j: i + j') + FunctionMatrix(n, m, Lambda((i, j), i + j)) + + Example of lazy evaluation of matrix product: + + >>> Y = FunctionMatrix(1000, 1000, Lambda((i, j), i + j)) + >>> isinstance(Y*Y, MatPow) # this is an expression object + True + >>> (Y**2)[10,10] # So this is evaluated lazily + 342923500 + + Notes + ===== + + This class provides an alternative way to represent an extremely + dense matrix with entries in some form of a sequence, in a most + sparse way. + """ + def __new__(cls, rows, cols, lamda): + rows, cols = _sympify(rows), _sympify(cols) + cls._check_dim(rows) + cls._check_dim(cols) + + lamda = sympify(lamda) + if not isinstance(lamda, (FunctionClass, Lambda)): + raise ValueError( + "{} should be compatible with SymPy function classes." + .format(lamda)) + + if 2 not in lamda.nargs: + raise ValueError( + '{} should be able to accept 2 arguments.'.format(lamda)) + + if not isinstance(lamda, Lambda): + i, j = Dummy('i'), Dummy('j') + lamda = Lambda((i, j), lamda(i, j)) + + return super().__new__(cls, rows, cols, lamda) + + @property + def shape(self): + return self.args[0:2] + + @property + def lamda(self): + return self.args[2] + + def _entry(self, i, j, **kwargs): + return self.lamda(i, j) + + def _eval_trace(self): + from sympy.matrices.expressions.trace import Trace + from sympy.concrete.summations import Sum + return Trace(self).rewrite(Sum).doit() + + def _eval_as_real_imag(self): + return (re(Matrix(self)), im(Matrix(self))) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/hadamard.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/hadamard.py new file mode 100644 index 0000000000000000000000000000000000000000..38c9033ebea3a7bfc569223978dc6ef3890206cf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/hadamard.py @@ -0,0 +1,464 @@ +from collections import Counter + +from sympy.core import Mul, sympify +from sympy.core.add import Add +from sympy.core.expr import ExprBuilder +from sympy.core.sorting import default_sort_key +from sympy.functions.elementary.exponential import log +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions._shape import validate_matadd_integer as validate +from sympy.matrices.expressions.special import ZeroMatrix, OneMatrix +from sympy.strategies import ( + unpack, flatten, condition, exhaust, rm_id, sort +) +from sympy.utilities.exceptions import sympy_deprecation_warning + + +def hadamard_product(*matrices): + """ + Return the elementwise (aka Hadamard) product of matrices. + + Examples + ======== + + >>> from sympy import hadamard_product, MatrixSymbol + >>> A = MatrixSymbol('A', 2, 3) + >>> B = MatrixSymbol('B', 2, 3) + >>> hadamard_product(A) + A + >>> hadamard_product(A, B) + HadamardProduct(A, B) + >>> hadamard_product(A, B)[0, 1] + A[0, 1]*B[0, 1] + """ + if not matrices: + raise TypeError("Empty Hadamard product is undefined") + if len(matrices) == 1: + return matrices[0] + return HadamardProduct(*matrices).doit() + + +class HadamardProduct(MatrixExpr): + """ + Elementwise product of matrix expressions + + Examples + ======== + + Hadamard product for matrix symbols: + + >>> from sympy import hadamard_product, HadamardProduct, MatrixSymbol + >>> A = MatrixSymbol('A', 5, 5) + >>> B = MatrixSymbol('B', 5, 5) + >>> isinstance(hadamard_product(A, B), HadamardProduct) + True + + Notes + ===== + + This is a symbolic object that simply stores its argument without + evaluating it. To actually compute the product, use the function + ``hadamard_product()`` or ``HadamardProduct.doit`` + """ + is_HadamardProduct = True + + def __new__(cls, *args, evaluate=False, check=None): + args = list(map(sympify, args)) + if len(args) == 0: + # We currently don't have a way to support one-matrices of generic dimensions: + raise ValueError("HadamardProduct needs at least one argument") + + if not all(isinstance(arg, MatrixExpr) for arg in args): + raise TypeError("Mix of Matrix and Scalar symbols") + + if check is not None: + sympy_deprecation_warning( + "Passing check to HadamardProduct is deprecated and the check argument will be removed in a future version.", + deprecated_since_version="1.11", + active_deprecations_target='remove-check-argument-from-matrix-operations') + + if check is not False: + validate(*args) + + obj = super().__new__(cls, *args) + if evaluate: + obj = obj.doit(deep=False) + return obj + + @property + def shape(self): + return self.args[0].shape + + def _entry(self, i, j, **kwargs): + return Mul(*[arg._entry(i, j, **kwargs) for arg in self.args]) + + def _eval_transpose(self): + from sympy.matrices.expressions.transpose import transpose + return HadamardProduct(*list(map(transpose, self.args))) + + def doit(self, **hints): + expr = self.func(*(i.doit(**hints) for i in self.args)) + # Check for explicit matrices: + from sympy.matrices.matrixbase import MatrixBase + from sympy.matrices.immutable import ImmutableMatrix + + explicit = [i for i in expr.args if isinstance(i, MatrixBase)] + if explicit: + remainder = [i for i in expr.args if i not in explicit] + expl_mat = ImmutableMatrix([ + Mul.fromiter(i) for i in zip(*explicit) + ]).reshape(*self.shape) + expr = HadamardProduct(*([expl_mat] + remainder)) + + return canonicalize(expr) + + def _eval_derivative(self, x): + terms = [] + args = list(self.args) + for i in range(len(args)): + factors = args[:i] + [args[i].diff(x)] + args[i+1:] + terms.append(hadamard_product(*factors)) + return Add.fromiter(terms) + + def _eval_derivative_matrix_lines(self, x): + from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct + from sympy.matrices.expressions.matexpr import _make_matrix + + with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)] + lines = [] + for ind in with_x_ind: + left_args = self.args[:ind] + right_args = self.args[ind+1:] + + d = self.args[ind]._eval_derivative_matrix_lines(x) + hadam = hadamard_product(*(right_args + left_args)) + diagonal = [(0, 2), (3, 4)] + diagonal = [e for j, e in enumerate(diagonal) if self.shape[j] != 1] + for i in d: + l1 = i._lines[i._first_line_index] + l2 = i._lines[i._second_line_index] + subexpr = ExprBuilder( + ArrayDiagonal, + [ + ExprBuilder( + ArrayTensorProduct, + [ + ExprBuilder(_make_matrix, [l1]), + hadam, + ExprBuilder(_make_matrix, [l2]), + ] + ), + *diagonal], + + ) + i._first_pointer_parent = subexpr.args[0].args[0].args + i._first_pointer_index = 0 + i._second_pointer_parent = subexpr.args[0].args[2].args + i._second_pointer_index = 0 + i._lines = [subexpr] + lines.append(i) + + return lines + + +# TODO Implement algorithm for rewriting Hadamard product as diagonal matrix +# if matmul identy matrix is multiplied. +def canonicalize(x): + """Canonicalize the Hadamard product ``x`` with mathematical properties. + + Examples + ======== + + >>> from sympy import MatrixSymbol, HadamardProduct + >>> from sympy import OneMatrix, ZeroMatrix + >>> from sympy.matrices.expressions.hadamard import canonicalize + >>> from sympy import init_printing + >>> init_printing(use_unicode=False) + + >>> A = MatrixSymbol('A', 2, 2) + >>> B = MatrixSymbol('B', 2, 2) + >>> C = MatrixSymbol('C', 2, 2) + + Hadamard product associativity: + + >>> X = HadamardProduct(A, HadamardProduct(B, C)) + >>> X + A.*(B.*C) + >>> canonicalize(X) + A.*B.*C + + Hadamard product commutativity: + + >>> X = HadamardProduct(A, B) + >>> Y = HadamardProduct(B, A) + >>> X + A.*B + >>> Y + B.*A + >>> canonicalize(X) + A.*B + >>> canonicalize(Y) + A.*B + + Hadamard product identity: + + >>> X = HadamardProduct(A, OneMatrix(2, 2)) + >>> X + A.*1 + >>> canonicalize(X) + A + + Absorbing element of Hadamard product: + + >>> X = HadamardProduct(A, ZeroMatrix(2, 2)) + >>> X + A.*0 + >>> canonicalize(X) + 0 + + Rewriting to Hadamard Power + + >>> X = HadamardProduct(A, A, A) + >>> X + A.*A.*A + >>> canonicalize(X) + .3 + A + + Notes + ===== + + As the Hadamard product is associative, nested products can be flattened. + + The Hadamard product is commutative so that factors can be sorted for + canonical form. + + A matrix of only ones is an identity for Hadamard product, + so every matrices of only ones can be removed. + + Any zero matrix will make the whole product a zero matrix. + + Duplicate elements can be collected and rewritten as HadamardPower + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hadamard_product_(matrices) + """ + # Associativity + rule = condition( + lambda x: isinstance(x, HadamardProduct), + flatten + ) + fun = exhaust(rule) + x = fun(x) + + # Identity + fun = condition( + lambda x: isinstance(x, HadamardProduct), + rm_id(lambda x: isinstance(x, OneMatrix)) + ) + x = fun(x) + + # Absorbing by Zero Matrix + def absorb(x): + if any(isinstance(c, ZeroMatrix) for c in x.args): + return ZeroMatrix(*x.shape) + else: + return x + fun = condition( + lambda x: isinstance(x, HadamardProduct), + absorb + ) + x = fun(x) + + # Rewriting with HadamardPower + if isinstance(x, HadamardProduct): + tally = Counter(x.args) + + new_arg = [] + for base, exp in tally.items(): + if exp == 1: + new_arg.append(base) + else: + new_arg.append(HadamardPower(base, exp)) + + x = HadamardProduct(*new_arg) + + # Commutativity + fun = condition( + lambda x: isinstance(x, HadamardProduct), + sort(default_sort_key) + ) + x = fun(x) + + # Unpacking + x = unpack(x) + return x + + +def hadamard_power(base, exp): + base = sympify(base) + exp = sympify(exp) + if exp == 1: + return base + if not base.is_Matrix: + return base**exp + if exp.is_Matrix: + raise ValueError("cannot raise expression to a matrix") + return HadamardPower(base, exp) + + +class HadamardPower(MatrixExpr): + r""" + Elementwise power of matrix expressions + + Parameters + ========== + + base : scalar or matrix + + exp : scalar or matrix + + Notes + ===== + + There are four definitions for the hadamard power which can be used. + Let's consider `A, B` as `(m, n)` matrices, and `a, b` as scalars. + + Matrix raised to a scalar exponent: + + .. math:: + A^{\circ b} = \begin{bmatrix} + A_{0, 0}^b & A_{0, 1}^b & \cdots & A_{0, n-1}^b \\ + A_{1, 0}^b & A_{1, 1}^b & \cdots & A_{1, n-1}^b \\ + \vdots & \vdots & \ddots & \vdots \\ + A_{m-1, 0}^b & A_{m-1, 1}^b & \cdots & A_{m-1, n-1}^b + \end{bmatrix} + + Scalar raised to a matrix exponent: + + .. math:: + a^{\circ B} = \begin{bmatrix} + a^{B_{0, 0}} & a^{B_{0, 1}} & \cdots & a^{B_{0, n-1}} \\ + a^{B_{1, 0}} & a^{B_{1, 1}} & \cdots & a^{B_{1, n-1}} \\ + \vdots & \vdots & \ddots & \vdots \\ + a^{B_{m-1, 0}} & a^{B_{m-1, 1}} & \cdots & a^{B_{m-1, n-1}} + \end{bmatrix} + + Matrix raised to a matrix exponent: + + .. math:: + A^{\circ B} = \begin{bmatrix} + A_{0, 0}^{B_{0, 0}} & A_{0, 1}^{B_{0, 1}} & + \cdots & A_{0, n-1}^{B_{0, n-1}} \\ + A_{1, 0}^{B_{1, 0}} & A_{1, 1}^{B_{1, 1}} & + \cdots & A_{1, n-1}^{B_{1, n-1}} \\ + \vdots & \vdots & + \ddots & \vdots \\ + A_{m-1, 0}^{B_{m-1, 0}} & A_{m-1, 1}^{B_{m-1, 1}} & + \cdots & A_{m-1, n-1}^{B_{m-1, n-1}} + \end{bmatrix} + + Scalar raised to a scalar exponent: + + .. math:: + a^{\circ b} = a^b + """ + + def __new__(cls, base, exp): + base = sympify(base) + exp = sympify(exp) + + if base.is_scalar and exp.is_scalar: + return base ** exp + + if isinstance(base, MatrixExpr) and isinstance(exp, MatrixExpr): + validate(base, exp) + + obj = super().__new__(cls, base, exp) + return obj + + @property + def base(self): + return self._args[0] + + @property + def exp(self): + return self._args[1] + + @property + def shape(self): + if self.base.is_Matrix: + return self.base.shape + return self.exp.shape + + def _entry(self, i, j, **kwargs): + base = self.base + exp = self.exp + + if base.is_Matrix: + a = base._entry(i, j, **kwargs) + elif base.is_scalar: + a = base + else: + raise ValueError( + 'The base {} must be a scalar or a matrix.'.format(base)) + + if exp.is_Matrix: + b = exp._entry(i, j, **kwargs) + elif exp.is_scalar: + b = exp + else: + raise ValueError( + 'The exponent {} must be a scalar or a matrix.'.format(exp)) + + return a ** b + + def _eval_transpose(self): + from sympy.matrices.expressions.transpose import transpose + return HadamardPower(transpose(self.base), self.exp) + + def _eval_derivative(self, x): + dexp = self.exp.diff(x) + logbase = self.base.applyfunc(log) + dlbase = logbase.diff(x) + return hadamard_product( + dexp*logbase + self.exp*dlbase, + self + ) + + def _eval_derivative_matrix_lines(self, x): + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct + from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal + from sympy.matrices.expressions.matexpr import _make_matrix + + lr = self.base._eval_derivative_matrix_lines(x) + for i in lr: + diagonal = [(1, 2), (3, 4)] + diagonal = [e for j, e in enumerate(diagonal) if self.base.shape[j] != 1] + l1 = i._lines[i._first_line_index] + l2 = i._lines[i._second_line_index] + subexpr = ExprBuilder( + ArrayDiagonal, + [ + ExprBuilder( + ArrayTensorProduct, + [ + ExprBuilder(_make_matrix, [l1]), + self.exp*hadamard_power(self.base, self.exp-1), + ExprBuilder(_make_matrix, [l2]), + ] + ), + *diagonal], + validator=ArrayDiagonal._validate + ) + i._first_pointer_parent = subexpr.args[0].args[0].args + i._first_pointer_index = 0 + i._first_line_index = 0 + i._second_pointer_parent = subexpr.args[0].args[2].args + i._second_pointer_index = 0 + i._second_line_index = 0 + i._lines = [subexpr] + return lr diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/inverse.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc3feccd7126a761f18f23599eed9413c86a9e5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/inverse.py @@ -0,0 +1,112 @@ +from sympy.core.sympify import _sympify +from sympy.core import S, Basic + +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices.expressions.matpow import MatPow + + +class Inverse(MatPow): + """ + The multiplicative inverse of a matrix expression + + This is a symbolic object that simply stores its argument without + evaluating it. To actually compute the inverse, use the ``.inverse()`` + method of matrices. + + Examples + ======== + + >>> from sympy import MatrixSymbol, Inverse + >>> A = MatrixSymbol('A', 3, 3) + >>> B = MatrixSymbol('B', 3, 3) + >>> Inverse(A) + A**(-1) + >>> A.inverse() == Inverse(A) + True + >>> (A*B).inverse() + B**(-1)*A**(-1) + >>> Inverse(A*B) + (A*B)**(-1) + + """ + is_Inverse = True + exp = S.NegativeOne + + def __new__(cls, mat, exp=S.NegativeOne): + # exp is there to make it consistent with + # inverse.func(*inverse.args) == inverse + mat = _sympify(mat) + exp = _sympify(exp) + if not mat.is_Matrix: + raise TypeError("mat should be a matrix") + if mat.is_square is False: + raise NonSquareMatrixError("Inverse of non-square matrix %s" % mat) + return Basic.__new__(cls, mat, exp) + + @property + def arg(self): + return self.args[0] + + @property + def shape(self): + return self.arg.shape + + def _eval_inverse(self): + return self.arg + + def _eval_transpose(self): + return Inverse(self.arg.transpose()) + + def _eval_adjoint(self): + return Inverse(self.arg.adjoint()) + + def _eval_conjugate(self): + return Inverse(self.arg.conjugate()) + + def _eval_determinant(self): + from sympy.matrices.expressions.determinant import det + return 1/det(self.arg) + + def doit(self, **hints): + if 'inv_expand' in hints and hints['inv_expand'] == False: + return self + + arg = self.arg + if hints.get('deep', True): + arg = arg.doit(**hints) + + return arg.inverse() + + def _eval_derivative_matrix_lines(self, x): + arg = self.args[0] + lines = arg._eval_derivative_matrix_lines(x) + for line in lines: + line.first_pointer *= -self.T + line.second_pointer *= self + return lines + + +from sympy.assumptions.ask import ask, Q +from sympy.assumptions.refine import handlers_dict + + +def refine_Inverse(expr, assumptions): + """ + >>> from sympy import MatrixSymbol, Q, assuming, refine + >>> X = MatrixSymbol('X', 2, 2) + >>> X.I + X**(-1) + >>> with assuming(Q.orthogonal(X)): + ... print(refine(X.I)) + X.T + """ + if ask(Q.orthogonal(expr), assumptions): + return expr.arg.T + elif ask(Q.unitary(expr), assumptions): + return expr.arg.conjugate() + elif ask(Q.singular(expr), assumptions): + raise ValueError("Inverse of singular matrix %s" % expr.arg) + + return expr + +handlers_dict['Inverse'] = refine_Inverse diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/kronecker.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/kronecker.py new file mode 100644 index 0000000000000000000000000000000000000000..1dd175cb0d500af3e786e2d0dbf6b010947840b4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/kronecker.py @@ -0,0 +1,434 @@ +"""Implementation of the Kronecker product""" +from functools import reduce +from math import prod + +from sympy.core import Mul, sympify +from sympy.functions import adjoint +from sympy.matrices.exceptions import ShapeError +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.transpose import transpose +from sympy.matrices.expressions.special import Identity +from sympy.matrices.matrixbase import MatrixBase +from sympy.strategies import ( + canon, condition, distribute, do_one, exhaust, flatten, typed, unpack) +from sympy.strategies.traverse import bottom_up +from sympy.utilities import sift + +from .matadd import MatAdd +from .matmul import MatMul +from .matpow import MatPow + + +def kronecker_product(*matrices): + """ + The Kronecker product of two or more arguments. + + This computes the explicit Kronecker product for subclasses of + ``MatrixBase`` i.e. explicit matrices. Otherwise, a symbolic + ``KroneckerProduct`` object is returned. + + + Examples + ======== + + For ``MatrixSymbol`` arguments a ``KroneckerProduct`` object is returned. + Elements of this matrix can be obtained by indexing, or for MatrixSymbols + with known dimension the explicit matrix can be obtained with + ``.as_explicit()`` + + >>> from sympy import kronecker_product, MatrixSymbol + >>> A = MatrixSymbol('A', 2, 2) + >>> B = MatrixSymbol('B', 2, 2) + >>> kronecker_product(A) + A + >>> kronecker_product(A, B) + KroneckerProduct(A, B) + >>> kronecker_product(A, B)[0, 1] + A[0, 0]*B[0, 1] + >>> kronecker_product(A, B).as_explicit() + Matrix([ + [A[0, 0]*B[0, 0], A[0, 0]*B[0, 1], A[0, 1]*B[0, 0], A[0, 1]*B[0, 1]], + [A[0, 0]*B[1, 0], A[0, 0]*B[1, 1], A[0, 1]*B[1, 0], A[0, 1]*B[1, 1]], + [A[1, 0]*B[0, 0], A[1, 0]*B[0, 1], A[1, 1]*B[0, 0], A[1, 1]*B[0, 1]], + [A[1, 0]*B[1, 0], A[1, 0]*B[1, 1], A[1, 1]*B[1, 0], A[1, 1]*B[1, 1]]]) + + For explicit matrices the Kronecker product is returned as a Matrix + + >>> from sympy import Matrix, kronecker_product + >>> sigma_x = Matrix([ + ... [0, 1], + ... [1, 0]]) + ... + >>> Isigma_y = Matrix([ + ... [0, 1], + ... [-1, 0]]) + ... + >>> kronecker_product(sigma_x, Isigma_y) + Matrix([ + [ 0, 0, 0, 1], + [ 0, 0, -1, 0], + [ 0, 1, 0, 0], + [-1, 0, 0, 0]]) + + See Also + ======== + KroneckerProduct + + """ + if not matrices: + raise TypeError("Empty Kronecker product is undefined") + if len(matrices) == 1: + return matrices[0] + else: + return KroneckerProduct(*matrices).doit() + + +class KroneckerProduct(MatrixExpr): + """ + The Kronecker product of two or more arguments. + + The Kronecker product is a non-commutative product of matrices. + Given two matrices of dimension (m, n) and (s, t) it produces a matrix + of dimension (m s, n t). + + This is a symbolic object that simply stores its argument without + evaluating it. To actually compute the product, use the function + ``kronecker_product()`` or call the ``.doit()`` or ``.as_explicit()`` + methods. + + >>> from sympy import KroneckerProduct, MatrixSymbol + >>> A = MatrixSymbol('A', 5, 5) + >>> B = MatrixSymbol('B', 5, 5) + >>> isinstance(KroneckerProduct(A, B), KroneckerProduct) + True + """ + is_KroneckerProduct = True + + def __new__(cls, *args, check=True): + args = list(map(sympify, args)) + if all(a.is_Identity for a in args): + ret = Identity(prod(a.rows for a in args)) + if all(isinstance(a, MatrixBase) for a in args): + return ret.as_explicit() + else: + return ret + + if check: + validate(*args) + return super().__new__(cls, *args) + + @property + def shape(self): + rows, cols = self.args[0].shape + for mat in self.args[1:]: + rows *= mat.rows + cols *= mat.cols + return (rows, cols) + + def _entry(self, i, j, **kwargs): + result = 1 + for mat in reversed(self.args): + i, m = divmod(i, mat.rows) + j, n = divmod(j, mat.cols) + result *= mat[m, n] + return result + + def _eval_adjoint(self): + return KroneckerProduct(*list(map(adjoint, self.args))).doit() + + def _eval_conjugate(self): + return KroneckerProduct(*[a.conjugate() for a in self.args]).doit() + + def _eval_transpose(self): + return KroneckerProduct(*list(map(transpose, self.args))).doit() + + def _eval_trace(self): + from .trace import trace + return Mul(*[trace(a) for a in self.args]) + + def _eval_determinant(self): + from .determinant import det, Determinant + if not all(a.is_square for a in self.args): + return Determinant(self) + + m = self.rows + return Mul(*[det(a)**(m/a.rows) for a in self.args]) + + def _eval_inverse(self): + try: + return KroneckerProduct(*[a.inverse() for a in self.args]) + except ShapeError: + from sympy.matrices.expressions.inverse import Inverse + return Inverse(self) + + def structurally_equal(self, other): + '''Determine whether two matrices have the same Kronecker product structure + + Examples + ======== + + >>> from sympy import KroneckerProduct, MatrixSymbol, symbols + >>> m, n = symbols(r'm, n', integer=True) + >>> A = MatrixSymbol('A', m, m) + >>> B = MatrixSymbol('B', n, n) + >>> C = MatrixSymbol('C', m, m) + >>> D = MatrixSymbol('D', n, n) + >>> KroneckerProduct(A, B).structurally_equal(KroneckerProduct(C, D)) + True + >>> KroneckerProduct(A, B).structurally_equal(KroneckerProduct(D, C)) + False + >>> KroneckerProduct(A, B).structurally_equal(C) + False + ''' + # Inspired by BlockMatrix + return (isinstance(other, KroneckerProduct) + and self.shape == other.shape + and len(self.args) == len(other.args) + and all(a.shape == b.shape for (a, b) in zip(self.args, other.args))) + + def has_matching_shape(self, other): + '''Determine whether two matrices have the appropriate structure to bring matrix + multiplication inside the KroneckerProdut + + Examples + ======== + >>> from sympy import KroneckerProduct, MatrixSymbol, symbols + >>> m, n = symbols(r'm, n', integer=True) + >>> A = MatrixSymbol('A', m, n) + >>> B = MatrixSymbol('B', n, m) + >>> KroneckerProduct(A, B).has_matching_shape(KroneckerProduct(B, A)) + True + >>> KroneckerProduct(A, B).has_matching_shape(KroneckerProduct(A, B)) + False + >>> KroneckerProduct(A, B).has_matching_shape(A) + False + ''' + return (isinstance(other, KroneckerProduct) + and self.cols == other.rows + and len(self.args) == len(other.args) + and all(a.cols == b.rows for (a, b) in zip(self.args, other.args))) + + def _eval_expand_kroneckerproduct(self, **hints): + return flatten(canon(typed({KroneckerProduct: distribute(KroneckerProduct, MatAdd)}))(self)) + + def _kronecker_add(self, other): + if self.structurally_equal(other): + return self.__class__(*[a + b for (a, b) in zip(self.args, other.args)]) + else: + return self + other + + def _kronecker_mul(self, other): + if self.has_matching_shape(other): + return self.__class__(*[a*b for (a, b) in zip(self.args, other.args)]) + else: + return self * other + + def doit(self, **hints): + deep = hints.get('deep', True) + if deep: + args = [arg.doit(**hints) for arg in self.args] + else: + args = self.args + return canonicalize(KroneckerProduct(*args)) + + +def validate(*args): + if not all(arg.is_Matrix for arg in args): + raise TypeError("Mix of Matrix and Scalar symbols") + + +# rules + +def extract_commutative(kron): + c_part = [] + nc_part = [] + for arg in kron.args: + c, nc = arg.args_cnc() + c_part.extend(c) + nc_part.append(Mul._from_args(nc)) + + c_part = Mul(*c_part) + if c_part != 1: + return c_part*KroneckerProduct(*nc_part) + return kron + + +def matrix_kronecker_product(*matrices): + """Compute the Kronecker product of a sequence of SymPy Matrices. + + This is the standard Kronecker product of matrices [1]. + + Parameters + ========== + + matrices : tuple of MatrixBase instances + The matrices to take the Kronecker product of. + + Returns + ======= + + matrix : MatrixBase + The Kronecker product matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.matrices.expressions.kronecker import ( + ... matrix_kronecker_product) + + >>> m1 = Matrix([[1,2],[3,4]]) + >>> m2 = Matrix([[1,0],[0,1]]) + >>> matrix_kronecker_product(m1, m2) + Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2], + [3, 0, 4, 0], + [0, 3, 0, 4]]) + >>> matrix_kronecker_product(m2, m1) + Matrix([ + [1, 2, 0, 0], + [3, 4, 0, 0], + [0, 0, 1, 2], + [0, 0, 3, 4]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Kronecker_product + """ + # Make sure we have a sequence of Matrices + if not all(isinstance(m, MatrixBase) for m in matrices): + raise TypeError( + 'Sequence of Matrices expected, got: %s' % repr(matrices) + ) + + # Pull out the first element in the product. + matrix_expansion = matrices[-1] + # Do the kronecker product working from right to left. + for mat in reversed(matrices[:-1]): + rows = mat.rows + cols = mat.cols + # Go through each row appending kronecker product to. + # running matrix_expansion. + for i in range(rows): + start = matrix_expansion*mat[i*cols] + # Go through each column joining each item + for j in range(cols - 1): + start = start.row_join( + matrix_expansion*mat[i*cols + j + 1] + ) + # If this is the first element, make it the start of the + # new row. + if i == 0: + next = start + else: + next = next.col_join(start) + matrix_expansion = next + + MatrixClass = max(matrices, key=lambda M: M._class_priority).__class__ + if isinstance(matrix_expansion, MatrixClass): + return matrix_expansion + else: + return MatrixClass(matrix_expansion) + + +def explicit_kronecker_product(kron): + # Make sure we have a sequence of Matrices + if not all(isinstance(m, MatrixBase) for m in kron.args): + return kron + + return matrix_kronecker_product(*kron.args) + + +rules = (unpack, + explicit_kronecker_product, + flatten, + extract_commutative) + +canonicalize = exhaust(condition(lambda x: isinstance(x, KroneckerProduct), + do_one(*rules))) + + +def _kronecker_dims_key(expr): + if isinstance(expr, KroneckerProduct): + return tuple(a.shape for a in expr.args) + else: + return (0,) + + +def kronecker_mat_add(expr): + args = sift(expr.args, _kronecker_dims_key) + nonkrons = args.pop((0,), None) + if not args: + return expr + + krons = [reduce(lambda x, y: x._kronecker_add(y), group) + for group in args.values()] + + if not nonkrons: + return MatAdd(*krons) + else: + return MatAdd(*krons) + nonkrons + + +def kronecker_mat_mul(expr): + # modified from block matrix code + factor, matrices = expr.as_coeff_matrices() + + i = 0 + while i < len(matrices) - 1: + A, B = matrices[i:i+2] + if isinstance(A, KroneckerProduct) and isinstance(B, KroneckerProduct): + matrices[i] = A._kronecker_mul(B) + matrices.pop(i+1) + else: + i += 1 + + return factor*MatMul(*matrices) + + +def kronecker_mat_pow(expr): + if isinstance(expr.base, KroneckerProduct) and all(a.is_square for a in expr.base.args): + return KroneckerProduct(*[MatPow(a, expr.exp) for a in expr.base.args]) + else: + return expr + + +def combine_kronecker(expr): + """Combine KronekeckerProduct with expression. + + If possible write operations on KroneckerProducts of compatible shapes + as a single KroneckerProduct. + + Examples + ======== + + >>> from sympy.matrices.expressions import combine_kronecker + >>> from sympy import MatrixSymbol, KroneckerProduct, symbols + >>> m, n = symbols(r'm, n', integer=True) + >>> A = MatrixSymbol('A', m, n) + >>> B = MatrixSymbol('B', n, m) + >>> combine_kronecker(KroneckerProduct(A, B)*KroneckerProduct(B, A)) + KroneckerProduct(A*B, B*A) + >>> combine_kronecker(KroneckerProduct(A, B)+KroneckerProduct(B.T, A.T)) + KroneckerProduct(A + B.T, B + A.T) + >>> C = MatrixSymbol('C', n, n) + >>> D = MatrixSymbol('D', m, m) + >>> combine_kronecker(KroneckerProduct(C, D)**m) + KroneckerProduct(C**m, D**m) + """ + def haskron(expr): + return isinstance(expr, MatrixExpr) and expr.has(KroneckerProduct) + + rule = exhaust( + bottom_up(exhaust(condition(haskron, typed( + {MatAdd: kronecker_mat_add, + MatMul: kronecker_mat_mul, + MatPow: kronecker_mat_pow}))))) + result = rule(expr) + doit = getattr(result, 'doit', None) + if doit is not None: + return doit() + else: + return result diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/matadd.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/matadd.py new file mode 100644 index 0000000000000000000000000000000000000000..cfae1e5010e4077c7210c85c4315ed2404f245d7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/matadd.py @@ -0,0 +1,155 @@ +from functools import reduce +import operator + +from sympy.core import Basic, sympify +from sympy.core.add import add, Add, _could_extract_minus_sign +from sympy.core.sorting import default_sort_key +from sympy.functions import adjoint +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices.expressions.transpose import transpose +from sympy.strategies import (rm_id, unpack, flatten, sort, condition, + exhaust, do_one, glom) +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.special import ZeroMatrix, GenericZeroMatrix +from sympy.matrices.expressions._shape import validate_matadd_integer as validate +from sympy.utilities.iterables import sift +from sympy.utilities.exceptions import sympy_deprecation_warning + +# XXX: MatAdd should perhaps not subclass directly from Add +class MatAdd(MatrixExpr, Add): + """A Sum of Matrix Expressions + + MatAdd inherits from and operates like SymPy Add + + Examples + ======== + + >>> from sympy import MatAdd, MatrixSymbol + >>> A = MatrixSymbol('A', 5, 5) + >>> B = MatrixSymbol('B', 5, 5) + >>> C = MatrixSymbol('C', 5, 5) + >>> MatAdd(A, B, C) + A + B + C + """ + is_MatAdd = True + + identity = GenericZeroMatrix() + + def __new__(cls, *args, evaluate=False, check=None, _sympify=True): + if not args: + return cls.identity + + # This must be removed aggressively in the constructor to avoid + # TypeErrors from GenericZeroMatrix().shape + args = list(filter(lambda i: cls.identity != i, args)) + if _sympify: + args = list(map(sympify, args)) + + if not all(isinstance(arg, MatrixExpr) for arg in args): + raise TypeError("Mix of Matrix and Scalar symbols") + + obj = Basic.__new__(cls, *args) + + if check is not None: + sympy_deprecation_warning( + "Passing check to MatAdd is deprecated and the check argument will be removed in a future version.", + deprecated_since_version="1.11", + active_deprecations_target='remove-check-argument-from-matrix-operations') + + if check is not False: + validate(*args) + + if evaluate: + obj = cls._evaluate(obj) + + return obj + + @classmethod + def _evaluate(cls, expr): + return canonicalize(expr) + + @property + def shape(self): + return self.args[0].shape + + def could_extract_minus_sign(self): + return _could_extract_minus_sign(self) + + def expand(self, **kwargs): + expanded = super(MatAdd, self).expand(**kwargs) + return self._evaluate(expanded) + + def _entry(self, i, j, **kwargs): + return Add(*[arg._entry(i, j, **kwargs) for arg in self.args]) + + def _eval_transpose(self): + return MatAdd(*[transpose(arg) for arg in self.args]).doit() + + def _eval_adjoint(self): + return MatAdd(*[adjoint(arg) for arg in self.args]).doit() + + def _eval_trace(self): + from .trace import trace + return Add(*[trace(arg) for arg in self.args]).doit() + + def doit(self, **hints): + deep = hints.get('deep', True) + if deep: + args = [arg.doit(**hints) for arg in self.args] + else: + args = self.args + return canonicalize(MatAdd(*args)) + + def _eval_derivative_matrix_lines(self, x): + add_lines = [arg._eval_derivative_matrix_lines(x) for arg in self.args] + return [j for i in add_lines for j in i] + +add.register_handlerclass((Add, MatAdd), MatAdd) + + +factor_of = lambda arg: arg.as_coeff_mmul()[0] +matrix_of = lambda arg: unpack(arg.as_coeff_mmul()[1]) +def combine(cnt, mat): + if cnt == 1: + return mat + else: + return cnt * mat + + +def merge_explicit(matadd): + """ Merge explicit MatrixBase arguments + + Examples + ======== + + >>> from sympy import MatrixSymbol, eye, Matrix, MatAdd, pprint + >>> from sympy.matrices.expressions.matadd import merge_explicit + >>> A = MatrixSymbol('A', 2, 2) + >>> B = eye(2) + >>> C = Matrix([[1, 2], [3, 4]]) + >>> X = MatAdd(A, B, C) + >>> pprint(X) + [1 0] [1 2] + A + [ ] + [ ] + [0 1] [3 4] + >>> pprint(merge_explicit(X)) + [2 2] + A + [ ] + [3 5] + """ + groups = sift(matadd.args, lambda arg: isinstance(arg, MatrixBase)) + if len(groups[True]) > 1: + return MatAdd(*(groups[False] + [reduce(operator.add, groups[True])])) + else: + return matadd + + +rules = (rm_id(lambda x: x == 0 or isinstance(x, ZeroMatrix)), + unpack, + flatten, + glom(matrix_of, factor_of, combine), + merge_explicit, + sort(default_sort_key)) + +canonicalize = exhaust(condition(lambda x: isinstance(x, MatAdd), + do_one(*rules))) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/matexpr.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/matexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e99296ccfcbdac5e09a86ecee020adf9831c73 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/matexpr.py @@ -0,0 +1,888 @@ +from __future__ import annotations +from functools import wraps + +from sympy.core import S, Integer, Basic, Mul, Add +from sympy.core.assumptions import check_assumptions +from sympy.core.decorators import call_highest_priority +from sympy.core.expr import Expr, ExprBuilder +from sympy.core.logic import FuzzyBool +from sympy.core.symbol import Str, Dummy, symbols, Symbol +from sympy.core.sympify import SympifyError, _sympify +from sympy.external.gmpy import SYMPY_INTS +from sympy.functions import conjugate, adjoint +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices.kind import MatrixKind +from sympy.matrices.matrixbase import MatrixBase +from sympy.multipledispatch import dispatch +from sympy.utilities.misc import filldedent + + +def _sympifyit(arg, retval=None): + # This version of _sympifyit sympifies MutableMatrix objects + def deco(func): + @wraps(func) + def __sympifyit_wrapper(a, b): + try: + b = _sympify(b) + return func(a, b) + except SympifyError: + return retval + + return __sympifyit_wrapper + + return deco + + +class MatrixExpr(Expr): + """Superclass for Matrix Expressions + + MatrixExprs represent abstract matrices, linear transformations represented + within a particular basis. + + Examples + ======== + + >>> from sympy import MatrixSymbol + >>> A = MatrixSymbol('A', 3, 3) + >>> y = MatrixSymbol('y', 3, 1) + >>> x = (A.T*A).I * A * y + + See Also + ======== + + MatrixSymbol, MatAdd, MatMul, Transpose, Inverse + """ + __slots__: tuple[str, ...] = () + + # Should not be considered iterable by the + # sympy.utilities.iterables.iterable function. Subclass that actually are + # iterable (i.e., explicit matrices) should set this to True. + _iterable = False + + _op_priority = 11.0 + + is_Matrix: bool = True + is_MatrixExpr: bool = True + is_Identity: FuzzyBool = None + is_Inverse = False + is_Transpose = False + is_ZeroMatrix = False + is_MatAdd = False + is_MatMul = False + + is_commutative = False + is_number = False + is_symbol = False + is_scalar = False + + kind: MatrixKind = MatrixKind() + + def __new__(cls, *args, **kwargs): + args = map(_sympify, args) + return Basic.__new__(cls, *args, **kwargs) + + # The following is adapted from the core Expr object + + @property + def shape(self) -> tuple[Expr, Expr]: + raise NotImplementedError + + @property + def _add_handler(self): + return MatAdd + + @property + def _mul_handler(self): + return MatMul + + def __neg__(self): + return MatMul(S.NegativeOne, self).doit() + + def __abs__(self): + raise NotImplementedError + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__radd__') + def __add__(self, other): + return MatAdd(self, other).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__add__') + def __radd__(self, other): + return MatAdd(other, self).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rsub__') + def __sub__(self, other): + return MatAdd(self, -other).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__sub__') + def __rsub__(self, other): + return MatAdd(other, -self).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rmul__') + def __mul__(self, other): + return MatMul(self, other).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rmul__') + def __matmul__(self, other): + return MatMul(self, other).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__mul__') + def __rmul__(self, other): + return MatMul(other, self).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__mul__') + def __rmatmul__(self, other): + return MatMul(other, self).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rpow__') + def __pow__(self, other): + return MatPow(self, other).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__pow__') + def __rpow__(self, other): + raise NotImplementedError("Matrix Power not defined") + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rtruediv__') + def __truediv__(self, other): + return self * other**S.NegativeOne + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__truediv__') + def __rtruediv__(self, other): + raise NotImplementedError() + #return MatMul(other, Pow(self, S.NegativeOne)) + + @property + def rows(self): + return self.shape[0] + + @property + def cols(self): + return self.shape[1] + + @property + def is_square(self) -> bool | None: + rows, cols = self.shape + if isinstance(rows, Integer) and isinstance(cols, Integer): + return rows == cols + if rows == cols: + return True + return None + + def _eval_conjugate(self): + from sympy.matrices.expressions.adjoint import Adjoint + return Adjoint(Transpose(self)) + + def as_real_imag(self, deep=True, **hints): + return self._eval_as_real_imag() + + def _eval_as_real_imag(self): + real = S.Half * (self + self._eval_conjugate()) + im = (self - self._eval_conjugate())/(2*S.ImaginaryUnit) + return (real, im) + + def _eval_inverse(self): + return Inverse(self) + + def _eval_determinant(self): + return Determinant(self) + + def _eval_transpose(self): + return Transpose(self) + + def _eval_trace(self): + return None + + def _eval_power(self, exp): + """ + Override this in sub-classes to implement simplification of powers. The cases where the exponent + is -1, 0, 1 are already covered in MatPow.doit(), so implementations can exclude these cases. + """ + return MatPow(self, exp) + + def _eval_simplify(self, **kwargs): + if self.is_Atom: + return self + else: + from sympy.simplify import simplify + return self.func(*[simplify(x, **kwargs) for x in self.args]) + + def _eval_adjoint(self): + from sympy.matrices.expressions.adjoint import Adjoint + return Adjoint(self) + + def _eval_derivative_n_times(self, x, n): + return Basic._eval_derivative_n_times(self, x, n) + + def _eval_derivative(self, x): + # `x` is a scalar: + if self.has(x): + # See if there are other methods using it: + return super()._eval_derivative(x) + else: + return ZeroMatrix(*self.shape) + + @classmethod + def _check_dim(cls, dim): + """Helper function to check invalid matrix dimensions""" + ok = not dim.is_Float and check_assumptions( + dim, integer=True, nonnegative=True) + if ok is False: + raise ValueError( + "The dimension specification {} should be " + "a nonnegative integer.".format(dim)) + + + def _entry(self, i, j, **kwargs): + raise NotImplementedError( + "Indexing not implemented for %s" % self.__class__.__name__) + + def adjoint(self): + return adjoint(self) + + def as_coeff_Mul(self, rational=False): + """Efficiently extract the coefficient of a product.""" + return S.One, self + + def conjugate(self): + return conjugate(self) + + def transpose(self): + from sympy.matrices.expressions.transpose import transpose + return transpose(self) + + @property + def T(self): + '''Matrix transposition''' + return self.transpose() + + def inverse(self): + if self.is_square is False: + raise NonSquareMatrixError('Inverse of non-square matrix') + return self._eval_inverse() + + def inv(self): + return self.inverse() + + def det(self): + from sympy.matrices.expressions.determinant import det + return det(self) + + @property + def I(self): + return self.inverse() + + def valid_index(self, i, j): + def is_valid(idx): + return isinstance(idx, (int, Integer, Symbol, Expr)) + return (is_valid(i) and is_valid(j) and + (self.rows is None or + (i >= -self.rows) != False and (i < self.rows) != False) and + (j >= -self.cols) != False and (j < self.cols) != False) + + def __getitem__(self, key): + if not isinstance(key, tuple) and isinstance(key, slice): + from sympy.matrices.expressions.slice import MatrixSlice + return MatrixSlice(self, key, (0, None, 1)) + if isinstance(key, tuple) and len(key) == 2: + i, j = key + if isinstance(i, slice) or isinstance(j, slice): + from sympy.matrices.expressions.slice import MatrixSlice + return MatrixSlice(self, i, j) + i, j = _sympify(i), _sympify(j) + if self.valid_index(i, j) != False: + return self._entry(i, j) + else: + raise IndexError("Invalid indices (%s, %s)" % (i, j)) + elif isinstance(key, (SYMPY_INTS, Integer)): + # row-wise decomposition of matrix + rows, cols = self.shape + # allow single indexing if number of columns is known + if not isinstance(cols, Integer): + raise IndexError(filldedent(''' + Single indexing is only supported when the number + of columns is known.''')) + key = _sympify(key) + i = key // cols + j = key % cols + if self.valid_index(i, j) != False: + return self._entry(i, j) + else: + raise IndexError("Invalid index %s" % key) + elif isinstance(key, (Symbol, Expr)): + raise IndexError(filldedent(''' + Only integers may be used when addressing the matrix + with a single index.''')) + raise IndexError("Invalid index, wanted %s[i,j]" % self) + + def _is_shape_symbolic(self) -> bool: + return (not isinstance(self.rows, (SYMPY_INTS, Integer)) + or not isinstance(self.cols, (SYMPY_INTS, Integer))) + + def as_explicit(self): + """ + Returns a dense Matrix with elements represented explicitly + + Returns an object of type ImmutableDenseMatrix. + + Examples + ======== + + >>> from sympy import Identity + >>> I = Identity(3) + >>> I + I + >>> I.as_explicit() + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + See Also + ======== + as_mutable: returns mutable Matrix type + + """ + if self._is_shape_symbolic(): + raise ValueError( + 'Matrix with symbolic shape ' + 'cannot be represented explicitly.') + from sympy.matrices.immutable import ImmutableDenseMatrix + return ImmutableDenseMatrix([[self[i, j] + for j in range(self.cols)] + for i in range(self.rows)]) + + def as_mutable(self): + """ + Returns a dense, mutable matrix with elements represented explicitly + + Examples + ======== + + >>> from sympy import Identity + >>> I = Identity(3) + >>> I + I + >>> I.shape + (3, 3) + >>> I.as_mutable() + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + See Also + ======== + as_explicit: returns ImmutableDenseMatrix + """ + return self.as_explicit().as_mutable() + + def __array__(self, dtype=object, copy=None): + if copy is not None and not copy: + raise TypeError("Cannot implement copy=False when converting Matrix to ndarray") + from numpy import empty + a = empty(self.shape, dtype=object) + for i in range(self.rows): + for j in range(self.cols): + a[i, j] = self[i, j] + return a + + def equals(self, other): + """ + Test elementwise equality between matrices, potentially of different + types + + >>> from sympy import Identity, eye + >>> Identity(3).equals(eye(3)) + True + """ + return self.as_explicit().equals(other) + + def canonicalize(self): + return self + + def as_coeff_mmul(self): + return S.One, MatMul(self) + + @staticmethod + def from_index_summation(expr, first_index=None, last_index=None, dimensions=None): + r""" + Parse expression of matrices with explicitly summed indices into a + matrix expression without indices, if possible. + + This transformation expressed in mathematical notation: + + `\sum_{j=0}^{N-1} A_{i,j} B_{j,k} \Longrightarrow \mathbf{A}\cdot \mathbf{B}` + + Optional parameter ``first_index``: specify which free index to use as + the index starting the expression. + + Examples + ======== + + >>> from sympy import MatrixSymbol, MatrixExpr, Sum + >>> from sympy.abc import i, j, k, l, N + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + >>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1)) + >>> MatrixExpr.from_index_summation(expr) + A*B + + Transposition is detected: + + >>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1)) + >>> MatrixExpr.from_index_summation(expr) + A.T*B + + Detect the trace: + + >>> expr = Sum(A[i, i], (i, 0, N-1)) + >>> MatrixExpr.from_index_summation(expr) + Trace(A) + + More complicated expressions: + + >>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1)) + >>> MatrixExpr.from_index_summation(expr) + A*B.T*A.T + """ + from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array + from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + first_indices = [] + if first_index is not None: + first_indices.append(first_index) + if last_index is not None: + first_indices.append(last_index) + arr = convert_indexed_to_array(expr, first_indices=first_indices) + return convert_array_to_matrix(arr) + + def applyfunc(self, func): + from .applyfunc import ElementwiseApplyFunction + return ElementwiseApplyFunction(func, self) + + +@dispatch(MatrixExpr, Expr) +def _eval_is_eq(lhs, rhs): # noqa:F811 + return False + +@dispatch(MatrixExpr, MatrixExpr) # type: ignore +def _eval_is_eq(lhs, rhs): # noqa:F811 + if lhs.shape != rhs.shape: + return False + if (lhs - rhs).is_ZeroMatrix: + return True + +def get_postprocessor(cls): + def _postprocessor(expr): + # To avoid circular imports, we can't have MatMul/MatAdd on the top level + mat_class = {Mul: MatMul, Add: MatAdd}[cls] + nonmatrices = [] + matrices = [] + for term in expr.args: + if isinstance(term, MatrixExpr): + matrices.append(term) + else: + nonmatrices.append(term) + + if not matrices: + return cls._from_args(nonmatrices) + + if nonmatrices: + if cls == Mul: + for i in range(len(matrices)): + if not matrices[i].is_MatrixExpr: + # If one of the matrices explicit, absorb the scalar into it + # (doit will combine all explicit matrices into one, so it + # doesn't matter which) + matrices[i] = matrices[i].__mul__(cls._from_args(nonmatrices)) + nonmatrices = [] + break + + else: + # Maintain the ability to create Add(scalar, matrix) without + # raising an exception. That way different algorithms can + # replace matrix expressions with non-commutative symbols to + # manipulate them like non-commutative scalars. + return cls._from_args(nonmatrices + [mat_class(*matrices).doit(deep=False)]) + + if mat_class == MatAdd: + return mat_class(*matrices).doit(deep=False) + return mat_class(cls._from_args(nonmatrices), *matrices).doit(deep=False) + return _postprocessor + + +Basic._constructor_postprocessor_mapping[MatrixExpr] = { + "Mul": [get_postprocessor(Mul)], + "Add": [get_postprocessor(Add)], +} + + +def _matrix_derivative(expr, x, old_algorithm=False): + + if isinstance(expr, MatrixBase) or isinstance(x, MatrixBase): + # Do not use array expressions for explicit matrices: + old_algorithm = True + + if old_algorithm: + return _matrix_derivative_old_algorithm(expr, x) + + from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + from sympy.tensor.array.expressions.arrayexpr_derivatives import array_derive + from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + + array_expr = convert_matrix_to_array(expr) + diff_array_expr = array_derive(array_expr, x) + diff_matrix_expr = convert_array_to_matrix(diff_array_expr) + return diff_matrix_expr + + +def _matrix_derivative_old_algorithm(expr, x): + from sympy.tensor.array.array_derivatives import ArrayDerivative + lines = expr._eval_derivative_matrix_lines(x) + + parts = [i.build() for i in lines] + + from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + + parts = [[convert_array_to_matrix(j) for j in i] for i in parts] + + def _get_shape(elem): + if isinstance(elem, MatrixExpr): + return elem.shape + return 1, 1 + + def get_rank(parts): + return sum(j not in (1, None) for i in parts for j in _get_shape(i)) + + ranks = [get_rank(i) for i in parts] + rank = ranks[0] + + def contract_one_dims(parts): + if len(parts) == 1: + return parts[0] + else: + p1, p2 = parts[:2] + if p2.is_Matrix: + p2 = p2.T + if p1 == Identity(1): + pbase = p2 + elif p2 == Identity(1): + pbase = p1 + else: + pbase = p1*p2 + if len(parts) == 2: + return pbase + else: # len(parts) > 2 + if pbase.is_Matrix: + raise ValueError("") + return pbase*Mul.fromiter(parts[2:]) + + if rank <= 2: + return Add.fromiter([contract_one_dims(i) for i in parts]) + + return ArrayDerivative(expr, x) + + +class MatrixElement(Expr): + parent = property(lambda self: self.args[0]) + i = property(lambda self: self.args[1]) + j = property(lambda self: self.args[2]) + _diff_wrt = True + is_symbol = True + is_commutative = True + + def __new__(cls, name, n, m): + n, m = map(_sympify, (n, m)) + if isinstance(name, str): + name = Symbol(name) + else: + if isinstance(name, MatrixBase): + if n.is_Integer and m.is_Integer: + return name[n, m] + name = _sympify(name) # change mutable into immutable + else: + name = _sympify(name) + if not isinstance(name.kind, MatrixKind): + raise TypeError("First argument of MatrixElement should be a matrix") + if not getattr(name, 'valid_index', lambda n, m: True)(n, m): + raise IndexError('indices out of range') + obj = Expr.__new__(cls, name, n, m) + return obj + + @property + def symbol(self): + return self.args[0] + + def doit(self, **hints): + deep = hints.get('deep', True) + if deep: + args = [arg.doit(**hints) for arg in self.args] + else: + args = self.args + return args[0][args[1], args[2]] + + @property + def indices(self): + return self.args[1:] + + def _eval_derivative(self, v): + + if not isinstance(v, MatrixElement): + return self.parent.diff(v)[self.i, self.j] + + M = self.args[0] + + m, n = self.parent.shape + + if M == v.args[0]: + return KroneckerDelta(self.args[1], v.args[1], (0, m-1)) * \ + KroneckerDelta(self.args[2], v.args[2], (0, n-1)) + + if isinstance(M, Inverse): + from sympy.concrete.summations import Sum + i, j = self.args[1:] + i1, i2 = symbols("z1, z2", cls=Dummy) + Y = M.args[0] + r1, r2 = Y.shape + return -Sum(M[i, i1]*Y[i1, i2].diff(v)*M[i2, j], (i1, 0, r1-1), (i2, 0, r2-1)) + + if self.has(v.args[0]): + return None + + return S.Zero + + +class MatrixSymbol(MatrixExpr): + """Symbolic representation of a Matrix object + + Creates a SymPy Symbol to represent a Matrix. This matrix has a shape and + can be included in Matrix Expressions + + Examples + ======== + + >>> from sympy import MatrixSymbol, Identity + >>> A = MatrixSymbol('A', 3, 4) # A 3 by 4 Matrix + >>> B = MatrixSymbol('B', 4, 3) # A 4 by 3 Matrix + >>> A.shape + (3, 4) + >>> 2*A*B + Identity(3) + I + 2*A*B + """ + is_commutative = False + is_symbol = True + _diff_wrt = True + + def __new__(cls, name, n, m): + n, m = _sympify(n), _sympify(m) + + cls._check_dim(m) + cls._check_dim(n) + + if isinstance(name, str): + name = Str(name) + obj = Basic.__new__(cls, name, n, m) + return obj + + @property + def shape(self): + return self.args[1], self.args[2] + + @property + def name(self): + return self.args[0].name + + def _entry(self, i, j, **kwargs): + return MatrixElement(self, i, j) + + @property + def free_symbols(self): + return {self} + + def _eval_simplify(self, **kwargs): + return self + + def _eval_derivative(self, x): + # x is a scalar: + return ZeroMatrix(self.shape[0], self.shape[1]) + + def _eval_derivative_matrix_lines(self, x): + if self != x: + first = ZeroMatrix(x.shape[0], self.shape[0]) if self.shape[0] != 1 else S.Zero + second = ZeroMatrix(x.shape[1], self.shape[1]) if self.shape[1] != 1 else S.Zero + return [_LeftRightArgs( + [first, second], + )] + else: + first = Identity(self.shape[0]) if self.shape[0] != 1 else S.One + second = Identity(self.shape[1]) if self.shape[1] != 1 else S.One + return [_LeftRightArgs( + [first, second], + )] + + +def matrix_symbols(expr): + return [sym for sym in expr.free_symbols if sym.is_Matrix] + + +class _LeftRightArgs: + r""" + Helper class to compute matrix derivatives. + + The logic: when an expression is derived by a matrix `X_{mn}`, two lines of + matrix multiplications are created: the one contracted to `m` (first line), + and the one contracted to `n` (second line). + + Transposition flips the side by which new matrices are connected to the + lines. + + The trace connects the end of the two lines. + """ + + def __init__(self, lines, higher=S.One): + self._lines = list(lines) + self._first_pointer_parent = self._lines + self._first_pointer_index = 0 + self._first_line_index = 0 + self._second_pointer_parent = self._lines + self._second_pointer_index = 1 + self._second_line_index = 1 + self.higher = higher + + @property + def first_pointer(self): + return self._first_pointer_parent[self._first_pointer_index] + + @first_pointer.setter + def first_pointer(self, value): + self._first_pointer_parent[self._first_pointer_index] = value + + @property + def second_pointer(self): + return self._second_pointer_parent[self._second_pointer_index] + + @second_pointer.setter + def second_pointer(self, value): + self._second_pointer_parent[self._second_pointer_index] = value + + def __repr__(self): + built = [self._build(i) for i in self._lines] + return "_LeftRightArgs(lines=%s, higher=%s)" % ( + built, + self.higher, + ) + + def transpose(self): + self._first_pointer_parent, self._second_pointer_parent = self._second_pointer_parent, self._first_pointer_parent + self._first_pointer_index, self._second_pointer_index = self._second_pointer_index, self._first_pointer_index + self._first_line_index, self._second_line_index = self._second_line_index, self._first_line_index + return self + + @staticmethod + def _build(expr): + if isinstance(expr, ExprBuilder): + return expr.build() + if isinstance(expr, list): + if len(expr) == 1: + return expr[0] + else: + return expr[0](*[_LeftRightArgs._build(i) for i in expr[1]]) + else: + return expr + + def build(self): + data = [self._build(i) for i in self._lines] + if self.higher != 1: + data += [self._build(self.higher)] + data = list(data) + return data + + def matrix_form(self): + if self.first != 1 and self.higher != 1: + raise ValueError("higher dimensional array cannot be represented") + + def _get_shape(elem): + if isinstance(elem, MatrixExpr): + return elem.shape + return (None, None) + + if _get_shape(self.first)[1] != _get_shape(self.second)[1]: + # Remove one-dimensional identity matrices: + # (this is needed by `a.diff(a)` where `a` is a vector) + if _get_shape(self.second) == (1, 1): + return self.first*self.second[0, 0] + if _get_shape(self.first) == (1, 1): + return self.first[1, 1]*self.second.T + raise ValueError("incompatible shapes") + if self.first != 1: + return self.first*self.second.T + else: + return self.higher + + def rank(self): + """ + Number of dimensions different from trivial (warning: not related to + matrix rank). + """ + rank = 0 + if self.first != 1: + rank += sum(i != 1 for i in self.first.shape) + if self.second != 1: + rank += sum(i != 1 for i in self.second.shape) + if self.higher != 1: + rank += 2 + return rank + + def _multiply_pointer(self, pointer, other): + from ...tensor.array.expressions.array_expressions import ArrayTensorProduct + from ...tensor.array.expressions.array_expressions import ArrayContraction + + subexpr = ExprBuilder( + ArrayContraction, + [ + ExprBuilder( + ArrayTensorProduct, + [ + pointer, + other + ] + ), + (1, 2) + ], + validator=ArrayContraction._validate + ) + + return subexpr + + def append_first(self, other): + self.first_pointer *= other + + def append_second(self, other): + self.second_pointer *= other + + +def _make_matrix(x): + from sympy.matrices.immutable import ImmutableDenseMatrix + if isinstance(x, MatrixExpr): + return x + return ImmutableDenseMatrix([[x]]) + + +from .matmul import MatMul +from .matadd import MatAdd +from .matpow import MatPow +from .transpose import Transpose +from .inverse import Inverse +from .special import ZeroMatrix, Identity +from .determinant import Determinant diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/matmul.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..1c46f7ff5251d89793423f92ea02d7243601de3f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/matmul.py @@ -0,0 +1,496 @@ +from sympy.assumptions.ask import ask, Q +from sympy.assumptions.refine import handlers_dict +from sympy.core import Basic, sympify, S +from sympy.core.mul import mul, Mul +from sympy.core.numbers import Number, Integer +from sympy.core.symbol import Dummy +from sympy.functions import adjoint +from sympy.strategies import (rm_id, unpack, typed, flatten, exhaust, + do_one, new) +from sympy.matrices.exceptions import NonInvertibleMatrixError +from sympy.matrices.matrixbase import MatrixBase +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.matrices.expressions._shape import validate_matmul_integer as validate + +from .inverse import Inverse +from .matexpr import MatrixExpr +from .matpow import MatPow +from .transpose import transpose +from .permutation import PermutationMatrix +from .special import ZeroMatrix, Identity, GenericIdentity, OneMatrix + + +# XXX: MatMul should perhaps not subclass directly from Mul +class MatMul(MatrixExpr, Mul): + """ + A product of matrix expressions + + Examples + ======== + + >>> from sympy import MatMul, MatrixSymbol + >>> A = MatrixSymbol('A', 5, 4) + >>> B = MatrixSymbol('B', 4, 3) + >>> C = MatrixSymbol('C', 3, 6) + >>> MatMul(A, B, C) + A*B*C + """ + is_MatMul = True + + identity = GenericIdentity() + + def __new__(cls, *args, evaluate=False, check=None, _sympify=True): + if not args: + return cls.identity + + # This must be removed aggressively in the constructor to avoid + # TypeErrors from GenericIdentity().shape + args = list(filter(lambda i: cls.identity != i, args)) + if _sympify: + args = list(map(sympify, args)) + obj = Basic.__new__(cls, *args) + factor, matrices = obj.as_coeff_matrices() + + if check is not None: + sympy_deprecation_warning( + "Passing check to MatMul is deprecated and the check argument will be removed in a future version.", + deprecated_since_version="1.11", + active_deprecations_target='remove-check-argument-from-matrix-operations') + + if check is not False: + validate(*matrices) + + if not matrices: + # Should it be + # + # return Basic.__neq__(cls, factor, GenericIdentity()) ? + return factor + + if evaluate: + return cls._evaluate(obj) + + return obj + + @classmethod + def _evaluate(cls, expr): + return canonicalize(expr) + + @property + def shape(self): + matrices = [arg for arg in self.args if arg.is_Matrix] + return (matrices[0].rows, matrices[-1].cols) + + def _entry(self, i, j, expand=True, **kwargs): + # Avoid cyclic imports + from sympy.concrete.summations import Sum + from sympy.matrices.immutable import ImmutableMatrix + + coeff, matrices = self.as_coeff_matrices() + + if len(matrices) == 1: # situation like 2*X, matmul is just X + return coeff * matrices[0][i, j] + + indices = [None]*(len(matrices) + 1) + ind_ranges = [None]*(len(matrices) - 1) + indices[0] = i + indices[-1] = j + + def f(): + counter = 1 + while True: + yield Dummy("i_%i" % counter) + counter += 1 + + dummy_generator = kwargs.get("dummy_generator", f()) + + for i in range(1, len(matrices)): + indices[i] = next(dummy_generator) + + for i, arg in enumerate(matrices[:-1]): + ind_ranges[i] = arg.shape[1] - 1 + matrices = [arg._entry(indices[i], indices[i+1], dummy_generator=dummy_generator) for i, arg in enumerate(matrices)] + expr_in_sum = Mul.fromiter(matrices) + if any(v.has(ImmutableMatrix) for v in matrices): + expand = True + result = coeff*Sum( + expr_in_sum, + *zip(indices[1:-1], [0]*len(ind_ranges), ind_ranges) + ) + + # Don't waste time in result.doit() if the sum bounds are symbolic + if not any(isinstance(v, (Integer, int)) for v in ind_ranges): + expand = False + return result.doit() if expand else result + + def as_coeff_matrices(self): + scalars = [x for x in self.args if not x.is_Matrix] + matrices = [x for x in self.args if x.is_Matrix] + coeff = Mul(*scalars) + if coeff.is_commutative is False: + raise NotImplementedError("noncommutative scalars in MatMul are not supported.") + + return coeff, matrices + + def as_coeff_mmul(self): + coeff, matrices = self.as_coeff_matrices() + return coeff, MatMul(*matrices) + + def expand(self, **kwargs): + expanded = super(MatMul, self).expand(**kwargs) + return self._evaluate(expanded) + + def _eval_transpose(self): + """Transposition of matrix multiplication. + + Notes + ===== + + The following rules are applied. + + Transposition for matrix multiplied with another matrix: + `\\left(A B\\right)^{T} = B^{T} A^{T}` + + Transposition for matrix multiplied with scalar: + `\\left(c A\\right)^{T} = c A^{T}` + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Transpose + """ + coeff, matrices = self.as_coeff_matrices() + return MatMul( + coeff, *[transpose(arg) for arg in matrices[::-1]]).doit() + + def _eval_adjoint(self): + return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit() + + def _eval_trace(self): + factor, mmul = self.as_coeff_mmul() + if factor != 1: + from .trace import trace + return factor * trace(mmul.doit()) + + def _eval_determinant(self): + from sympy.matrices.expressions.determinant import Determinant + factor, matrices = self.as_coeff_matrices() + square_matrices = only_squares(*matrices) + return factor**self.rows * Mul(*list(map(Determinant, square_matrices))) + + def _eval_inverse(self): + if all(arg.is_square for arg in self.args if isinstance(arg, MatrixExpr)): + return MatMul(*( + arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1 + for arg in self.args[::-1] + ) + ).doit() + return Inverse(self) + + def doit(self, **hints): + deep = hints.get('deep', True) + if deep: + args = tuple(arg.doit(**hints) for arg in self.args) + else: + args = self.args + + # treat scalar*MatrixSymbol or scalar*MatPow separately + expr = canonicalize(MatMul(*args)) + return expr + + # Needed for partial compatibility with Mul + def args_cnc(self, cset=False, warn=True, **kwargs): + coeff_c = [x for x in self.args if x.is_commutative] + coeff_nc = [x for x in self.args if not x.is_commutative] + if cset: + clen = len(coeff_c) + coeff_c = set(coeff_c) + if clen and warn and len(coeff_c) != clen: + raise ValueError('repeated commutative arguments: %s' % + [ci for ci in coeff_c if list(self.args).count(ci) > 1]) + return [coeff_c, coeff_nc] + + def _eval_derivative_matrix_lines(self, x): + from .transpose import Transpose + with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)] + lines = [] + for ind in with_x_ind: + left_args = self.args[:ind] + right_args = self.args[ind+1:] + + if right_args: + right_mat = MatMul.fromiter(right_args) + else: + right_mat = Identity(self.shape[1]) + if left_args: + left_rev = MatMul.fromiter([Transpose(i).doit() if i.is_Matrix else i for i in reversed(left_args)]) + else: + left_rev = Identity(self.shape[0]) + + d = self.args[ind]._eval_derivative_matrix_lines(x) + for i in d: + i.append_first(left_rev) + i.append_second(right_mat) + lines.append(i) + + return lines + +mul.register_handlerclass((Mul, MatMul), MatMul) + + +# Rules +def newmul(*args): + if args[0] == 1: + args = args[1:] + return new(MatMul, *args) + +def any_zeros(mul): + if any(arg.is_zero or (arg.is_Matrix and arg.is_ZeroMatrix) + for arg in mul.args): + matrices = [arg for arg in mul.args if arg.is_Matrix] + return ZeroMatrix(matrices[0].rows, matrices[-1].cols) + return mul + +def merge_explicit(matmul): + """ Merge explicit MatrixBase arguments + + >>> from sympy import MatrixSymbol, Matrix, MatMul, pprint + >>> from sympy.matrices.expressions.matmul import merge_explicit + >>> A = MatrixSymbol('A', 2, 2) + >>> B = Matrix([[1, 1], [1, 1]]) + >>> C = Matrix([[1, 2], [3, 4]]) + >>> X = MatMul(A, B, C) + >>> pprint(X) + [1 1] [1 2] + A*[ ]*[ ] + [1 1] [3 4] + >>> pprint(merge_explicit(X)) + [4 6] + A*[ ] + [4 6] + + >>> X = MatMul(B, A, C) + >>> pprint(X) + [1 1] [1 2] + [ ]*A*[ ] + [1 1] [3 4] + >>> pprint(merge_explicit(X)) + [1 1] [1 2] + [ ]*A*[ ] + [1 1] [3 4] + """ + if not any(isinstance(arg, MatrixBase) for arg in matmul.args): + return matmul + newargs = [] + last = matmul.args[0] + for arg in matmul.args[1:]: + if isinstance(arg, (MatrixBase, Number)) and isinstance(last, (MatrixBase, Number)): + last = last * arg + else: + newargs.append(last) + last = arg + newargs.append(last) + + return MatMul(*newargs) + +def remove_ids(mul): + """ Remove Identities from a MatMul + + This is a modified version of sympy.strategies.rm_id. + This is necessary because MatMul may contain both MatrixExprs and Exprs + as args. + + See Also + ======== + + sympy.strategies.rm_id + """ + # Separate Exprs from MatrixExprs in args + factor, mmul = mul.as_coeff_mmul() + # Apply standard rm_id for MatMuls + result = rm_id(lambda x: x.is_Identity is True)(mmul) + if result != mmul: + return newmul(factor, *result.args) # Recombine and return + else: + return mul + +def factor_in_front(mul): + factor, matrices = mul.as_coeff_matrices() + if factor != 1: + return newmul(factor, *matrices) + return mul + +def combine_powers(mul): + r"""Combine consecutive powers with the same base into one, e.g. + $$A \times A^2 \Rightarrow A^3$$ + + This also cancels out the possible matrix inverses using the + knowledgebase of :class:`~.Inverse`, e.g., + $$ Y \times X \times X^{-1} \Rightarrow Y $$ + """ + factor, args = mul.as_coeff_matrices() + new_args = [args[0]] + + for i in range(1, len(args)): + A = new_args[-1] + B = args[i] + + if isinstance(B, Inverse) and isinstance(B.arg, MatMul): + Bargs = B.arg.args + l = len(Bargs) + if list(Bargs) == new_args[-l:]: + new_args = new_args[:-l] + [Identity(B.shape[0])] + continue + + if isinstance(A, Inverse) and isinstance(A.arg, MatMul): + Aargs = A.arg.args + l = len(Aargs) + if list(Aargs) == args[i:i+l]: + identity = Identity(A.shape[0]) + new_args[-1] = identity + for j in range(i, i+l): + args[j] = identity + continue + + if A.is_square == False or B.is_square == False: + new_args.append(B) + continue + + if isinstance(A, MatPow): + A_base, A_exp = A.args + else: + A_base, A_exp = A, S.One + + if isinstance(B, MatPow): + B_base, B_exp = B.args + else: + B_base, B_exp = B, S.One + + if A_base == B_base: + new_exp = A_exp + B_exp + new_args[-1] = MatPow(A_base, new_exp).doit(deep=False) + continue + elif not isinstance(B_base, MatrixBase): + try: + B_base_inv = B_base.inverse() + except NonInvertibleMatrixError: + B_base_inv = None + if B_base_inv is not None and A_base == B_base_inv: + new_exp = A_exp - B_exp + new_args[-1] = MatPow(A_base, new_exp).doit(deep=False) + continue + new_args.append(B) + + return newmul(factor, *new_args) + +def combine_permutations(mul): + """Refine products of permutation matrices as the products of cycles. + """ + args = mul.args + l = len(args) + if l < 2: + return mul + + result = [args[0]] + for i in range(1, l): + A = result[-1] + B = args[i] + if isinstance(A, PermutationMatrix) and \ + isinstance(B, PermutationMatrix): + cycle_1 = A.args[0] + cycle_2 = B.args[0] + result[-1] = PermutationMatrix(cycle_1 * cycle_2) + else: + result.append(B) + + return MatMul(*result) + +def combine_one_matrices(mul): + """ + Combine products of OneMatrix + + e.g. OneMatrix(2, 3) * OneMatrix(3, 4) -> 3 * OneMatrix(2, 4) + """ + factor, args = mul.as_coeff_matrices() + new_args = [args[0]] + + for B in args[1:]: + A = new_args[-1] + if not isinstance(A, OneMatrix) or not isinstance(B, OneMatrix): + new_args.append(B) + continue + new_args.pop() + new_args.append(OneMatrix(A.shape[0], B.shape[1])) + factor *= A.shape[1] + + return newmul(factor, *new_args) + +def distribute_monom(mul): + """ + Simplify MatMul expressions but distributing + rational term to MatMul. + + e.g. 2*(A+B) -> 2*A + 2*B + """ + args = mul.args + if len(args) == 2: + from .matadd import MatAdd + if args[0].is_MatAdd and args[1].is_Rational: + return MatAdd(*[MatMul(mat, args[1]).doit() for mat in args[0].args]) + if args[1].is_MatAdd and args[0].is_Rational: + return MatAdd(*[MatMul(args[0], mat).doit() for mat in args[1].args]) + return mul + +rules = ( + distribute_monom, any_zeros, remove_ids, combine_one_matrices, combine_powers, unpack, rm_id(lambda x: x == 1), + merge_explicit, factor_in_front, flatten, combine_permutations) + +canonicalize = exhaust(typed({MatMul: do_one(*rules)})) + +def only_squares(*matrices): + """factor matrices only if they are square""" + if matrices[0].rows != matrices[-1].cols: + raise RuntimeError("Invalid matrices being multiplied") + out = [] + start = 0 + for i, M in enumerate(matrices): + if M.cols == matrices[start].rows: + out.append(MatMul(*matrices[start:i+1]).doit()) + start = i+1 + return out + + +def refine_MatMul(expr, assumptions): + """ + >>> from sympy import MatrixSymbol, Q, assuming, refine + >>> X = MatrixSymbol('X', 2, 2) + >>> expr = X * X.T + >>> print(expr) + X*X.T + >>> with assuming(Q.orthogonal(X)): + ... print(refine(expr)) + I + """ + newargs = [] + exprargs = [] + + for args in expr.args: + if args.is_Matrix: + exprargs.append(args) + else: + newargs.append(args) + + last = exprargs[0] + for arg in exprargs[1:]: + if arg == last.T and ask(Q.orthogonal(arg), assumptions): + last = Identity(arg.shape[0]) + elif arg == last.conjugate() and ask(Q.unitary(arg), assumptions): + last = Identity(arg.shape[0]) + else: + newargs.append(last) + last = arg + newargs.append(last) + + return MatMul(*newargs) + + +handlers_dict['MatMul'] = refine_MatMul diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/matpow.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/matpow.py new file mode 100644 index 0000000000000000000000000000000000000000..b6472995e134e9e5ebfd28a901480665d1531275 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/matpow.py @@ -0,0 +1,150 @@ +from .matexpr import MatrixExpr +from .special import Identity +from sympy.core import S +from sympy.core.expr import ExprBuilder +from sympy.core.cache import cacheit +from sympy.core.power import Pow +from sympy.core.sympify import _sympify +from sympy.matrices import MatrixBase +from sympy.matrices.exceptions import NonSquareMatrixError + + +class MatPow(MatrixExpr): + def __new__(cls, base, exp, evaluate=False, **options): + base = _sympify(base) + if not base.is_Matrix: + raise TypeError("MatPow base should be a matrix") + + if base.is_square is False: + raise NonSquareMatrixError("Power of non-square matrix %s" % base) + + exp = _sympify(exp) + obj = super().__new__(cls, base, exp) + + if evaluate: + obj = obj.doit(deep=False) + + return obj + + @property + def base(self): + return self.args[0] + + @property + def exp(self): + return self.args[1] + + @property + def shape(self): + return self.base.shape + + @cacheit + def _get_explicit_matrix(self): + return self.base.as_explicit()**self.exp + + def _entry(self, i, j, **kwargs): + from sympy.matrices.expressions import MatMul + A = self.doit() + if isinstance(A, MatPow): + # We still have a MatPow, make an explicit MatMul out of it. + if A.exp.is_Integer and A.exp.is_positive: + A = MatMul(*[A.base for k in range(A.exp)]) + elif not self._is_shape_symbolic(): + return A._get_explicit_matrix()[i, j] + else: + # Leave the expression unevaluated: + from sympy.matrices.expressions.matexpr import MatrixElement + return MatrixElement(self, i, j) + return A[i, j] + + def doit(self, **hints): + if hints.get('deep', True): + base, exp = (arg.doit(**hints) for arg in self.args) + else: + base, exp = self.args + + # combine all powers, e.g. (A ** 2) ** 3 -> A ** 6 + while isinstance(base, MatPow): + exp *= base.args[1] + base = base.args[0] + + if isinstance(base, MatrixBase): + # Delegate + return base ** exp + + # Handle simple cases so that _eval_power() in MatrixExpr sub-classes can ignore them + if exp == S.One: + return base + if exp == S.Zero: + return Identity(base.rows) + if exp == S.NegativeOne: + from sympy.matrices.expressions import Inverse + return Inverse(base).doit(**hints) + + eval_power = getattr(base, '_eval_power', None) + if eval_power is not None: + return eval_power(exp) + + return MatPow(base, exp) + + def _eval_transpose(self): + base, exp = self.args + return MatPow(base.transpose(), exp) + + def _eval_adjoint(self): + base, exp = self.args + return MatPow(base.adjoint(), exp) + + def _eval_conjugate(self): + base, exp = self.args + return MatPow(base.conjugate(), exp) + + def _eval_derivative(self, x): + return Pow._eval_derivative(self, x) + + def _eval_derivative_matrix_lines(self, x): + from sympy.tensor.array.expressions.array_expressions import ArrayContraction + from ...tensor.array.expressions.array_expressions import ArrayTensorProduct + from .matmul import MatMul + from .inverse import Inverse + exp = self.exp + if self.base.shape == (1, 1) and not exp.has(x): + lr = self.base._eval_derivative_matrix_lines(x) + for i in lr: + subexpr = ExprBuilder( + ArrayContraction, + [ + ExprBuilder( + ArrayTensorProduct, + [ + Identity(1), + i._lines[0], + exp*self.base**(exp-1), + i._lines[1], + Identity(1), + ] + ), + (0, 3, 4), (5, 7, 8) + ], + validator=ArrayContraction._validate + ) + i._first_pointer_parent = subexpr.args[0].args + i._first_pointer_index = 0 + i._second_pointer_parent = subexpr.args[0].args + i._second_pointer_index = 4 + i._lines = [subexpr] + return lr + if (exp > 0) == True: + newexpr = MatMul.fromiter([self.base for i in range(exp)]) + elif (exp == -1) == True: + return Inverse(self.base)._eval_derivative_matrix_lines(x) + elif (exp < 0) == True: + newexpr = MatMul.fromiter([Inverse(self.base) for i in range(-exp)]) + elif (exp == 0) == True: + return self.doit()._eval_derivative_matrix_lines(x) + else: + raise NotImplementedError("cannot evaluate %s derived by %s" % (self, x)) + return newexpr._eval_derivative_matrix_lines(x) + + def _eval_inverse(self): + return MatPow(self.base, -self.exp) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/permutation.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/permutation.py new file mode 100644 index 0000000000000000000000000000000000000000..5634fa941a53d8890583fe61bb29bc34f4e6000d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/permutation.py @@ -0,0 +1,303 @@ +from sympy.core import S +from sympy.core.sympify import _sympify +from sympy.functions import KroneckerDelta + +from .matexpr import MatrixExpr +from .special import ZeroMatrix, Identity, OneMatrix + + +class PermutationMatrix(MatrixExpr): + """A Permutation Matrix + + Parameters + ========== + + perm : Permutation + The permutation the matrix uses. + + The size of the permutation determines the matrix size. + + See the documentation of + :class:`sympy.combinatorics.permutations.Permutation` for + the further information of how to create a permutation object. + + Examples + ======== + + >>> from sympy import Matrix, PermutationMatrix + >>> from sympy.combinatorics import Permutation + + Creating a permutation matrix: + + >>> p = Permutation(1, 2, 0) + >>> P = PermutationMatrix(p) + >>> P = P.as_explicit() + >>> P + Matrix([ + [0, 1, 0], + [0, 0, 1], + [1, 0, 0]]) + + Permuting a matrix row and column: + + >>> M = Matrix([0, 1, 2]) + >>> Matrix(P*M) + Matrix([ + [1], + [2], + [0]]) + + >>> Matrix(M.T*P) + Matrix([[2, 0, 1]]) + + See Also + ======== + + sympy.combinatorics.permutations.Permutation + """ + + def __new__(cls, perm): + from sympy.combinatorics.permutations import Permutation + + perm = _sympify(perm) + if not isinstance(perm, Permutation): + raise ValueError( + "{} must be a SymPy Permutation instance.".format(perm)) + + return super().__new__(cls, perm) + + @property + def shape(self): + size = self.args[0].size + return (size, size) + + @property + def is_Identity(self): + return self.args[0].is_Identity + + def doit(self, **hints): + if self.is_Identity: + return Identity(self.rows) + return self + + def _entry(self, i, j, **kwargs): + perm = self.args[0] + return KroneckerDelta(perm.apply(i), j) + + def _eval_power(self, exp): + return PermutationMatrix(self.args[0] ** exp).doit() + + def _eval_inverse(self): + return PermutationMatrix(self.args[0] ** -1) + + _eval_transpose = _eval_adjoint = _eval_inverse + + def _eval_determinant(self): + sign = self.args[0].signature() + if sign == 1: + return S.One + elif sign == -1: + return S.NegativeOne + raise NotImplementedError + + def _eval_rewrite_as_BlockDiagMatrix(self, *args, **kwargs): + from sympy.combinatorics.permutations import Permutation + from .blockmatrix import BlockDiagMatrix + + perm = self.args[0] + full_cyclic_form = perm.full_cyclic_form + + cycles_picks = [] + + # Stage 1. Decompose the cycles into the blockable form. + a, b, c = 0, 0, 0 + flag = False + for cycle in full_cyclic_form: + l = len(cycle) + m = max(cycle) + + if not flag: + if m + 1 > a + l: + flag = True + temp = [cycle] + b = m + c = l + else: + cycles_picks.append([cycle]) + a += l + + else: + if m > b: + if m + 1 == a + c + l: + temp.append(cycle) + cycles_picks.append(temp) + flag = False + a = m+1 + else: + b = m + temp.append(cycle) + c += l + else: + if b + 1 == a + c + l: + temp.append(cycle) + cycles_picks.append(temp) + flag = False + a = b+1 + else: + temp.append(cycle) + c += l + + # Stage 2. Normalize each decomposed cycles and build matrix. + p = 0 + args = [] + for pick in cycles_picks: + new_cycles = [] + l = 0 + for cycle in pick: + new_cycle = [i - p for i in cycle] + new_cycles.append(new_cycle) + l += len(cycle) + p += l + perm = Permutation(new_cycles) + mat = PermutationMatrix(perm) + args.append(mat) + + return BlockDiagMatrix(*args) + + +class MatrixPermute(MatrixExpr): + r"""Symbolic representation for permuting matrix rows or columns. + + Parameters + ========== + + perm : Permutation, PermutationMatrix + The permutation to use for permuting the matrix. + The permutation can be resized to the suitable one, + + axis : 0 or 1 + The axis to permute alongside. + If `0`, it will permute the matrix rows. + If `1`, it will permute the matrix columns. + + Notes + ===== + + This follows the same notation used in + :meth:`sympy.matrices.matrixbase.MatrixBase.permute`. + + Examples + ======== + + >>> from sympy import Matrix, MatrixPermute + >>> from sympy.combinatorics import Permutation + + Permuting the matrix rows: + + >>> p = Permutation(1, 2, 0) + >>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> B = MatrixPermute(A, p, axis=0) + >>> B.as_explicit() + Matrix([ + [4, 5, 6], + [7, 8, 9], + [1, 2, 3]]) + + Permuting the matrix columns: + + >>> B = MatrixPermute(A, p, axis=1) + >>> B.as_explicit() + Matrix([ + [2, 3, 1], + [5, 6, 4], + [8, 9, 7]]) + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.permute + """ + def __new__(cls, mat, perm, axis=S.Zero): + from sympy.combinatorics.permutations import Permutation + + mat = _sympify(mat) + if not mat.is_Matrix: + raise ValueError( + "{} must be a SymPy matrix instance.".format(perm)) + + perm = _sympify(perm) + if isinstance(perm, PermutationMatrix): + perm = perm.args[0] + + if not isinstance(perm, Permutation): + raise ValueError( + "{} must be a SymPy Permutation or a PermutationMatrix " \ + "instance".format(perm)) + + axis = _sympify(axis) + if axis not in (0, 1): + raise ValueError("The axis must be 0 or 1.") + + mat_size = mat.shape[axis] + if mat_size != perm.size: + try: + perm = perm.resize(mat_size) + except ValueError: + raise ValueError( + "Size does not match between the permutation {} " + "and the matrix {} threaded over the axis {} " + "and cannot be converted." + .format(perm, mat, axis)) + + return super().__new__(cls, mat, perm, axis) + + def doit(self, deep=True, **hints): + mat, perm, axis = self.args + + if deep: + mat = mat.doit(deep=deep, **hints) + perm = perm.doit(deep=deep, **hints) + + if perm.is_Identity: + return mat + + if mat.is_Identity: + if axis is S.Zero: + return PermutationMatrix(perm) + elif axis is S.One: + return PermutationMatrix(perm**-1) + + if isinstance(mat, (ZeroMatrix, OneMatrix)): + return mat + + if isinstance(mat, MatrixPermute) and mat.args[2] == axis: + return MatrixPermute(mat.args[0], perm * mat.args[1], axis) + + return self + + @property + def shape(self): + return self.args[0].shape + + def _entry(self, i, j, **kwargs): + mat, perm, axis = self.args + + if axis == 0: + return mat[perm.apply(i), j] + elif axis == 1: + return mat[i, perm.apply(j)] + + def _eval_rewrite_as_MatMul(self, *args, **kwargs): + from .matmul import MatMul + + mat, perm, axis = self.args + + deep = kwargs.get("deep", True) + + if deep: + mat = mat.rewrite(MatMul) + + if axis == 0: + return MatMul(PermutationMatrix(perm), mat) + elif axis == 1: + return MatMul(mat, PermutationMatrix(perm**-1)) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/sets.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/sets.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4930ea8f1b058977a8dd1abdc62f1f5e2195c1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/sets.py @@ -0,0 +1,68 @@ +from sympy.core.assumptions import check_assumptions +from sympy.core.logic import fuzzy_and +from sympy.core.sympify import _sympify +from sympy.matrices.kind import MatrixKind +from sympy.sets.sets import Set, SetKind +from sympy.core.kind import NumberKind +from .matexpr import MatrixExpr + + +class MatrixSet(Set): + """ + MatrixSet represents the set of matrices with ``shape = (n, m)`` over the + given set. + + Examples + ======== + + >>> from sympy.matrices import MatrixSet + >>> from sympy import S, I, Matrix + >>> M = MatrixSet(2, 2, set=S.Reals) + >>> X = Matrix([[1, 2], [3, 4]]) + >>> X in M + True + >>> X = Matrix([[1, 2], [I, 4]]) + >>> X in M + False + + """ + is_empty = False + + def __new__(cls, n, m, set): + n, m, set = _sympify(n), _sympify(m), _sympify(set) + cls._check_dim(n) + cls._check_dim(m) + if not isinstance(set, Set): + raise TypeError("{} should be an instance of Set.".format(set)) + return Set.__new__(cls, n, m, set) + + @property + def shape(self): + return self.args[:2] + + @property + def set(self): + return self.args[2] + + def _contains(self, other): + if not isinstance(other, MatrixExpr): + raise TypeError("{} should be an instance of MatrixExpr.".format(other)) + if other.shape != self.shape: + are_symbolic = any(_sympify(x).is_Symbol for x in other.shape + self.shape) + if are_symbolic: + return None + return False + return fuzzy_and(self.set.contains(x) for x in other) + + @classmethod + def _check_dim(cls, dim): + """Helper function to check invalid matrix dimensions""" + ok = not dim.is_Float and check_assumptions( + dim, integer=True, nonnegative=True) + if ok is False: + raise ValueError( + "The dimension specification {} should be " + "a nonnegative integer.".format(dim)) + + def _kind(self): + return SetKind(MatrixKind(NumberKind)) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/slice.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/slice.py new file mode 100644 index 0000000000000000000000000000000000000000..1904b49f29c503fb4c0c909532f8342fb0f4b135 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/slice.py @@ -0,0 +1,114 @@ +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.functions.elementary.integers import floor + +def normalize(i, parentsize): + if isinstance(i, slice): + i = (i.start, i.stop, i.step) + if not isinstance(i, (tuple, list, Tuple)): + if (i < 0) == True: + i += parentsize + i = (i, i+1, 1) + i = list(i) + if len(i) == 2: + i.append(1) + start, stop, step = i + start = start or 0 + if stop is None: + stop = parentsize + if (start < 0) == True: + start += parentsize + if (stop < 0) == True: + stop += parentsize + step = step or 1 + + if ((stop - start) * step < 1) == True: + raise IndexError() + + return (start, stop, step) + +class MatrixSlice(MatrixExpr): + """ A MatrixSlice of a Matrix Expression + + Examples + ======== + + >>> from sympy import MatrixSlice, ImmutableMatrix + >>> M = ImmutableMatrix(4, 4, range(16)) + >>> M + Matrix([ + [ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11], + [12, 13, 14, 15]]) + + >>> B = MatrixSlice(M, (0, 2), (2, 4)) + >>> ImmutableMatrix(B) + Matrix([ + [2, 3], + [6, 7]]) + """ + parent = property(lambda self: self.args[0]) + rowslice = property(lambda self: self.args[1]) + colslice = property(lambda self: self.args[2]) + + def __new__(cls, parent, rowslice, colslice): + rowslice = normalize(rowslice, parent.shape[0]) + colslice = normalize(colslice, parent.shape[1]) + if not (len(rowslice) == len(colslice) == 3): + raise IndexError() + if ((0 > rowslice[0]) == True or + (parent.shape[0] < rowslice[1]) == True or + (0 > colslice[0]) == True or + (parent.shape[1] < colslice[1]) == True): + raise IndexError() + if isinstance(parent, MatrixSlice): + return mat_slice_of_slice(parent, rowslice, colslice) + return Basic.__new__(cls, parent, Tuple(*rowslice), Tuple(*colslice)) + + @property + def shape(self): + rows = self.rowslice[1] - self.rowslice[0] + rows = rows if self.rowslice[2] == 1 else floor(rows/self.rowslice[2]) + cols = self.colslice[1] - self.colslice[0] + cols = cols if self.colslice[2] == 1 else floor(cols/self.colslice[2]) + return rows, cols + + def _entry(self, i, j, **kwargs): + return self.parent._entry(i*self.rowslice[2] + self.rowslice[0], + j*self.colslice[2] + self.colslice[0], + **kwargs) + + @property + def on_diag(self): + return self.rowslice == self.colslice + + +def slice_of_slice(s, t): + start1, stop1, step1 = s + start2, stop2, step2 = t + + start = start1 + start2*step1 + step = step1 * step2 + stop = start1 + step1*stop2 + + if stop > stop1: + raise IndexError() + + return start, stop, step + + +def mat_slice_of_slice(parent, rowslice, colslice): + """ Collapse nested matrix slices + + >>> from sympy import MatrixSymbol + >>> X = MatrixSymbol('X', 10, 10) + >>> X[:, 1:5][5:8, :] + X[5:8, 1:5] + >>> X[1:9:2, 2:6][1:3, 2] + X[3:7:2, 4:5] + """ + row = slice_of_slice(parent.rowslice, rowslice) + col = slice_of_slice(parent.colslice, colslice) + return MatrixSlice(parent.parent, row, col) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/special.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/special.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e426f16ada0e4245b644867974b41b6f86b5cc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/special.py @@ -0,0 +1,299 @@ +from sympy.assumptions.ask import ask, Q +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.sympify import _sympify +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.exceptions import NonInvertibleMatrixError +from .matexpr import MatrixExpr + + +class ZeroMatrix(MatrixExpr): + """The Matrix Zero 0 - additive identity + + Examples + ======== + + >>> from sympy import MatrixSymbol, ZeroMatrix + >>> A = MatrixSymbol('A', 3, 5) + >>> Z = ZeroMatrix(3, 5) + >>> A + Z + A + >>> Z*A.T + 0 + """ + is_ZeroMatrix = True + + def __new__(cls, m, n): + m, n = _sympify(m), _sympify(n) + cls._check_dim(m) + cls._check_dim(n) + + return super().__new__(cls, m, n) + + @property + def shape(self): + return (self.args[0], self.args[1]) + + def _eval_power(self, exp): + # exp = -1, 0, 1 are already handled at this stage + if (exp < 0) == True: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible") + return self + + def _eval_transpose(self): + return ZeroMatrix(self.cols, self.rows) + + def _eval_adjoint(self): + return ZeroMatrix(self.cols, self.rows) + + def _eval_trace(self): + return S.Zero + + def _eval_determinant(self): + return S.Zero + + def _eval_inverse(self): + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + + def _eval_as_real_imag(self): + return (self, self) + + def _eval_conjugate(self): + return self + + def _entry(self, i, j, **kwargs): + return S.Zero + + +class GenericZeroMatrix(ZeroMatrix): + """ + A zero matrix without a specified shape + + This exists primarily so MatAdd() with no arguments can return something + meaningful. + """ + def __new__(cls): + # super(ZeroMatrix, cls) instead of super(GenericZeroMatrix, cls) + # because ZeroMatrix.__new__ doesn't have the same signature + return super(ZeroMatrix, cls).__new__(cls) + + @property + def rows(self): + raise TypeError("GenericZeroMatrix does not have a specified shape") + + @property + def cols(self): + raise TypeError("GenericZeroMatrix does not have a specified shape") + + @property + def shape(self): + raise TypeError("GenericZeroMatrix does not have a specified shape") + + # Avoid Matrix.__eq__ which might call .shape + def __eq__(self, other): + return isinstance(other, GenericZeroMatrix) + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return super().__hash__() + + + +class Identity(MatrixExpr): + """The Matrix Identity I - multiplicative identity + + Examples + ======== + + >>> from sympy import Identity, MatrixSymbol + >>> A = MatrixSymbol('A', 3, 5) + >>> I = Identity(3) + >>> I*A + A + """ + + is_Identity = True + + def __new__(cls, n): + n = _sympify(n) + cls._check_dim(n) + + return super().__new__(cls, n) + + @property + def rows(self): + return self.args[0] + + @property + def cols(self): + return self.args[0] + + @property + def shape(self): + return (self.args[0], self.args[0]) + + @property + def is_square(self): + return True + + def _eval_transpose(self): + return self + + def _eval_trace(self): + return self.rows + + def _eval_inverse(self): + return self + + def _eval_as_real_imag(self): + return (self, ZeroMatrix(*self.shape)) + + def _eval_conjugate(self): + return self + + def _eval_adjoint(self): + return self + + def _entry(self, i, j, **kwargs): + eq = Eq(i, j) + if eq is S.true: + return S.One + elif eq is S.false: + return S.Zero + return KroneckerDelta(i, j, (0, self.cols-1)) + + def _eval_determinant(self): + return S.One + + def _eval_power(self, exp): + return self + + +class GenericIdentity(Identity): + """ + An identity matrix without a specified shape + + This exists primarily so MatMul() with no arguments can return something + meaningful. + """ + def __new__(cls): + # super(Identity, cls) instead of super(GenericIdentity, cls) because + # Identity.__new__ doesn't have the same signature + return super(Identity, cls).__new__(cls) + + @property + def rows(self): + raise TypeError("GenericIdentity does not have a specified shape") + + @property + def cols(self): + raise TypeError("GenericIdentity does not have a specified shape") + + @property + def shape(self): + raise TypeError("GenericIdentity does not have a specified shape") + + @property + def is_square(self): + return True + + # Avoid Matrix.__eq__ which might call .shape + def __eq__(self, other): + return isinstance(other, GenericIdentity) + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return super().__hash__() + + +class OneMatrix(MatrixExpr): + """ + Matrix whose all entries are ones. + """ + def __new__(cls, m, n, evaluate=False): + m, n = _sympify(m), _sympify(n) + cls._check_dim(m) + cls._check_dim(n) + + if evaluate: + condition = Eq(m, 1) & Eq(n, 1) + if condition == True: + return Identity(1) + + obj = super().__new__(cls, m, n) + return obj + + @property + def shape(self): + return self._args + + @property + def is_Identity(self): + return self._is_1x1() == True + + def as_explicit(self): + from sympy.matrices.immutable import ImmutableDenseMatrix + return ImmutableDenseMatrix.ones(*self.shape) + + def doit(self, **hints): + args = self.args + if hints.get('deep', True): + args = [a.doit(**hints) for a in args] + return self.func(*args, evaluate=True) + + def _eval_power(self, exp): + # exp = -1, 0, 1 are already handled at this stage + if self._is_1x1() == True: + return Identity(1) + if (exp < 0) == True: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible") + if ask(Q.integer(exp)): + return self.shape[0] ** (exp - 1) * OneMatrix(*self.shape) + return super()._eval_power(exp) + + def _eval_transpose(self): + return OneMatrix(self.cols, self.rows) + + def _eval_adjoint(self): + return OneMatrix(self.cols, self.rows) + + def _eval_trace(self): + return S.One*self.rows + + def _is_1x1(self): + """Returns true if the matrix is known to be 1x1""" + shape = self.shape + return Eq(shape[0], 1) & Eq(shape[1], 1) + + def _eval_determinant(self): + condition = self._is_1x1() + if condition == True: + return S.One + elif condition == False: + return S.Zero + else: + from sympy.matrices.expressions.determinant import Determinant + return Determinant(self) + + def _eval_inverse(self): + condition = self._is_1x1() + if condition == True: + return Identity(1) + elif condition == False: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + else: + from .inverse import Inverse + return Inverse(self) + + def _eval_as_real_imag(self): + return (self, ZeroMatrix(*self.shape)) + + def _eval_conjugate(self): + return self + + def _entry(self, i, j, **kwargs): + return S.One diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_adjoint.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_adjoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7106b5740b1dc7c32f2c6f5ecb9d286b5e1dd222 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_adjoint.py @@ -0,0 +1,34 @@ +from sympy.core import symbols, S +from sympy.functions import adjoint, conjugate, transpose +from sympy.matrices.expressions import MatrixSymbol, Adjoint, trace, Transpose +from sympy.matrices import eye, Matrix + +n, m, l, k, p = symbols('n m l k p', integer=True) +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) + + +def test_adjoint(): + Sq = MatrixSymbol('Sq', n, n) + + assert Adjoint(A).shape == (m, n) + assert Adjoint(A*B).shape == (l, n) + assert adjoint(Adjoint(A)) == A + assert isinstance(Adjoint(Adjoint(A)), Adjoint) + + assert conjugate(Adjoint(A)) == Transpose(A) + assert transpose(Adjoint(A)) == Adjoint(Transpose(A)) + + assert Adjoint(eye(3)).doit() == eye(3) + + assert Adjoint(S(5)).doit() == S(5) + + assert Adjoint(Matrix([[1, 2], [3, 4]])).doit() == Matrix([[1, 3], [2, 4]]) + + assert adjoint(trace(Sq)) == conjugate(trace(Sq)) + assert trace(adjoint(Sq)) == conjugate(trace(Sq)) + + assert Adjoint(Sq)[0, 1] == conjugate(Sq[1, 0]) + + assert Adjoint(A*B).doit() == Adjoint(B) * Adjoint(A) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_applyfunc.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_applyfunc.py new file mode 100644 index 0000000000000000000000000000000000000000..d98732e2751e53938d96d7ea56c916e6fee4578e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_applyfunc.py @@ -0,0 +1,118 @@ +from sympy.core.symbol import symbols, Dummy +from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction +from sympy.core.function import Lambda +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import sin +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.matmul import MatMul +from sympy.simplify.simplify import simplify + + +X = MatrixSymbol("X", 3, 3) +Y = MatrixSymbol("Y", 3, 3) + +k = symbols("k") +Xk = MatrixSymbol("X", k, k) + +Xd = X.as_explicit() + +x, y, z, t = symbols("x y z t") + + +def test_applyfunc_matrix(): + x = Dummy('x') + double = Lambda(x, x**2) + + expr = ElementwiseApplyFunction(double, Xd) + assert isinstance(expr, ElementwiseApplyFunction) + assert expr.doit() == Xd.applyfunc(lambda x: x**2) + assert expr.shape == (3, 3) + assert expr.func(*expr.args) == expr + assert simplify(expr) == expr + assert expr[0, 0] == double(Xd[0, 0]) + + expr = ElementwiseApplyFunction(double, X) + assert isinstance(expr, ElementwiseApplyFunction) + assert isinstance(expr.doit(), ElementwiseApplyFunction) + assert expr == X.applyfunc(double) + assert expr.func(*expr.args) == expr + + expr = ElementwiseApplyFunction(exp, X*Y) + assert expr.expr == X*Y + assert expr.function.dummy_eq(Lambda(x, exp(x))) + assert expr.dummy_eq((X*Y).applyfunc(exp)) + assert expr.func(*expr.args) == expr + + assert isinstance(X*expr, MatMul) + assert (X*expr).shape == (3, 3) + Z = MatrixSymbol("Z", 2, 3) + assert (Z*expr).shape == (2, 3) + + expr = ElementwiseApplyFunction(exp, Z.T)*ElementwiseApplyFunction(exp, Z) + assert expr.shape == (3, 3) + expr = ElementwiseApplyFunction(exp, Z)*ElementwiseApplyFunction(exp, Z.T) + assert expr.shape == (2, 2) + + M = Matrix([[x, y], [z, t]]) + expr = ElementwiseApplyFunction(sin, M) + assert isinstance(expr, ElementwiseApplyFunction) + assert expr.function.dummy_eq(Lambda(x, sin(x))) + assert expr.expr == M + assert expr.doit() == M.applyfunc(sin) + assert expr.doit() == Matrix([[sin(x), sin(y)], [sin(z), sin(t)]]) + assert expr.func(*expr.args) == expr + + expr = ElementwiseApplyFunction(double, Xk) + assert expr.doit() == expr + assert expr.subs(k, 2).shape == (2, 2) + assert (expr*expr).shape == (k, k) + M = MatrixSymbol("M", k, t) + expr2 = M.T*expr*M + assert isinstance(expr2, MatMul) + assert expr2.args[1] == expr + assert expr2.shape == (t, t) + expr3 = expr*M + assert expr3.shape == (k, t) + + expr1 = ElementwiseApplyFunction(lambda x: x+1, Xk) + expr2 = ElementwiseApplyFunction(lambda x: x, Xk) + assert expr1 != expr2 + + +def test_applyfunc_entry(): + + af = X.applyfunc(sin) + assert af[0, 0] == sin(X[0, 0]) + + af = Xd.applyfunc(sin) + assert af[0, 0] == sin(X[0, 0]) + + +def test_applyfunc_as_explicit(): + + af = X.applyfunc(sin) + assert af.as_explicit() == Matrix([ + [sin(X[0, 0]), sin(X[0, 1]), sin(X[0, 2])], + [sin(X[1, 0]), sin(X[1, 1]), sin(X[1, 2])], + [sin(X[2, 0]), sin(X[2, 1]), sin(X[2, 2])], + ]) + + +def test_applyfunc_transpose(): + + af = Xk.applyfunc(sin) + assert af.T.dummy_eq(Xk.T.applyfunc(sin)) + + +def test_applyfunc_shape_11_matrices(): + M = MatrixSymbol("M", 1, 1) + + double = Lambda(x, x*2) + + expr = M.applyfunc(sin) + assert isinstance(expr, ElementwiseApplyFunction) + + expr = M.applyfunc(double) + assert isinstance(expr, MatMul) + assert expr == 2*M diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_blockmatrix.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_blockmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..1d4893cd9a4b3e47dd8e84db33031f7f6f3201fd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_blockmatrix.py @@ -0,0 +1,469 @@ +from sympy.matrices.expressions.trace import Trace +from sympy.testing.pytest import raises, slow +from sympy.matrices.expressions.blockmatrix import ( + block_collapse, bc_matmul, bc_block_plus_ident, BlockDiagMatrix, + BlockMatrix, bc_dist, bc_matadd, bc_transpose, bc_inverse, + blockcut, reblock_2x2, deblock) +from sympy.matrices.expressions import ( + MatrixSymbol, Identity, trace, det, ZeroMatrix, OneMatrix) +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matpow import MatPow +from sympy.matrices.expressions.transpose import Transpose +from sympy.matrices.exceptions import NonInvertibleMatrixError +from sympy.matrices import ( + Matrix, ImmutableMatrix, ImmutableSparseMatrix, zeros) +from sympy.core import Tuple, Expr, S, Function +from sympy.core.symbol import Symbol, symbols +from sympy.functions import transpose, im, re + +i, j, k, l, m, n, p = symbols('i:n, p', integer=True) +A = MatrixSymbol('A', n, n) +B = MatrixSymbol('B', n, n) +C = MatrixSymbol('C', n, n) +D = MatrixSymbol('D', n, n) +G = MatrixSymbol('G', n, n) +H = MatrixSymbol('H', n, n) +b1 = BlockMatrix([[G, H]]) +b2 = BlockMatrix([[G], [H]]) + +def test_bc_matmul(): + assert bc_matmul(H*b1*b2*G) == BlockMatrix([[(H*G*G + H*H*H)*G]]) + +def test_bc_matadd(): + assert bc_matadd(BlockMatrix([[G, H]]) + BlockMatrix([[H, H]])) == \ + BlockMatrix([[G+H, H+H]]) + +def test_bc_transpose(): + assert bc_transpose(Transpose(BlockMatrix([[A, B], [C, D]]))) == \ + BlockMatrix([[A.T, C.T], [B.T, D.T]]) + +def test_bc_dist_diag(): + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', m, m) + C = MatrixSymbol('C', l, l) + X = BlockDiagMatrix(A, B, C) + + assert bc_dist(X+X).equals(BlockDiagMatrix(2*A, 2*B, 2*C)) + +def test_block_plus_ident(): + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, m) + C = MatrixSymbol('C', m, n) + D = MatrixSymbol('D', m, m) + X = BlockMatrix([[A, B], [C, D]]) + Z = MatrixSymbol('Z', n + m, n + m) + assert bc_block_plus_ident(X + Identity(m + n) + Z) == \ + BlockDiagMatrix(Identity(n), Identity(m)) + X + Z + +def test_BlockMatrix(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', n, k) + C = MatrixSymbol('C', l, m) + D = MatrixSymbol('D', l, k) + M = MatrixSymbol('M', m + k, p) + N = MatrixSymbol('N', l + n, k + m) + X = BlockMatrix(Matrix([[A, B], [C, D]])) + + assert X.__class__(*X.args) == X + + # block_collapse does nothing on normal inputs + E = MatrixSymbol('E', n, m) + assert block_collapse(A + 2*E) == A + 2*E + F = MatrixSymbol('F', m, m) + assert block_collapse(E.T*A*F) == E.T*A*F + + assert X.shape == (l + n, k + m) + assert X.blockshape == (2, 2) + assert transpose(X) == BlockMatrix(Matrix([[A.T, C.T], [B.T, D.T]])) + assert transpose(X).shape == X.shape[::-1] + + # Test that BlockMatrices and MatrixSymbols can still mix + assert (X*M).is_MatMul + assert X._blockmul(M).is_MatMul + assert (X*M).shape == (n + l, p) + assert (X + N).is_MatAdd + assert X._blockadd(N).is_MatAdd + assert (X + N).shape == X.shape + + E = MatrixSymbol('E', m, 1) + F = MatrixSymbol('F', k, 1) + + Y = BlockMatrix(Matrix([[E], [F]])) + + assert (X*Y).shape == (l + n, 1) + assert block_collapse(X*Y).blocks[0, 0] == A*E + B*F + assert block_collapse(X*Y).blocks[1, 0] == C*E + D*F + + # block_collapse passes down into container objects, transposes, and inverse + assert block_collapse(transpose(X*Y)) == transpose(block_collapse(X*Y)) + assert block_collapse(Tuple(X*Y, 2*X)) == ( + block_collapse(X*Y), block_collapse(2*X)) + + # Make sure that MatrixSymbols will enter 1x1 BlockMatrix if it simplifies + Ab = BlockMatrix([[A]]) + Z = MatrixSymbol('Z', *A.shape) + assert block_collapse(Ab + Z) == A + Z + +def test_block_collapse_explicit_matrices(): + A = Matrix([[1, 2], [3, 4]]) + assert block_collapse(BlockMatrix([[A]])) == A + + A = ImmutableSparseMatrix([[1, 2], [3, 4]]) + assert block_collapse(BlockMatrix([[A]])) == A + +def test_issue_17624(): + a = MatrixSymbol("a", 2, 2) + z = ZeroMatrix(2, 2) + b = BlockMatrix([[a, z], [z, z]]) + assert block_collapse(b * b) == BlockMatrix([[a**2, z], [z, z]]) + assert block_collapse(b * b * b) == BlockMatrix([[a**3, z], [z, z]]) + +def test_issue_18618(): + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert A == Matrix(BlockDiagMatrix(A)) + +def test_BlockMatrix_trace(): + A, B, C, D = [MatrixSymbol(s, 3, 3) for s in 'ABCD'] + X = BlockMatrix([[A, B], [C, D]]) + assert trace(X) == trace(A) + trace(D) + assert trace(BlockMatrix([ZeroMatrix(n, n)])) == 0 + +def test_BlockMatrix_Determinant(): + A, B, C, D = [MatrixSymbol(s, 3, 3) for s in 'ABCD'] + X = BlockMatrix([[A, B], [C, D]]) + from sympy.assumptions.ask import Q + from sympy.assumptions.assume import assuming + with assuming(Q.invertible(A)): + assert det(X) == det(A) * det(X.schur('A')) + + assert isinstance(det(X), Expr) + assert det(BlockMatrix([A])) == det(A) + assert det(BlockMatrix([ZeroMatrix(n, n)])) == 0 + +def test_squareBlockMatrix(): + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, m) + C = MatrixSymbol('C', m, n) + D = MatrixSymbol('D', m, m) + X = BlockMatrix([[A, B], [C, D]]) + Y = BlockMatrix([[A]]) + + assert X.is_square + + Q = X + Identity(m + n) + assert (block_collapse(Q) == + BlockMatrix([[A + Identity(n), B], [C, D + Identity(m)]])) + + assert (X + MatrixSymbol('Q', n + m, n + m)).is_MatAdd + assert (X * MatrixSymbol('Q', n + m, n + m)).is_MatMul + + assert block_collapse(Y.I) == A.I + + assert isinstance(X.inverse(), Inverse) + + assert not X.is_Identity + + Z = BlockMatrix([[Identity(n), B], [C, D]]) + assert not Z.is_Identity + + +def test_BlockMatrix_2x2_inverse_symbolic(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', n, k - m) + C = MatrixSymbol('C', k - n, m) + D = MatrixSymbol('D', k - n, k - m) + X = BlockMatrix([[A, B], [C, D]]) + assert X.is_square and X.shape == (k, k) + assert isinstance(block_collapse(X.I), Inverse) # Can't invert when none of the blocks is square + + # test code path where only A is invertible + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, m) + C = MatrixSymbol('C', m, n) + D = ZeroMatrix(m, m) + X = BlockMatrix([[A, B], [C, D]]) + assert block_collapse(X.inverse()) == BlockMatrix([ + [A.I + A.I * B * X.schur('A').I * C * A.I, -A.I * B * X.schur('A').I], + [-X.schur('A').I * C * A.I, X.schur('A').I], + ]) + + # test code path where only B is invertible + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', n, n) + C = ZeroMatrix(m, m) + D = MatrixSymbol('D', m, n) + X = BlockMatrix([[A, B], [C, D]]) + assert block_collapse(X.inverse()) == BlockMatrix([ + [-X.schur('B').I * D * B.I, X.schur('B').I], + [B.I + B.I * A * X.schur('B').I * D * B.I, -B.I * A * X.schur('B').I], + ]) + + # test code path where only C is invertible + A = MatrixSymbol('A', n, m) + B = ZeroMatrix(n, n) + C = MatrixSymbol('C', m, m) + D = MatrixSymbol('D', m, n) + X = BlockMatrix([[A, B], [C, D]]) + assert block_collapse(X.inverse()) == BlockMatrix([ + [-C.I * D * X.schur('C').I, C.I + C.I * D * X.schur('C').I * A * C.I], + [X.schur('C').I, -X.schur('C').I * A * C.I], + ]) + + # test code path where only D is invertible + A = ZeroMatrix(n, n) + B = MatrixSymbol('B', n, m) + C = MatrixSymbol('C', m, n) + D = MatrixSymbol('D', m, m) + X = BlockMatrix([[A, B], [C, D]]) + assert block_collapse(X.inverse()) == BlockMatrix([ + [X.schur('D').I, -X.schur('D').I * B * D.I], + [-D.I * C * X.schur('D').I, D.I + D.I * C * X.schur('D').I * B * D.I], + ]) + + +def test_BlockMatrix_2x2_inverse_numeric(): + """Test 2x2 block matrix inversion numerically for all 4 formulas""" + M = Matrix([[1, 2], [3, 4]]) + # rank deficient matrices that have full rank when two of them combined + D1 = Matrix([[1, 2], [2, 4]]) + D2 = Matrix([[1, 3], [3, 9]]) + D3 = Matrix([[1, 4], [4, 16]]) + assert D1.rank() == D2.rank() == D3.rank() == 1 + assert (D1 + D2).rank() == (D2 + D3).rank() == (D3 + D1).rank() == 2 + + # Only A is invertible + K = BlockMatrix([[M, D1], [D2, D3]]) + assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv() + # Only B is invertible + K = BlockMatrix([[D1, M], [D2, D3]]) + assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv() + # Only C is invertible + K = BlockMatrix([[D1, D2], [M, D3]]) + assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv() + # Only D is invertible + K = BlockMatrix([[D1, D2], [D3, M]]) + assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv() + + +@slow +def test_BlockMatrix_3x3_symbolic(): + # Only test one of these, instead of all permutations, because it's slow + rowblocksizes = (n, m, k) + colblocksizes = (m, k, n) + K = BlockMatrix([ + [MatrixSymbol('M%s%s' % (rows, cols), rows, cols) for cols in colblocksizes] + for rows in rowblocksizes + ]) + collapse = block_collapse(K.I) + assert isinstance(collapse, BlockMatrix) + + +def test_BlockDiagMatrix(): + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', m, m) + C = MatrixSymbol('C', l, l) + M = MatrixSymbol('M', n + m + l, n + m + l) + + X = BlockDiagMatrix(A, B, C) + Y = BlockDiagMatrix(A, 2*B, 3*C) + + assert X.blocks[1, 1] == B + assert X.shape == (n + m + l, n + m + l) + assert all(X.blocks[i, j].is_ZeroMatrix if i != j else X.blocks[i, j] in [A, B, C] + for i in range(3) for j in range(3)) + assert X.__class__(*X.args) == X + assert X.get_diag_blocks() == (A, B, C) + + assert isinstance(block_collapse(X.I * X), Identity) + + assert bc_matmul(X*X) == BlockDiagMatrix(A*A, B*B, C*C) + assert block_collapse(X*X) == BlockDiagMatrix(A*A, B*B, C*C) + #XXX: should be == ?? + assert block_collapse(X + X).equals(BlockDiagMatrix(2*A, 2*B, 2*C)) + assert block_collapse(X*Y) == BlockDiagMatrix(A*A, 2*B*B, 3*C*C) + assert block_collapse(X + Y) == BlockDiagMatrix(2*A, 3*B, 4*C) + + # Ensure that BlockDiagMatrices can still interact with normal MatrixExprs + assert (X*(2*M)).is_MatMul + assert (X + (2*M)).is_MatAdd + + assert (X._blockmul(M)).is_MatMul + assert (X._blockadd(M)).is_MatAdd + +def test_BlockDiagMatrix_nonsquare(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', k, l) + X = BlockDiagMatrix(A, B) + assert X.shape == (n + k, m + l) + assert X.shape == (n + k, m + l) + assert X.rowblocksizes == [n, k] + assert X.colblocksizes == [m, l] + C = MatrixSymbol('C', n, m) + D = MatrixSymbol('D', k, l) + Y = BlockDiagMatrix(C, D) + assert block_collapse(X + Y) == BlockDiagMatrix(A + C, B + D) + assert block_collapse(X * Y.T) == BlockDiagMatrix(A * C.T, B * D.T) + raises(NonInvertibleMatrixError, lambda: BlockDiagMatrix(A, C.T).inverse()) + +def test_BlockDiagMatrix_determinant(): + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', m, m) + assert det(BlockDiagMatrix()) == 1 + assert det(BlockDiagMatrix(A)) == det(A) + assert det(BlockDiagMatrix(A, B)) == det(A) * det(B) + + # non-square blocks + C = MatrixSymbol('C', m, n) + D = MatrixSymbol('D', n, m) + assert det(BlockDiagMatrix(C, D)) == 0 + +def test_BlockDiagMatrix_trace(): + assert trace(BlockDiagMatrix()) == 0 + assert trace(BlockDiagMatrix(ZeroMatrix(n, n))) == 0 + A = MatrixSymbol('A', n, n) + assert trace(BlockDiagMatrix(A)) == trace(A) + B = MatrixSymbol('B', m, m) + assert trace(BlockDiagMatrix(A, B)) == trace(A) + trace(B) + + # non-square blocks + C = MatrixSymbol('C', m, n) + D = MatrixSymbol('D', n, m) + assert isinstance(trace(BlockDiagMatrix(C, D)), Trace) + +def test_BlockDiagMatrix_transpose(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', k, l) + assert transpose(BlockDiagMatrix()) == BlockDiagMatrix() + assert transpose(BlockDiagMatrix(A)) == BlockDiagMatrix(A.T) + assert transpose(BlockDiagMatrix(A, B)) == BlockDiagMatrix(A.T, B.T) + +def test_issue_2460(): + bdm1 = BlockDiagMatrix(Matrix([i]), Matrix([j])) + bdm2 = BlockDiagMatrix(Matrix([k]), Matrix([l])) + assert block_collapse(bdm1 + bdm2) == BlockDiagMatrix(Matrix([i + k]), Matrix([j + l])) + +def test_blockcut(): + A = MatrixSymbol('A', n, m) + B = blockcut(A, (n/2, n/2), (m/2, m/2)) + assert B == BlockMatrix([[A[:n/2, :m/2], A[:n/2, m/2:]], + [A[n/2:, :m/2], A[n/2:, m/2:]]]) + + M = ImmutableMatrix(4, 4, range(16)) + B = blockcut(M, (2, 2), (2, 2)) + assert M == ImmutableMatrix(B) + + B = blockcut(M, (1, 3), (2, 2)) + assert ImmutableMatrix(B.blocks[0, 1]) == ImmutableMatrix([[2, 3]]) + +def test_reblock_2x2(): + B = BlockMatrix([[MatrixSymbol('A_%d%d'%(i,j), 2, 2) + for j in range(3)] + for i in range(3)]) + assert B.blocks.shape == (3, 3) + + BB = reblock_2x2(B) + assert BB.blocks.shape == (2, 2) + + assert B.shape == BB.shape + assert B.as_explicit() == BB.as_explicit() + +def test_deblock(): + B = BlockMatrix([[MatrixSymbol('A_%d%d'%(i,j), n, n) + for j in range(4)] + for i in range(4)]) + + assert deblock(reblock_2x2(B)) == B + +def test_block_collapse_type(): + bm1 = BlockDiagMatrix(ImmutableMatrix([1]), ImmutableMatrix([2])) + bm2 = BlockDiagMatrix(ImmutableMatrix([3]), ImmutableMatrix([4])) + + assert bm1.T.__class__ == BlockDiagMatrix + assert block_collapse(bm1 - bm2).__class__ == BlockDiagMatrix + assert block_collapse(Inverse(bm1)).__class__ == BlockDiagMatrix + assert block_collapse(Transpose(bm1)).__class__ == BlockDiagMatrix + assert bc_transpose(Transpose(bm1)).__class__ == BlockDiagMatrix + assert bc_inverse(Inverse(bm1)).__class__ == BlockDiagMatrix + +def test_invalid_block_matrix(): + raises(ValueError, lambda: BlockMatrix([ + [Identity(2), Identity(5)], + ])) + raises(ValueError, lambda: BlockMatrix([ + [Identity(n), Identity(m)], + ])) + raises(ValueError, lambda: BlockMatrix([ + [ZeroMatrix(n, n), ZeroMatrix(n, n)], + [ZeroMatrix(n, n - 1), ZeroMatrix(n, n + 1)], + ])) + raises(ValueError, lambda: BlockMatrix([ + [ZeroMatrix(n - 1, n), ZeroMatrix(n, n)], + [ZeroMatrix(n + 1, n), ZeroMatrix(n, n)], + ])) + +def test_block_lu_decomposition(): + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, m) + C = MatrixSymbol('C', m, n) + D = MatrixSymbol('D', m, m) + X = BlockMatrix([[A, B], [C, D]]) + + #LDU decomposition + L, D, U = X.LDUdecomposition() + assert block_collapse(L*D*U) == X + + #UDL decomposition + U, D, L = X.UDLdecomposition() + assert block_collapse(U*D*L) == X + + #LU decomposition + L, U = X.LUdecomposition() + assert block_collapse(L*U) == X + +def test_issue_21866(): + n = 10 + I = Identity(n) + O = ZeroMatrix(n, n) + A = BlockMatrix([[ I, O, O, O ], + [ O, I, O, O ], + [ O, O, I, O ], + [ I, O, O, I ]]) + Ainv = block_collapse(A.inv()) + AinvT = BlockMatrix([[ I, O, O, O ], + [ O, I, O, O ], + [ O, O, I, O ], + [ -I, O, O, I ]]) + assert Ainv == AinvT + + +def test_adjoint_and_special_matrices(): + A = Identity(3) + B = OneMatrix(3, 2) + C = ZeroMatrix(2, 3) + D = Identity(2) + X = BlockMatrix([[A, B], [C, D]]) + X2 = BlockMatrix([[A, S.ImaginaryUnit*B], [C, D]]) + assert X.adjoint() == BlockMatrix([[A, ZeroMatrix(3, 2)], [OneMatrix(2, 3), D]]) + assert re(X) == X + assert X2.adjoint() == BlockMatrix([[A, ZeroMatrix(3, 2)], [-S.ImaginaryUnit*OneMatrix(2, 3), D]]) + assert im(X2) == BlockMatrix([[ZeroMatrix(3, 3), OneMatrix(3, 2)], [ZeroMatrix(2, 3), ZeroMatrix(2, 2)]]) + + +def test_block_matrix_derivative(): + x = symbols('x') + A = Matrix(3, 3, [Function(f'a{i}')(x) for i in range(9)]) + bc = BlockMatrix([[A[:2, :2], A[:2, 2]], [A[2, :2], A[2:, 2]]]) + assert Matrix(bc.diff(x)) - A.diff(x) == zeros(3, 3) + + +def test_transpose_inverse_commute(): + n = Symbol('n') + I = Identity(n) + Z = ZeroMatrix(n, n) + A = BlockMatrix([[I, Z], [Z, I]]) + + assert block_collapse(A.transpose().inverse()) == A + assert block_collapse(A.inverse().transpose()) == A + + assert block_collapse(MatPow(A.transpose(), -2)) == MatPow(A, -2) + assert block_collapse(MatPow(A, -2).transpose()) == MatPow(A, -2) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_companion.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_companion.py new file mode 100644 index 0000000000000000000000000000000000000000..edc592c29098eddce0c6352806aa73d5d889e999 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_companion.py @@ -0,0 +1,48 @@ +from sympy.core.expr import unchanged +from sympy.core.symbol import Symbol, symbols +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.matrices.expressions.companion import CompanionMatrix +from sympy.polys.polytools import Poly +from sympy.testing.pytest import raises + + +def test_creation(): + x = Symbol('x') + y = Symbol('y') + raises(ValueError, lambda: CompanionMatrix(1)) + raises(ValueError, lambda: CompanionMatrix(Poly([1], x))) + raises(ValueError, lambda: CompanionMatrix(Poly([2, 1], x))) + raises(ValueError, lambda: CompanionMatrix(Poly(x*y, [x, y]))) + assert unchanged(CompanionMatrix, Poly([1, 2, 3], x)) + + +def test_shape(): + c0, c1, c2 = symbols('c0:3') + x = Symbol('x') + assert CompanionMatrix(Poly([1, c0], x)).shape == (1, 1) + assert CompanionMatrix(Poly([1, c1, c0], x)).shape == (2, 2) + assert CompanionMatrix(Poly([1, c2, c1, c0], x)).shape == (3, 3) + + +def test_entry(): + c0, c1, c2 = symbols('c0:3') + x = Symbol('x') + A = CompanionMatrix(Poly([1, c2, c1, c0], x)) + assert A[0, 0] == 0 + assert A[1, 0] == 1 + assert A[1, 1] == 0 + assert A[2, 1] == 1 + assert A[0, 2] == -c0 + assert A[1, 2] == -c1 + assert A[2, 2] == -c2 + + +def test_as_explicit(): + c0, c1, c2 = symbols('c0:3') + x = Symbol('x') + assert CompanionMatrix(Poly([1, c0], x)).as_explicit() == \ + ImmutableDenseMatrix([-c0]) + assert CompanionMatrix(Poly([1, c1, c0], x)).as_explicit() == \ + ImmutableDenseMatrix([[0, -c0], [1, -c1]]) + assert CompanionMatrix(Poly([1, c2, c1, c0], x)).as_explicit() == \ + ImmutableDenseMatrix([[0, 0, -c0], [1, 0, -c1], [0, 1, -c2]]) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_derivatives.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..77484c994dda62eea9771a76afd8b3caeadacb93 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_derivatives.py @@ -0,0 +1,477 @@ +""" +Some examples have been taken from: + +http://www.math.uwaterloo.ca/~hwolkowi//matrixcookbook.pdf +""" +from sympy import KroneckerProduct +from sympy.combinatorics import Permutation +from sympy.concrete.summations import Sum +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.expressions.determinant import Determinant +from sympy.matrices.expressions.diagonal import DiagMatrix +from sympy.matrices.expressions.hadamard import (HadamardPower, HadamardProduct, hadamard_product) +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import OneMatrix +from sympy.matrices.expressions.trace import Trace +from sympy.matrices.expressions.matadd import MatAdd +from sympy.matrices.expressions.matmul import MatMul +from sympy.matrices.expressions.special import (Identity, ZeroMatrix) +from sympy.tensor.array.array_derivatives import ArrayDerivative +from sympy.matrices.expressions import hadamard_power +from sympy.tensor.array.expressions.array_expressions import ArrayAdd, ArrayTensorProduct, PermuteDims + +i, j, k = symbols("i j k") +m, n = symbols("m n") + +X = MatrixSymbol("X", k, k) +x = MatrixSymbol("x", k, 1) +y = MatrixSymbol("y", k, 1) + +A = MatrixSymbol("A", k, k) +B = MatrixSymbol("B", k, k) +C = MatrixSymbol("C", k, k) +D = MatrixSymbol("D", k, k) + +a = MatrixSymbol("a", k, 1) +b = MatrixSymbol("b", k, 1) +c = MatrixSymbol("c", k, 1) +d = MatrixSymbol("d", k, 1) + + +KDelta = lambda i, j: KroneckerDelta(i, j, (0, k-1)) + + +def _check_derivative_with_explicit_matrix(expr, x, diffexpr, dim=2): + # TODO: this is commented because it slows down the tests. + return + + expr = expr.xreplace({k: dim}) + x = x.xreplace({k: dim}) + diffexpr = diffexpr.xreplace({k: dim}) + + expr = expr.as_explicit() + x = x.as_explicit() + diffexpr = diffexpr.as_explicit() + + assert expr.diff(x).reshape(*diffexpr.shape).tomatrix() == diffexpr + + +def test_matrix_derivative_by_scalar(): + assert A.diff(i) == ZeroMatrix(k, k) + assert (A*(X + B)*c).diff(i) == ZeroMatrix(k, 1) + assert x.diff(i) == ZeroMatrix(k, 1) + assert (x.T*y).diff(i) == ZeroMatrix(1, 1) + assert (x*x.T).diff(i) == ZeroMatrix(k, k) + assert (x + y).diff(i) == ZeroMatrix(k, 1) + assert hadamard_power(x, 2).diff(i) == ZeroMatrix(k, 1) + assert hadamard_power(x, i).diff(i).dummy_eq( + HadamardProduct(x.applyfunc(log), HadamardPower(x, i))) + assert hadamard_product(x, y).diff(i) == ZeroMatrix(k, 1) + assert hadamard_product(i*OneMatrix(k, 1), x, y).diff(i) == hadamard_product(x, y) + assert (i*x).diff(i) == x + assert (sin(i)*A*B*x).diff(i) == cos(i)*A*B*x + assert x.applyfunc(sin).diff(i) == ZeroMatrix(k, 1) + assert Trace(i**2*X).diff(i) == 2*i*Trace(X) + + mu = symbols("mu") + expr = (2*mu*x) + assert expr.diff(x) == 2*mu*Identity(k) + + +def test_one_matrix(): + assert MatMul(x.T, OneMatrix(k, 1)).diff(x) == OneMatrix(k, 1) + + +def test_matrix_derivative_non_matrix_result(): + # This is a 4-dimensional array: + I = Identity(k) + AdA = PermuteDims(ArrayTensorProduct(I, I), Permutation(3)(1, 2)) + assert A.diff(A) == AdA + assert A.T.diff(A) == PermuteDims(ArrayTensorProduct(I, I), Permutation(3)(1, 2, 3)) + assert (2*A).diff(A) == PermuteDims(ArrayTensorProduct(2*I, I), Permutation(3)(1, 2)) + assert MatAdd(A, A).diff(A) == ArrayAdd(AdA, AdA) + assert (A + B).diff(A) == AdA + + +def test_matrix_derivative_trivial_cases(): + # Cookbook example 33: + # TODO: find a way to represent a four-dimensional zero-array: + assert X.diff(A) == ArrayDerivative(X, A) + + +def test_matrix_derivative_with_inverse(): + + # Cookbook example 61: + expr = a.T*Inverse(X)*b + assert expr.diff(X) == -Inverse(X).T*a*b.T*Inverse(X).T + + # Cookbook example 62: + expr = Determinant(Inverse(X)) + # Not implemented yet: + # assert expr.diff(X) == -Determinant(X.inv())*(X.inv()).T + + # Cookbook example 63: + expr = Trace(A*Inverse(X)*B) + assert expr.diff(X) == -(X**(-1)*B*A*X**(-1)).T + + # Cookbook example 64: + expr = Trace(Inverse(X + A)) + assert expr.diff(X) == -(Inverse(X + A)).T**2 + + +def test_matrix_derivative_vectors_and_scalars(): + + assert x.diff(x) == Identity(k) + assert x[i, 0].diff(x[m, 0]).doit() == KDelta(m, i) + + assert x.T.diff(x) == Identity(k) + + # Cookbook example 69: + expr = x.T*a + assert expr.diff(x) == a + assert expr[0, 0].diff(x[m, 0]).doit() == a[m, 0] + expr = a.T*x + assert expr.diff(x) == a + + # Cookbook example 70: + expr = a.T*X*b + assert expr.diff(X) == a*b.T + + # Cookbook example 71: + expr = a.T*X.T*b + assert expr.diff(X) == b*a.T + + # Cookbook example 72: + expr = a.T*X*a + assert expr.diff(X) == a*a.T + expr = a.T*X.T*a + assert expr.diff(X) == a*a.T + + # Cookbook example 77: + expr = b.T*X.T*X*c + assert expr.diff(X) == X*b*c.T + X*c*b.T + + # Cookbook example 78: + expr = (B*x + b).T*C*(D*x + d) + assert expr.diff(x) == B.T*C*(D*x + d) + D.T*C.T*(B*x + b) + + # Cookbook example 81: + expr = x.T*B*x + assert expr.diff(x) == B*x + B.T*x + + # Cookbook example 82: + expr = b.T*X.T*D*X*c + assert expr.diff(X) == D.T*X*b*c.T + D*X*c*b.T + + # Cookbook example 83: + expr = (X*b + c).T*D*(X*b + c) + assert expr.diff(X) == D*(X*b + c)*b.T + D.T*(X*b + c)*b.T + assert str(expr[0, 0].diff(X[m, n]).doit()) == \ + 'b[n, 0]*Sum((c[_i_1, 0] + Sum(X[_i_1, _i_3]*b[_i_3, 0], (_i_3, 0, k - 1)))*D[_i_1, m], (_i_1, 0, k - 1)) + Sum((c[_i_2, 0] + Sum(X[_i_2, _i_4]*b[_i_4, 0], (_i_4, 0, k - 1)))*D[m, _i_2]*b[n, 0], (_i_2, 0, k - 1))' + + # See https://github.com/sympy/sympy/issues/16504#issuecomment-1018339957 + expr = x*x.T*x + I = Identity(k) + assert expr.diff(x) == KroneckerProduct(I, x.T*x) + 2*x*x.T + + +def test_matrix_derivatives_of_traces(): + + expr = Trace(A)*A + I = Identity(k) + assert expr.diff(A) == ArrayAdd(ArrayTensorProduct(I, A), PermuteDims(ArrayTensorProduct(Trace(A)*I, I), Permutation(3)(1, 2))) + assert expr[i, j].diff(A[m, n]).doit() == ( + KDelta(i, m)*KDelta(j, n)*Trace(A) + + KDelta(m, n)*A[i, j] + ) + + ## First order: + + # Cookbook example 99: + expr = Trace(X) + assert expr.diff(X) == Identity(k) + assert expr.rewrite(Sum).diff(X[m, n]).doit() == KDelta(m, n) + + # Cookbook example 100: + expr = Trace(X*A) + assert expr.diff(X) == A.T + assert expr.rewrite(Sum).diff(X[m, n]).doit() == A[n, m] + + # Cookbook example 101: + expr = Trace(A*X*B) + assert expr.diff(X) == A.T*B.T + assert expr.rewrite(Sum).diff(X[m, n]).doit().dummy_eq((A.T*B.T)[m, n]) + + # Cookbook example 102: + expr = Trace(A*X.T*B) + assert expr.diff(X) == B*A + + # Cookbook example 103: + expr = Trace(X.T*A) + assert expr.diff(X) == A + + # Cookbook example 104: + expr = Trace(A*X.T) + assert expr.diff(X) == A + + # Cookbook example 105: + # TODO: TensorProduct is not supported + #expr = Trace(TensorProduct(A, X)) + #assert expr.diff(X) == Trace(A)*Identity(k) + + ## Second order: + + # Cookbook example 106: + expr = Trace(X**2) + assert expr.diff(X) == 2*X.T + + # Cookbook example 107: + expr = Trace(X**2*B) + assert expr.diff(X) == (X*B + B*X).T + expr = Trace(MatMul(X, X, B)) + assert expr.diff(X) == (X*B + B*X).T + + # Cookbook example 108: + expr = Trace(X.T*B*X) + assert expr.diff(X) == B*X + B.T*X + + # Cookbook example 109: + expr = Trace(B*X*X.T) + assert expr.diff(X) == B*X + B.T*X + + # Cookbook example 110: + expr = Trace(X*X.T*B) + assert expr.diff(X) == B*X + B.T*X + + # Cookbook example 111: + expr = Trace(X*B*X.T) + assert expr.diff(X) == X*B.T + X*B + + # Cookbook example 112: + expr = Trace(B*X.T*X) + assert expr.diff(X) == X*B.T + X*B + + # Cookbook example 113: + expr = Trace(X.T*X*B) + assert expr.diff(X) == X*B.T + X*B + + # Cookbook example 114: + expr = Trace(A*X*B*X) + assert expr.diff(X) == A.T*X.T*B.T + B.T*X.T*A.T + + # Cookbook example 115: + expr = Trace(X.T*X) + assert expr.diff(X) == 2*X + expr = Trace(X*X.T) + assert expr.diff(X) == 2*X + + # Cookbook example 116: + expr = Trace(B.T*X.T*C*X*B) + assert expr.diff(X) == C.T*X*B*B.T + C*X*B*B.T + + # Cookbook example 117: + expr = Trace(X.T*B*X*C) + assert expr.diff(X) == B*X*C + B.T*X*C.T + + # Cookbook example 118: + expr = Trace(A*X*B*X.T*C) + assert expr.diff(X) == A.T*C.T*X*B.T + C*A*X*B + + # Cookbook example 119: + expr = Trace((A*X*B + C)*(A*X*B + C).T) + assert expr.diff(X) == 2*A.T*(A*X*B + C)*B.T + + # Cookbook example 120: + # TODO: no support for TensorProduct. + # expr = Trace(TensorProduct(X, X)) + # expr = Trace(X)*Trace(X) + # expr.diff(X) == 2*Trace(X)*Identity(k) + + # Higher Order + + # Cookbook example 121: + expr = Trace(X**k) + #assert expr.diff(X) == k*(X**(k-1)).T + + # Cookbook example 122: + expr = Trace(A*X**k) + #assert expr.diff(X) == # Needs indices + + # Cookbook example 123: + expr = Trace(B.T*X.T*C*X*X.T*C*X*B) + assert expr.diff(X) == C*X*X.T*C*X*B*B.T + C.T*X*B*B.T*X.T*C.T*X + C*X*B*B.T*X.T*C*X + C.T*X*X.T*C.T*X*B*B.T + + # Other + + # Cookbook example 124: + expr = Trace(A*X**(-1)*B) + assert expr.diff(X) == -Inverse(X).T*A.T*B.T*Inverse(X).T + + # Cookbook example 125: + expr = Trace(Inverse(X.T*C*X)*A) + # Warning: result in the cookbook is equivalent if B and C are symmetric: + assert expr.diff(X) == - X.inv().T*A.T*X.inv()*C.inv().T*X.inv().T - X.inv().T*A*X.inv()*C.inv()*X.inv().T + + # Cookbook example 126: + expr = Trace((X.T*C*X).inv()*(X.T*B*X)) + assert expr.diff(X) == -2*C*X*(X.T*C*X).inv()*X.T*B*X*(X.T*C*X).inv() + 2*B*X*(X.T*C*X).inv() + + # Cookbook example 127: + expr = Trace((A + X.T*C*X).inv()*(X.T*B*X)) + # Warning: result in the cookbook is equivalent if B and C are symmetric: + assert expr.diff(X) == B*X*Inverse(A + X.T*C*X) - C*X*Inverse(A + X.T*C*X)*X.T*B*X*Inverse(A + X.T*C*X) - C.T*X*Inverse(A.T + (C*X).T*X)*X.T*B.T*X*Inverse(A.T + (C*X).T*X) + B.T*X*Inverse(A.T + (C*X).T*X) + + +def test_derivatives_of_complicated_matrix_expr(): + expr = a.T*(A*X*(X.T*B + X*A) + B.T*X.T*(a*b.T*(X*D*X.T + X*(X.T*B + A*X)*D*B - X.T*C.T*A)*B + B*(X*D.T + B*A*X*A.T - 3*X*D))*B + 42*X*B*X.T*A.T*(X + X.T))*b + result = (B*(B*A*X*A.T - 3*X*D + X*D.T) + a*b.T*(X*(A*X + X.T*B)*D*B + X*D*X.T - X.T*C.T*A)*B)*B*b*a.T*B.T + B**2*b*a.T*B.T*X.T*a*b.T*X*D + 42*A*X*B.T*X.T*a*b.T + B*D*B**3*b*a.T*B.T*X.T*a*b.T*X + B*b*a.T*A*X + a*b.T*(42*X + 42*X.T)*A*X*B.T + b*a.T*X*B*a*b.T*B.T**2*X*D.T + b*a.T*X*B*a*b.T*B.T**3*D.T*(B.T*X + X.T*A.T) + 42*b*a.T*X*B*X.T*A.T + A.T*(42*X + 42*X.T)*b*a.T*X*B + A.T*B.T**2*X*B*a*b.T*B.T*A + A.T*a*b.T*(A.T*X.T + B.T*X) + A.T*X.T*b*a.T*X*B*a*b.T*B.T**3*D.T + B.T*X*B*a*b.T*B.T*D - 3*B.T*X*B*a*b.T*B.T*D.T - C.T*A*B**2*b*a.T*B.T*X.T*a*b.T + X.T*A.T*a*b.T*A.T + assert expr.diff(X) == result + + +def test_mixed_deriv_mixed_expressions(): + + expr = 3*Trace(A) + assert expr.diff(A) == 3*Identity(k) + + expr = k + deriv = expr.diff(A) + assert isinstance(deriv, ZeroMatrix) + assert deriv == ZeroMatrix(k, k) + + expr = Trace(A)**2 + assert expr.diff(A) == (2*Trace(A))*Identity(k) + + expr = Trace(A)*A + I = Identity(k) + assert expr.diff(A) == ArrayAdd(ArrayTensorProduct(I, A), PermuteDims(ArrayTensorProduct(Trace(A)*I, I), Permutation(3)(1, 2))) + + expr = Trace(Trace(A)*A) + assert expr.diff(A) == (2*Trace(A))*Identity(k) + + expr = Trace(Trace(Trace(A)*A)*A) + assert expr.diff(A) == (3*Trace(A)**2)*Identity(k) + + +def test_derivatives_matrix_norms(): + + expr = x.T*y + assert expr.diff(x) == y + assert expr[0, 0].diff(x[m, 0]).doit() == y[m, 0] + + expr = (x.T*y)**S.Half + assert expr.diff(x) == y/(2*sqrt(x.T*y)) + + expr = (x.T*x)**S.Half + assert expr.diff(x) == x*(x.T*x)**Rational(-1, 2) + + expr = (c.T*a*x.T*b)**S.Half + assert expr.diff(x) == b*a.T*c/sqrt(c.T*a*x.T*b)/2 + + expr = (c.T*a*x.T*b)**Rational(1, 3) + assert expr.diff(x) == b*a.T*c*(c.T*a*x.T*b)**Rational(-2, 3)/3 + + expr = (a.T*X*b)**S.Half + assert expr.diff(X) == a/(2*sqrt(a.T*X*b))*b.T + + expr = d.T*x*(a.T*X*b)**S.Half*y.T*c + assert expr.diff(X) == a/(2*sqrt(a.T*X*b))*x.T*d*y.T*c*b.T + + +def test_derivatives_elementwise_applyfunc(): + + expr = x.applyfunc(tan) + assert expr.diff(x).dummy_eq( + DiagMatrix(x.applyfunc(lambda x: tan(x)**2 + 1))) + assert expr[i, 0].diff(x[m, 0]).doit() == (tan(x[i, 0])**2 + 1)*KDelta(i, m) + _check_derivative_with_explicit_matrix(expr, x, expr.diff(x)) + + expr = (i**2*x).applyfunc(sin) + assert expr.diff(i).dummy_eq( + HadamardProduct((2*i)*x, (i**2*x).applyfunc(cos))) + assert expr[i, 0].diff(i).doit() == 2*i*x[i, 0]*cos(i**2*x[i, 0]) + _check_derivative_with_explicit_matrix(expr, i, expr.diff(i)) + + expr = (log(i)*A*B).applyfunc(sin) + assert expr.diff(i).dummy_eq( + HadamardProduct(A*B/i, (log(i)*A*B).applyfunc(cos))) + _check_derivative_with_explicit_matrix(expr, i, expr.diff(i)) + + expr = A*x.applyfunc(exp) + # TODO: restore this result (currently returning the transpose): + # assert expr.diff(x).dummy_eq(DiagMatrix(x.applyfunc(exp))*A.T) + _check_derivative_with_explicit_matrix(expr, x, expr.diff(x)) + + expr = x.T*A*x + k*y.applyfunc(sin).T*x + assert expr.diff(x).dummy_eq(A.T*x + A*x + k*y.applyfunc(sin)) + _check_derivative_with_explicit_matrix(expr, x, expr.diff(x)) + + expr = x.applyfunc(sin).T*y + # TODO: restore (currently returning the transpose): + # assert expr.diff(x).dummy_eq(DiagMatrix(x.applyfunc(cos))*y) + _check_derivative_with_explicit_matrix(expr, x, expr.diff(x)) + + expr = (a.T * X * b).applyfunc(sin) + assert expr.diff(X).dummy_eq(a*(a.T*X*b).applyfunc(cos)*b.T) + _check_derivative_with_explicit_matrix(expr, X, expr.diff(X)) + + expr = a.T * X.applyfunc(sin) * b + assert expr.diff(X).dummy_eq( + DiagMatrix(a)*X.applyfunc(cos)*DiagMatrix(b)) + _check_derivative_with_explicit_matrix(expr, X, expr.diff(X)) + + expr = a.T * (A*X*B).applyfunc(sin) * b + assert expr.diff(X).dummy_eq( + A.T*DiagMatrix(a)*(A*X*B).applyfunc(cos)*DiagMatrix(b)*B.T) + _check_derivative_with_explicit_matrix(expr, X, expr.diff(X)) + + expr = a.T * (A*X*b).applyfunc(sin) * b.T + # TODO: not implemented + #assert expr.diff(X) == ... + #_check_derivative_with_explicit_matrix(expr, X, expr.diff(X)) + + expr = a.T*A*X.applyfunc(sin)*B*b + assert expr.diff(X).dummy_eq( + HadamardProduct(A.T * a * b.T * B.T, X.applyfunc(cos))) + + expr = a.T * (A*X.applyfunc(sin)*B).applyfunc(log) * b + # TODO: wrong + # assert expr.diff(X) == A.T*DiagMatrix(a)*(A*X.applyfunc(sin)*B).applyfunc(Lambda(k, 1/k))*DiagMatrix(b)*B.T + + expr = a.T * (X.applyfunc(sin)).applyfunc(log) * b + # TODO: wrong + # assert expr.diff(X) == DiagMatrix(a)*X.applyfunc(sin).applyfunc(Lambda(k, 1/k))*DiagMatrix(b) + + +def test_derivatives_of_hadamard_expressions(): + + # Hadamard Product + + expr = hadamard_product(a, x, b) + assert expr.diff(x) == DiagMatrix(hadamard_product(b, a)) + + expr = a.T*hadamard_product(A, X, B)*b + assert expr.diff(X) == HadamardProduct(a*b.T, A, B) + + # Hadamard Power + + expr = hadamard_power(x, 2) + assert expr.diff(x).doit() == 2*DiagMatrix(x) + + expr = hadamard_power(x.T, 2) + assert expr.diff(x).doit() == 2*DiagMatrix(x) + + expr = hadamard_power(x, S.Half) + assert expr.diff(x) == S.Half*DiagMatrix(hadamard_power(x, Rational(-1, 2))) + + expr = hadamard_power(a.T*X*b, 2) + assert expr.diff(X) == 2*a*a.T*X*b*b.T + + expr = hadamard_power(a.T*X*b, S.Half) + assert expr.diff(X) == a/(2*sqrt(a.T*X*b))*b.T diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_determinant.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_determinant.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a66c728f076f8c769d2519ee47c8a9cc90a90e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_determinant.py @@ -0,0 +1,65 @@ +from sympy.core import S, symbols +from sympy.matrices import eye, ones, Matrix, ShapeError +from sympy.matrices.expressions import ( + Identity, MatrixExpr, MatrixSymbol, Determinant, + det, per, ZeroMatrix, Transpose, + Permanent, MatMul +) +from sympy.matrices.expressions.special import OneMatrix +from sympy.testing.pytest import raises +from sympy.assumptions.ask import Q +from sympy.assumptions.refine import refine + +n = symbols('n', integer=True) +A = MatrixSymbol('A', n, n) +B = MatrixSymbol('B', n, n) +C = MatrixSymbol('C', 3, 4) + + +def test_det(): + assert isinstance(Determinant(A), Determinant) + assert not isinstance(Determinant(A), MatrixExpr) + raises(ShapeError, lambda: Determinant(C)) + assert det(eye(3)) == 1 + assert det(Matrix(3, 3, [1, 3, 2, 4, 1, 3, 2, 5, 2])) == 17 + _ = A / det(A) # Make sure this is possible + + raises(TypeError, lambda: Determinant(S.One)) + + assert Determinant(A).arg is A + + +def test_eval_determinant(): + assert det(Identity(n)) == 1 + assert det(ZeroMatrix(n, n)) == 0 + assert det(OneMatrix(n, n)) == Determinant(OneMatrix(n, n)) + assert det(OneMatrix(1, 1)) == 1 + assert det(OneMatrix(2, 2)) == 0 + assert det(Transpose(A)) == det(A) + assert Determinant(MatMul(eye(2), eye(2))).doit(deep=True) == 1 + + +def test_refine(): + assert refine(det(A), Q.orthogonal(A)) == 1 + assert refine(det(A), Q.singular(A)) == 0 + assert refine(det(A), Q.unit_triangular(A)) == 1 + assert refine(det(A), Q.normal(A)) == det(A) + + +def test_commutative(): + det_a = Determinant(A) + det_b = Determinant(B) + assert det_a.is_commutative + assert det_b.is_commutative + assert det_a * det_b == det_b * det_a + + +def test_permanent(): + assert isinstance(Permanent(A), Permanent) + assert not isinstance(Permanent(A), MatrixExpr) + assert isinstance(Permanent(C), Permanent) + assert Permanent(ones(3, 3)).doit() == 6 + _ = C / per(C) + assert per(Matrix(3, 3, [1, 3, 2, 4, 1, 3, 2, 5, 2])) == 103 + raises(TypeError, lambda: Permanent(S.One)) + assert Permanent(A).arg is A diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_diagonal.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_diagonal.py new file mode 100644 index 0000000000000000000000000000000000000000..3e4f7ea4c178121c33eeb26c09675403d274c1e8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_diagonal.py @@ -0,0 +1,156 @@ +from sympy.matrices.expressions import MatrixSymbol +from sympy.matrices.expressions.diagonal import DiagonalMatrix, DiagonalOf, DiagMatrix, diagonalize_vector +from sympy.assumptions.ask import (Q, ask) +from sympy.core.symbol import Symbol +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matmul import MatMul +from sympy.matrices.expressions.special import Identity +from sympy.testing.pytest import raises + + +n = Symbol('n') +m = Symbol('m') + + +def test_DiagonalMatrix(): + x = MatrixSymbol('x', n, m) + D = DiagonalMatrix(x) + assert D.diagonal_length is None + assert D.shape == (n, m) + + x = MatrixSymbol('x', n, n) + D = DiagonalMatrix(x) + assert D.diagonal_length == n + assert D.shape == (n, n) + assert D[1, 2] == 0 + assert D[1, 1] == x[1, 1] + i = Symbol('i') + j = Symbol('j') + x = MatrixSymbol('x', 3, 3) + ij = DiagonalMatrix(x)[i, j] + assert ij != 0 + assert ij.subs({i:0, j:0}) == x[0, 0] + assert ij.subs({i:0, j:1}) == 0 + assert ij.subs({i:1, j:1}) == x[1, 1] + assert ask(Q.diagonal(D)) # affirm that D is diagonal + + x = MatrixSymbol('x', n, 3) + D = DiagonalMatrix(x) + assert D.diagonal_length == 3 + assert D.shape == (n, 3) + assert D[2, m] == KroneckerDelta(2, m)*x[2, m] + assert D[3, m] == 0 + raises(IndexError, lambda: D[m, 3]) + + x = MatrixSymbol('x', 3, n) + D = DiagonalMatrix(x) + assert D.diagonal_length == 3 + assert D.shape == (3, n) + assert D[m, 2] == KroneckerDelta(m, 2)*x[m, 2] + assert D[m, 3] == 0 + raises(IndexError, lambda: D[3, m]) + + x = MatrixSymbol('x', n, m) + D = DiagonalMatrix(x) + assert D.diagonal_length is None + assert D.shape == (n, m) + assert D[m, 4] != 0 + + x = MatrixSymbol('x', 3, 4) + assert [DiagonalMatrix(x)[i] for i in range(12)] == [ + x[0, 0], 0, 0, 0, 0, x[1, 1], 0, 0, 0, 0, x[2, 2], 0] + + # shape is retained, issue 12427 + assert ( + DiagonalMatrix(MatrixSymbol('x', 3, 4))* + DiagonalMatrix(MatrixSymbol('x', 4, 2))).shape == (3, 2) + + +def test_DiagonalOf(): + x = MatrixSymbol('x', n, n) + d = DiagonalOf(x) + assert d.shape == (n, 1) + assert d.diagonal_length == n + assert d[2, 0] == d[2] == x[2, 2] + + x = MatrixSymbol('x', n, m) + d = DiagonalOf(x) + assert d.shape == (None, 1) + assert d.diagonal_length is None + assert d[2, 0] == d[2] == x[2, 2] + + d = DiagonalOf(MatrixSymbol('x', 4, 3)) + assert d.shape == (3, 1) + d = DiagonalOf(MatrixSymbol('x', n, 3)) + assert d.shape == (3, 1) + d = DiagonalOf(MatrixSymbol('x', 3, n)) + assert d.shape == (3, 1) + x = MatrixSymbol('x', n, m) + assert [DiagonalOf(x)[i] for i in range(4)] ==[ + x[0, 0], x[1, 1], x[2, 2], x[3, 3]] + + +def test_DiagMatrix(): + x = MatrixSymbol('x', n, 1) + d = DiagMatrix(x) + assert d.shape == (n, n) + assert d[0, 1] == 0 + assert d[0, 0] == x[0, 0] + + a = MatrixSymbol('a', 1, 1) + d = diagonalize_vector(a) + assert isinstance(d, MatrixSymbol) + assert a == d + assert diagonalize_vector(Identity(3)) == Identity(3) + assert DiagMatrix(Identity(3)).doit() == Identity(3) + assert isinstance(DiagMatrix(Identity(3)), DiagMatrix) + + # A diagonal matrix is equal to its transpose: + assert DiagMatrix(x).T == DiagMatrix(x) + assert diagonalize_vector(x.T) == DiagMatrix(x) + + dx = DiagMatrix(x) + assert dx[0, 0] == x[0, 0] + assert dx[1, 1] == x[1, 0] + assert dx[0, 1] == 0 + assert dx[0, m] == x[0, 0]*KroneckerDelta(0, m) + + z = MatrixSymbol('z', 1, n) + dz = DiagMatrix(z) + assert dz[0, 0] == z[0, 0] + assert dz[1, 1] == z[0, 1] + assert dz[0, 1] == 0 + assert dz[0, m] == z[0, m]*KroneckerDelta(0, m) + + v = MatrixSymbol('v', 3, 1) + dv = DiagMatrix(v) + assert dv.as_explicit() == Matrix([ + [v[0, 0], 0, 0], + [0, v[1, 0], 0], + [0, 0, v[2, 0]], + ]) + + v = MatrixSymbol('v', 1, 3) + dv = DiagMatrix(v) + assert dv.as_explicit() == Matrix([ + [v[0, 0], 0, 0], + [0, v[0, 1], 0], + [0, 0, v[0, 2]], + ]) + + dv = DiagMatrix(3*v) + assert dv.args == (3*v,) + assert dv.doit() == 3*DiagMatrix(v) + assert isinstance(dv.doit(), MatMul) + + a = MatrixSymbol("a", 3, 1).as_explicit() + expr = DiagMatrix(a) + result = Matrix([ + [a[0, 0], 0, 0], + [0, a[1, 0], 0], + [0, 0, a[2, 0]], + ]) + assert expr.doit() == result + expr = DiagMatrix(a.T) + assert expr.doit() == result diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_dotproduct.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_dotproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..abf8ab8e935cbd3039f25f83d3603ac444e5a7bb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_dotproduct.py @@ -0,0 +1,35 @@ +from sympy.core.expr import unchanged +from sympy.core.mul import Mul +from sympy.matrices import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.dotproduct import DotProduct +from sympy.testing.pytest import raises + + +A = Matrix(3, 1, [1, 2, 3]) +B = Matrix(3, 1, [1, 3, 5]) +C = Matrix(4, 1, [1, 2, 4, 5]) +D = Matrix(2, 2, [1, 2, 3, 4]) + +def test_docproduct(): + assert DotProduct(A, B).doit() == 22 + assert DotProduct(A.T, B).doit() == 22 + assert DotProduct(A, B.T).doit() == 22 + assert DotProduct(A.T, B.T).doit() == 22 + + raises(TypeError, lambda: DotProduct(1, A)) + raises(TypeError, lambda: DotProduct(A, 1)) + raises(TypeError, lambda: DotProduct(A, D)) + raises(TypeError, lambda: DotProduct(D, A)) + + raises(TypeError, lambda: DotProduct(B, C).doit()) + +def test_dotproduct_symbolic(): + A = MatrixSymbol('A', 3, 1) + B = MatrixSymbol('B', 3, 1) + + dot = DotProduct(A, B) + assert dot.is_scalar == True + assert unchanged(Mul, 2, dot) + # XXX Fix forced evaluation for arithmetics with matrix expressions + assert dot * A == (A[0, 0]*B[0, 0] + A[1, 0]*B[1, 0] + A[2, 0]*B[2, 0])*A diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_factorizations.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_factorizations.py new file mode 100644 index 0000000000000000000000000000000000000000..a0319acabbb7409dfa5c24ceca39e25ff0240618 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_factorizations.py @@ -0,0 +1,29 @@ +from sympy.matrices.expressions.factorizations import lu, LofCholesky, qr, svd +from sympy.assumptions.ask import (Q, ask) +from sympy.core.symbol import Symbol +from sympy.matrices.expressions.matexpr import MatrixSymbol + +n = Symbol('n') +X = MatrixSymbol('X', n, n) + +def test_LU(): + L, U = lu(X) + assert L.shape == U.shape == X.shape + assert ask(Q.lower_triangular(L)) + assert ask(Q.upper_triangular(U)) + +def test_Cholesky(): + LofCholesky(X) + +def test_QR(): + Q_, R = qr(X) + assert Q_.shape == R.shape == X.shape + assert ask(Q.orthogonal(Q_)) + assert ask(Q.upper_triangular(R)) + +def test_svd(): + U, S, V = svd(X) + assert U.shape == S.shape == V.shape == X.shape + assert ask(Q.orthogonal(U)) + assert ask(Q.orthogonal(V)) + assert ask(Q.diagonal(S)) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_fourier.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_fourier.py new file mode 100644 index 0000000000000000000000000000000000000000..0230c8a0957ed28fb0a5cc1e9ee77ecae797265b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_fourier.py @@ -0,0 +1,44 @@ +from sympy.assumptions.ask import (Q, ask) +from sympy.core.numbers import (I, Rational) +from sympy.core.singleton import S +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.simplify.simplify import simplify +from sympy.core.symbol import symbols +from sympy.matrices.expressions.fourier import DFT, IDFT +from sympy.matrices import det, Matrix, Identity +from sympy.testing.pytest import raises + + +def test_dft_creation(): + assert DFT(2) + assert DFT(0) + raises(ValueError, lambda: DFT(-1)) + raises(ValueError, lambda: DFT(2.0)) + raises(ValueError, lambda: DFT(2 + 1j)) + + n = symbols('n') + assert DFT(n) + n = symbols('n', integer=False) + raises(ValueError, lambda: DFT(n)) + n = symbols('n', negative=True) + raises(ValueError, lambda: DFT(n)) + + +def test_dft(): + n, i, j = symbols('n i j') + assert DFT(4).shape == (4, 4) + assert ask(Q.unitary(DFT(4))) + assert Abs(simplify(det(Matrix(DFT(4))))) == 1 + assert DFT(n)*IDFT(n) == Identity(n) + assert DFT(n)[i, j] == exp(-2*S.Pi*I/n)**(i*j) / sqrt(n) + + +def test_dft2(): + assert DFT(1).as_explicit() == Matrix([[1]]) + assert DFT(2).as_explicit() == 1/sqrt(2)*Matrix([[1,1],[1,-1]]) + assert DFT(4).as_explicit() == Matrix([[S.Half, S.Half, S.Half, S.Half], + [S.Half, -I/2, Rational(-1,2), I/2], + [S.Half, Rational(-1,2), S.Half, Rational(-1,2)], + [S.Half, I/2, Rational(-1,2), -I/2]]) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_funcmatrix.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_funcmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..e4850fe5c739b9390fac6afa10757b5babf821c6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_funcmatrix.py @@ -0,0 +1,54 @@ +from sympy.core import symbols, Lambda +from sympy.core.sympify import SympifyError +from sympy.functions import KroneckerDelta +from sympy.matrices import Matrix +from sympy.matrices.expressions import FunctionMatrix, MatrixExpr, Identity +from sympy.testing.pytest import raises + + +def test_funcmatrix_creation(): + i, j, k = symbols('i j k') + assert FunctionMatrix(2, 2, Lambda((i, j), 0)) + assert FunctionMatrix(0, 0, Lambda((i, j), 0)) + + raises(ValueError, lambda: FunctionMatrix(-1, 0, Lambda((i, j), 0))) + raises(ValueError, lambda: FunctionMatrix(2.0, 0, Lambda((i, j), 0))) + raises(ValueError, lambda: FunctionMatrix(2j, 0, Lambda((i, j), 0))) + raises(ValueError, lambda: FunctionMatrix(0, -1, Lambda((i, j), 0))) + raises(ValueError, lambda: FunctionMatrix(0, 2.0, Lambda((i, j), 0))) + raises(ValueError, lambda: FunctionMatrix(0, 2j, Lambda((i, j), 0))) + + raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda(i, 0))) + raises(SympifyError, lambda: FunctionMatrix(2, 2, lambda i, j: 0)) + raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda((i,), 0))) + raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda((i, j, k), 0))) + raises(ValueError, lambda: FunctionMatrix(2, 2, i+j)) + assert FunctionMatrix(2, 2, "lambda i, j: 0") == \ + FunctionMatrix(2, 2, Lambda((i, j), 0)) + + m = FunctionMatrix(2, 2, KroneckerDelta) + assert m.as_explicit() == Identity(2).as_explicit() + assert m.args[2].dummy_eq(Lambda((i, j), KroneckerDelta(i, j))) + + n = symbols('n') + assert FunctionMatrix(n, n, Lambda((i, j), 0)) + n = symbols('n', integer=False) + raises(ValueError, lambda: FunctionMatrix(n, n, Lambda((i, j), 0))) + n = symbols('n', negative=True) + raises(ValueError, lambda: FunctionMatrix(n, n, Lambda((i, j), 0))) + + +def test_funcmatrix(): + i, j = symbols('i,j') + X = FunctionMatrix(3, 3, Lambda((i, j), i - j)) + assert X[1, 1] == 0 + assert X[1, 2] == -1 + assert X.shape == (3, 3) + assert X.rows == X.cols == 3 + assert Matrix(X) == Matrix(3, 3, lambda i, j: i - j) + assert isinstance(X*X + X, MatrixExpr) + + +def test_replace_issue(): + X = FunctionMatrix(3, 3, KroneckerDelta) + assert X.replace(lambda x: True, lambda x: x) == X diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_hadamard.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_hadamard.py new file mode 100644 index 0000000000000000000000000000000000000000..800fa830a9b089103d69b372db93ebcea541d02b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_hadamard.py @@ -0,0 +1,141 @@ +from sympy.matrices.dense import Matrix, eye +from sympy.matrices.exceptions import ShapeError +from sympy.matrices.expressions.matadd import MatAdd +from sympy.matrices.expressions.special import Identity, OneMatrix, ZeroMatrix +from sympy.core import symbols +from sympy.testing.pytest import raises, warns_deprecated_sympy + +from sympy.matrices import MatrixSymbol +from sympy.matrices.expressions import (HadamardProduct, hadamard_product, HadamardPower, hadamard_power) + +n, m, k = symbols('n,m,k') +Z = MatrixSymbol('Z', n, n) +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', n, m) +C = MatrixSymbol('C', m, k) + + +def test_HadamardProduct(): + assert HadamardProduct(A, B, A).shape == A.shape + + raises(TypeError, lambda: HadamardProduct(A, n)) + raises(TypeError, lambda: HadamardProduct(A, 1)) + + assert HadamardProduct(A, 2*B, -A)[1, 1] == \ + -2 * A[1, 1] * B[1, 1] * A[1, 1] + + mix = HadamardProduct(Z*A, B)*C + assert mix.shape == (n, k) + + assert set(HadamardProduct(A, B, A).T.args) == {A.T, A.T, B.T} + + +def test_HadamardProduct_isnt_commutative(): + assert HadamardProduct(A, B) != HadamardProduct(B, A) + + +def test_mixed_indexing(): + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + Z = MatrixSymbol('Z', 2, 2) + + assert (X*HadamardProduct(Y, Z))[0, 0] == \ + X[0, 0]*Y[0, 0]*Z[0, 0] + X[0, 1]*Y[1, 0]*Z[1, 0] + + +def test_canonicalize(): + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + with warns_deprecated_sympy(): + expr = HadamardProduct(X, check=False) + assert isinstance(expr, HadamardProduct) + expr2 = expr.doit() # unpack is called + assert isinstance(expr2, MatrixSymbol) + Z = ZeroMatrix(2, 2) + U = OneMatrix(2, 2) + assert HadamardProduct(Z, X).doit() == Z + assert HadamardProduct(U, X, X, U).doit() == HadamardPower(X, 2) + assert HadamardProduct(X, U, Y).doit() == HadamardProduct(X, Y) + assert HadamardProduct(X, Z, U, Y).doit() == Z + + +def test_hadamard(): + m, n, p = symbols('m, n, p', integer=True) + A = MatrixSymbol('A', m, n) + B = MatrixSymbol('B', m, n) + X = MatrixSymbol('X', m, m) + I = Identity(m) + + raises(TypeError, lambda: hadamard_product()) + assert hadamard_product(A) == A + assert isinstance(hadamard_product(A, B), HadamardProduct) + assert hadamard_product(A, B).doit() == hadamard_product(A, B) + assert hadamard_product(X, I) == HadamardProduct(I, X) + assert isinstance(hadamard_product(X, I), HadamardProduct) + + a = MatrixSymbol("a", k, 1) + expr = MatAdd(ZeroMatrix(k, 1), OneMatrix(k, 1)) + expr = HadamardProduct(expr, a) + assert expr.doit() == a + + raises(ValueError, lambda: HadamardProduct()) + + +def test_hadamard_product_with_explicit_mat(): + A = MatrixSymbol("A", 3, 3).as_explicit() + B = MatrixSymbol("B", 3, 3).as_explicit() + X = MatrixSymbol("X", 3, 3) + expr = hadamard_product(A, B) + ret = Matrix([i*j for i, j in zip(A, B)]).reshape(3, 3) + assert expr == ret + expr = hadamard_product(A, X, B) + assert expr == HadamardProduct(ret, X) + expr = hadamard_product(eye(3), A) + assert expr == Matrix([[A[0, 0], 0, 0], [0, A[1, 1], 0], [0, 0, A[2, 2]]]) + expr = hadamard_product(eye(3), eye(3)) + assert expr == eye(3) + + +def test_hadamard_power(): + m, n, p = symbols('m, n, p', integer=True) + A = MatrixSymbol('A', m, n) + + assert hadamard_power(A, 1) == A + assert isinstance(hadamard_power(A, 2), HadamardPower) + assert hadamard_power(A, n).T == hadamard_power(A.T, n) + assert hadamard_power(A, n)[0, 0] == A[0, 0]**n + assert hadamard_power(m, n) == m**n + raises(ValueError, lambda: hadamard_power(A, A)) + + +def test_hadamard_power_explicit(): + A = MatrixSymbol('A', 2, 2) + B = MatrixSymbol('B', 2, 2) + a, b = symbols('a b') + + assert HadamardPower(a, b) == a**b + + assert HadamardPower(a, B).as_explicit() == \ + Matrix([ + [a**B[0, 0], a**B[0, 1]], + [a**B[1, 0], a**B[1, 1]]]) + + assert HadamardPower(A, b).as_explicit() == \ + Matrix([ + [A[0, 0]**b, A[0, 1]**b], + [A[1, 0]**b, A[1, 1]**b]]) + + assert HadamardPower(A, B).as_explicit() == \ + Matrix([ + [A[0, 0]**B[0, 0], A[0, 1]**B[0, 1]], + [A[1, 0]**B[1, 0], A[1, 1]**B[1, 1]]]) + + +def test_shape_error(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 3, 3) + raises(ShapeError, lambda: HadamardProduct(A, B)) + raises(ShapeError, lambda: HadamardPower(A, B)) + A = MatrixSymbol('A', 3, 2) + raises(ShapeError, lambda: HadamardProduct(A, B)) + raises(ShapeError, lambda: HadamardPower(A, B)) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_indexing.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..500761f248eef5f627c2a7344a6817aca0b8a802 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_indexing.py @@ -0,0 +1,299 @@ +from sympy.concrete.summations import Sum +from sympy.core.symbol import symbols, Symbol, Dummy +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.dense import eye +from sympy.matrices.expressions.blockmatrix import BlockMatrix +from sympy.matrices.expressions.hadamard import HadamardPower +from sympy.matrices.expressions.matexpr import (MatrixSymbol, + MatrixExpr, MatrixElement) +from sympy.matrices.expressions.matpow import MatPow +from sympy.matrices.expressions.special import (ZeroMatrix, Identity, + OneMatrix) +from sympy.matrices.expressions.trace import Trace, trace +from sympy.matrices.immutable import ImmutableMatrix +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct +from sympy.testing.pytest import XFAIL, raises + +k, l, m, n = symbols('k l m n', integer=True) +i, j = symbols('i j', integer=True) + +W = MatrixSymbol('W', k, l) +X = MatrixSymbol('X', l, m) +Y = MatrixSymbol('Y', l, m) +Z = MatrixSymbol('Z', m, n) + +X1 = MatrixSymbol('X1', m, m) +X2 = MatrixSymbol('X2', m, m) +X3 = MatrixSymbol('X3', m, m) +X4 = MatrixSymbol('X4', m, m) + +A = MatrixSymbol('A', 2, 2) +B = MatrixSymbol('B', 2, 2) +x = MatrixSymbol('x', 1, 2) +y = MatrixSymbol('x', 2, 1) + + +def test_symbolic_indexing(): + x12 = X[1, 2] + assert all(s in str(x12) for s in ['1', '2', X.name]) + # We don't care about the exact form of this. We do want to make sure + # that all of these features are present + + +def test_add_index(): + assert (X + Y)[i, j] == X[i, j] + Y[i, j] + + +def test_mul_index(): + assert (A*y)[0, 0] == A[0, 0]*y[0, 0] + A[0, 1]*y[1, 0] + assert (A*B).as_mutable() == (A.as_mutable() * B.as_mutable()) + X = MatrixSymbol('X', n, m) + Y = MatrixSymbol('Y', m, k) + + result = (X*Y)[4,2] + expected = Sum(X[4, i]*Y[i, 2], (i, 0, m - 1)) + assert result.args[0].dummy_eq(expected.args[0], i) + assert result.args[1][1:] == expected.args[1][1:] + + +def test_pow_index(): + Q = MatPow(A, 2) + assert Q[0, 0] == A[0, 0]**2 + A[0, 1]*A[1, 0] + n = symbols("n") + Q2 = A**n + assert Q2[0, 0] == 2*( + -sqrt((A[0, 0] + A[1, 1])**2 - 4*A[0, 0]*A[1, 1] + + 4*A[0, 1]*A[1, 0])/2 + A[0, 0]/2 + A[1, 1]/2 + )**n * \ + A[0, 1]*A[1, 0]/( + (sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + + A[1, 1]**2) + A[0, 0] - A[1, 1])* + sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + A[1, 1]**2) + ) - 2*( + sqrt((A[0, 0] + A[1, 1])**2 - 4*A[0, 0]*A[1, 1] + + 4*A[0, 1]*A[1, 0])/2 + A[0, 0]/2 + A[1, 1]/2 + )**n * A[0, 1]*A[1, 0]/( + (-sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + + A[1, 1]**2) + A[0, 0] - A[1, 1])* + sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + A[1, 1]**2) + ) + + +def test_transpose_index(): + assert X.T[i, j] == X[j, i] + + +def test_Identity_index(): + I = Identity(3) + assert I[0, 0] == I[1, 1] == I[2, 2] == 1 + assert I[1, 0] == I[0, 1] == I[2, 1] == 0 + assert I[i, 0].delta_range == (0, 2) + raises(IndexError, lambda: I[3, 3]) + + +def test_block_index(): + I = Identity(3) + Z = ZeroMatrix(3, 3) + B = BlockMatrix([[I, I], [I, I]]) + e3 = ImmutableMatrix(eye(3)) + BB = BlockMatrix([[e3, e3], [e3, e3]]) + assert B[0, 0] == B[3, 0] == B[0, 3] == B[3, 3] == 1 + assert B[4, 3] == B[5, 1] == 0 + + BB = BlockMatrix([[e3, e3], [e3, e3]]) + assert B.as_explicit() == BB.as_explicit() + + BI = BlockMatrix([[I, Z], [Z, I]]) + + assert BI.as_explicit().equals(eye(6)) + + +def test_block_index_symbolic(): + # Note that these matrices may be zero-sized and indices may be negative, which causes + # all naive simplifications given in the comments to be invalid + A1 = MatrixSymbol('A1', n, k) + A2 = MatrixSymbol('A2', n, l) + A3 = MatrixSymbol('A3', m, k) + A4 = MatrixSymbol('A4', m, l) + A = BlockMatrix([[A1, A2], [A3, A4]]) + assert A[0, 0] == MatrixElement(A, 0, 0) # Cannot be A1[0, 0] + assert A[n - 1, k - 1] == A1[n - 1, k - 1] + assert A[n, k] == A4[0, 0] + assert A[n + m - 1, 0] == MatrixElement(A, n + m - 1, 0) # Cannot be A3[m - 1, 0] + assert A[0, k + l - 1] == MatrixElement(A, 0, k + l - 1) # Cannot be A2[0, l - 1] + assert A[n + m - 1, k + l - 1] == MatrixElement(A, n + m - 1, k + l - 1) # Cannot be A4[m - 1, l - 1] + assert A[i, j] == MatrixElement(A, i, j) + assert A[n + i, k + j] == MatrixElement(A, n + i, k + j) # Cannot be A4[i, j] + assert A[n - i - 1, k - j - 1] == MatrixElement(A, n - i - 1, k - j - 1) # Cannot be A1[n - i - 1, k - j - 1] + + +def test_block_index_symbolic_nonzero(): + # All invalid simplifications from test_block_index_symbolic() that become valid if all + # matrices have nonzero size and all indices are nonnegative + k, l, m, n = symbols('k l m n', integer=True, positive=True) + i, j = symbols('i j', integer=True, nonnegative=True) + A1 = MatrixSymbol('A1', n, k) + A2 = MatrixSymbol('A2', n, l) + A3 = MatrixSymbol('A3', m, k) + A4 = MatrixSymbol('A4', m, l) + A = BlockMatrix([[A1, A2], [A3, A4]]) + assert A[0, 0] == A1[0, 0] + assert A[n + m - 1, 0] == A3[m - 1, 0] + assert A[0, k + l - 1] == A2[0, l - 1] + assert A[n + m - 1, k + l - 1] == A4[m - 1, l - 1] + assert A[i, j] == MatrixElement(A, i, j) + assert A[n + i, k + j] == A4[i, j] + assert A[n - i - 1, k - j - 1] == A1[n - i - 1, k - j - 1] + assert A[2 * n, 2 * k] == A4[n, k] + + +def test_block_index_large(): + n, m, k = symbols('n m k', integer=True, positive=True) + i = symbols('i', integer=True, nonnegative=True) + A1 = MatrixSymbol('A1', n, n) + A2 = MatrixSymbol('A2', n, m) + A3 = MatrixSymbol('A3', n, k) + A4 = MatrixSymbol('A4', m, n) + A5 = MatrixSymbol('A5', m, m) + A6 = MatrixSymbol('A6', m, k) + A7 = MatrixSymbol('A7', k, n) + A8 = MatrixSymbol('A8', k, m) + A9 = MatrixSymbol('A9', k, k) + A = BlockMatrix([[A1, A2, A3], [A4, A5, A6], [A7, A8, A9]]) + assert A[n + i, n + i] == MatrixElement(A, n + i, n + i) + + +@XFAIL +def test_block_index_symbolic_fail(): + # To make this work, symbolic matrix dimensions would need to be somehow assumed nonnegative + # even if the symbols aren't specified as such. Then 2 * n < n would correctly evaluate to + # False in BlockMatrix._entry() + A1 = MatrixSymbol('A1', n, 1) + A2 = MatrixSymbol('A2', m, 1) + A = BlockMatrix([[A1], [A2]]) + assert A[2 * n, 0] == A2[n, 0] + + +def test_slicing(): + A.as_explicit()[0, :] # does not raise an error + + +def test_errors(): + raises(IndexError, lambda: Identity(2)[1, 2, 3, 4, 5]) + raises(IndexError, lambda: Identity(2)[[1, 2, 3, 4, 5]]) + + +def test_matrix_expression_to_indices(): + i, j = symbols("i, j") + i1, i2, i3 = symbols("i_1:4") + + def replace_dummies(expr): + repl = {i: Symbol(i.name) for i in expr.atoms(Dummy)} + return expr.xreplace(repl) + + expr = W*X*Z + assert replace_dummies(expr._entry(i, j)) == \ + Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) + assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr + + expr = Z.T*X.T*W.T + assert replace_dummies(expr._entry(i, j)) == \ + Sum(W[j, i2]*X[i2, i1]*Z[i1, i], (i1, 0, m-1), (i2, 0, l-1)) + assert MatrixExpr.from_index_summation(expr._entry(i, j), i) == expr + + expr = W*X*Z + W*Y*Z + assert replace_dummies(expr._entry(i, j)) == \ + Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) +\ + Sum(W[i, i1]*Y[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) + assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr + + expr = 2*W*X*Z + 3*W*Y*Z + assert replace_dummies(expr._entry(i, j)) == \ + 2*Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) +\ + 3*Sum(W[i, i1]*Y[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) + assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr + + expr = W*(X + Y)*Z + assert replace_dummies(expr._entry(i, j)) == \ + Sum(W[i, i1]*(X[i1, i2] + Y[i1, i2])*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) + assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr + + expr = A*B**2*A + #assert replace_dummies(expr._entry(i, j)) == \ + # Sum(A[i, i1]*B[i1, i2]*B[i2, i3]*A[i3, j], (i1, 0, 1), (i2, 0, 1), (i3, 0, 1)) + + # Check that different dummies are used in sub-multiplications: + expr = (X1*X2 + X2*X1)*X3 + assert replace_dummies(expr._entry(i, j)) == \ + Sum((Sum(X1[i, i2] * X2[i2, i1], (i2, 0, m - 1)) + Sum(X1[i3, i1] * X2[i, i3], (i3, 0, m - 1))) * X3[ + i1, j], (i1, 0, m - 1)) + + +def test_matrix_expression_from_index_summation(): + from sympy.abc import a,b,c,d + A = MatrixSymbol("A", k, k) + B = MatrixSymbol("B", k, k) + C = MatrixSymbol("C", k, k) + w1 = MatrixSymbol("w1", k, 1) + + i0, i1, i2, i3, i4 = symbols("i0:5", cls=Dummy) + + expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m-1)) + assert MatrixExpr.from_index_summation(expr, a) == W*X*Z + expr = Sum(W.T[b,a]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m-1)) + assert MatrixExpr.from_index_summation(expr, a) == W*X*Z + expr = Sum(A[b, a]*B[b, c]*C[c, d], (b, 0, k-1), (c, 0, k-1)) + assert MatrixSymbol.from_index_summation(expr, a) == A.T*B*C + expr = Sum(A[b, a]*B[c, b]*C[c, d], (b, 0, k-1), (c, 0, k-1)) + assert MatrixSymbol.from_index_summation(expr, a) == A.T*B.T*C + expr = Sum(C[c, d]*A[b, a]*B[c, b], (b, 0, k-1), (c, 0, k-1)) + assert MatrixSymbol.from_index_summation(expr, a) == A.T*B.T*C + expr = Sum(A[a, b] + B[a, b], (a, 0, k-1), (b, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == OneMatrix(1, k)*A*OneMatrix(k, 1) + OneMatrix(1, k)*B*OneMatrix(k, 1) + expr = Sum(A[a, b]**2, (a, 0, k - 1), (b, 0, k - 1)) + assert MatrixExpr.from_index_summation(expr, a) == Trace(A * A.T) + expr = Sum(A[a, b]**3, (a, 0, k - 1), (b, 0, k - 1)) + assert MatrixExpr.from_index_summation(expr, a) == Trace(HadamardPower(A.T, 2) * A) + expr = Sum((A[a, b] + B[a, b])*C[b, c], (b, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == (A+B)*C + expr = Sum((A[a, b] + B[b, a])*C[b, c], (b, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == (A+B.T)*C + expr = Sum(A[a, b]*A[b, c]*A[c, d], (b, 0, k-1), (c, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == A**3 + expr = Sum(A[a, b]*A[b, c]*B[c, d], (b, 0, k-1), (c, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == A**2*B + + # Parse the trace of a matrix: + + expr = Sum(A[a, a], (a, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, None) == trace(A) + expr = Sum(A[a, a]*B[b, c]*C[c, d], (a, 0, k-1), (c, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, b) == trace(A)*B*C + + # Check wrong sum ranges (should raise an exception): + + ## Case 1: 0 to m instead of 0 to m-1 + expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m)) + raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) + ## Case 2: 1 to m-1 instead of 0 to m-1 + expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 1, m-1)) + raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) + + # Parse nested sums: + expr = Sum(A[a, b]*Sum(B[b, c]*C[c, d], (c, 0, k-1)), (b, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == A*B*C + + # Test Kronecker delta: + expr = Sum(A[a, b]*KroneckerDelta(b, c)*B[c, d], (b, 0, k-1), (c, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == A*B + + expr = Sum(KroneckerDelta(i1, m)*KroneckerDelta(i2, n)*A[i, i1]*A[j, i2], (i1, 0, k-1), (i2, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, m) == ArrayTensorProduct(A.T, A) + + # Test numbered indices: + expr = Sum(A[i1, i2]*w1[i2, 0], (i2, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, i1) == MatrixElement(A*w1, i1, 0) + + expr = Sum(A[i1, i2]*B[i2, 0], (i2, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, i1) == MatrixElement(A*B, i1, 0) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_inverse.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcc7d4de2b2bee4c4922bda8bc48a52aa205961 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_inverse.py @@ -0,0 +1,69 @@ +from sympy.core import symbols, S +from sympy.matrices.expressions import MatrixSymbol, Inverse, MatPow, ZeroMatrix, OneMatrix +from sympy.matrices.exceptions import NonInvertibleMatrixError, NonSquareMatrixError +from sympy.matrices import eye, Identity +from sympy.testing.pytest import raises +from sympy.assumptions.ask import Q +from sympy.assumptions.refine import refine + +n, m, l = symbols('n m l', integer=True) +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) +D = MatrixSymbol('D', n, n) +E = MatrixSymbol('E', m, n) + + +def test_inverse(): + assert Inverse(C).args == (C, S.NegativeOne) + assert Inverse(C).shape == (n, n) + assert Inverse(A*E).shape == (n, n) + assert Inverse(E*A).shape == (m, m) + assert Inverse(C).inverse() == C + assert Inverse(Inverse(C)).doit() == C + assert isinstance(Inverse(Inverse(C)), Inverse) + + assert Inverse(*Inverse(E*A).args) == Inverse(E*A) + + assert C.inverse().inverse() == C + + assert C.inverse()*C == Identity(C.rows) + + assert Identity(n).inverse() == Identity(n) + assert (3*Identity(n)).inverse() == Identity(n)/3 + + # Simplifies Muls if possible (i.e. submatrices are square) + assert (C*D).inverse() == D.I*C.I + # But still works when not possible + assert isinstance((A*E).inverse(), Inverse) + assert Inverse(C*D).doit(inv_expand=False) == Inverse(C*D) + + assert Inverse(eye(3)).doit() == eye(3) + assert Inverse(eye(3)).doit(deep=False) == eye(3) + + assert OneMatrix(1, 1).I == Identity(1) + assert isinstance(OneMatrix(n, n).I, Inverse) + +def test_inverse_non_invertible(): + raises(NonInvertibleMatrixError, lambda: ZeroMatrix(n, n).I) + raises(NonInvertibleMatrixError, lambda: OneMatrix(2, 2).I) + +def test_refine(): + assert refine(C.I, Q.orthogonal(C)) == C.T + + +def test_inverse_matpow_canonicalization(): + A = MatrixSymbol('A', 3, 3) + assert Inverse(MatPow(A, 3)).doit() == MatPow(Inverse(A), 3).doit() + + +def test_nonsquare_error(): + A = MatrixSymbol('A', 3, 4) + raises(NonSquareMatrixError, lambda: Inverse(A)) + + +def test_adjoint_trnaspose_conjugate(): + A = MatrixSymbol('A', n, n) + assert A.transpose().inverse() == A.inverse().transpose() + assert A.conjugate().inverse() == A.inverse().conjugate() + assert A.adjoint().inverse() == A.inverse().adjoint() diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_kronecker.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_kronecker.py new file mode 100644 index 0000000000000000000000000000000000000000..b4444716a76a52e3638dd7a36238a9f459179083 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_kronecker.py @@ -0,0 +1,150 @@ +from sympy.core.mod import Mod +from sympy.core.numbers import I +from sympy.core.symbol import symbols +from sympy.functions.elementary.integers import floor +from sympy.matrices.dense import (Matrix, eye) +from sympy.matrices import MatrixSymbol, Identity +from sympy.matrices.expressions import det, trace + +from sympy.matrices.expressions.kronecker import (KroneckerProduct, + kronecker_product, + combine_kronecker) + + +mat1 = Matrix([[1, 2 * I], [1 + I, 3]]) +mat2 = Matrix([[2 * I, 3], [4 * I, 2]]) + +i, j, k, n, m, o, p, x = symbols('i,j,k,n,m,o,p,x') +Z = MatrixSymbol('Z', n, n) +W = MatrixSymbol('W', m, m) +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', n, m) +C = MatrixSymbol('C', m, k) + + +def test_KroneckerProduct(): + assert isinstance(KroneckerProduct(A, B), KroneckerProduct) + assert KroneckerProduct(A, B).subs(A, C) == KroneckerProduct(C, B) + assert KroneckerProduct(A, C).shape == (n*m, m*k) + assert (KroneckerProduct(A, C) + KroneckerProduct(-A, C)).is_ZeroMatrix + assert (KroneckerProduct(W, Z) * KroneckerProduct(W.I, Z.I)).is_Identity + + +def test_KroneckerProduct_identity(): + assert KroneckerProduct(Identity(m), Identity(n)) == Identity(m*n) + assert KroneckerProduct(eye(2), eye(3)) == eye(6) + + +def test_KroneckerProduct_explicit(): + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + kp = KroneckerProduct(X, Y) + assert kp.shape == (4, 4) + assert kp.as_explicit() == Matrix( + [ + [X[0, 0]*Y[0, 0], X[0, 0]*Y[0, 1], X[0, 1]*Y[0, 0], X[0, 1]*Y[0, 1]], + [X[0, 0]*Y[1, 0], X[0, 0]*Y[1, 1], X[0, 1]*Y[1, 0], X[0, 1]*Y[1, 1]], + [X[1, 0]*Y[0, 0], X[1, 0]*Y[0, 1], X[1, 1]*Y[0, 0], X[1, 1]*Y[0, 1]], + [X[1, 0]*Y[1, 0], X[1, 0]*Y[1, 1], X[1, 1]*Y[1, 0], X[1, 1]*Y[1, 1]] + ] + ) + + +def test_tensor_product_adjoint(): + assert KroneckerProduct(I*A, B).adjoint() == \ + -I*KroneckerProduct(A.adjoint(), B.adjoint()) + assert KroneckerProduct(mat1, mat2).adjoint() == \ + kronecker_product(mat1.adjoint(), mat2.adjoint()) + + +def test_tensor_product_conjugate(): + assert KroneckerProduct(I*A, B).conjugate() == \ + -I*KroneckerProduct(A.conjugate(), B.conjugate()) + assert KroneckerProduct(mat1, mat2).conjugate() == \ + kronecker_product(mat1.conjugate(), mat2.conjugate()) + + +def test_tensor_product_transpose(): + assert KroneckerProduct(I*A, B).transpose() == \ + I*KroneckerProduct(A.transpose(), B.transpose()) + assert KroneckerProduct(mat1, mat2).transpose() == \ + kronecker_product(mat1.transpose(), mat2.transpose()) + + +def test_KroneckerProduct_is_associative(): + assert kronecker_product(A, kronecker_product( + B, C)) == kronecker_product(kronecker_product(A, B), C) + assert kronecker_product(A, kronecker_product( + B, C)) == KroneckerProduct(A, B, C) + + +def test_KroneckerProduct_is_bilinear(): + assert kronecker_product(x*A, B) == x*kronecker_product(A, B) + assert kronecker_product(A, x*B) == x*kronecker_product(A, B) + + +def test_KroneckerProduct_determinant(): + kp = kronecker_product(W, Z) + assert det(kp) == det(W)**n * det(Z)**m + + +def test_KroneckerProduct_trace(): + kp = kronecker_product(W, Z) + assert trace(kp) == trace(W)*trace(Z) + + +def test_KroneckerProduct_isnt_commutative(): + assert KroneckerProduct(A, B) != KroneckerProduct(B, A) + assert KroneckerProduct(A, B).is_commutative is False + + +def test_KroneckerProduct_extracts_commutative_part(): + assert kronecker_product(x * A, 2 * B) == x * \ + 2 * KroneckerProduct(A, B) + + +def test_KroneckerProduct_inverse(): + kp = kronecker_product(W, Z) + assert kp.inverse() == kronecker_product(W.inverse(), Z.inverse()) + + +def test_KroneckerProduct_combine_add(): + kp1 = kronecker_product(A, B) + kp2 = kronecker_product(C, W) + assert combine_kronecker(kp1*kp2) == kronecker_product(A*C, B*W) + + +def test_KroneckerProduct_combine_mul(): + X = MatrixSymbol('X', m, n) + Y = MatrixSymbol('Y', m, n) + kp1 = kronecker_product(A, X) + kp2 = kronecker_product(B, Y) + assert combine_kronecker(kp1+kp2) == kronecker_product(A+B, X+Y) + + +def test_KroneckerProduct_combine_pow(): + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', n, n) + assert combine_kronecker(KroneckerProduct( + X, Y)**x) == KroneckerProduct(X**x, Y**x) + assert combine_kronecker(x * KroneckerProduct(X, Y) + ** 2) == x * KroneckerProduct(X**2, Y**2) + assert combine_kronecker( + x * (KroneckerProduct(X, Y)**2) * KroneckerProduct(A, B)) == x * KroneckerProduct(X**2 * A, Y**2 * B) + # cannot simplify because of non-square arguments to kronecker product: + assert combine_kronecker(KroneckerProduct(A, B.T) ** m) == KroneckerProduct(A, B.T) ** m + + +def test_KroneckerProduct_expand(): + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', n, n) + + assert KroneckerProduct(X + Y, Y + Z).expand(kroneckerproduct=True) == \ + KroneckerProduct(X, Y) + KroneckerProduct(X, Z) + \ + KroneckerProduct(Y, Y) + KroneckerProduct(Y, Z) + +def test_KroneckerProduct_entry(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', o, p) + + assert KroneckerProduct(A, B)._entry(i, j) == A[Mod(floor(i/o), n), Mod(floor(j/p), m)]*B[Mod(i, o), Mod(j, p)] diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_matadd.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_matadd.py new file mode 100644 index 0000000000000000000000000000000000000000..43229ae8c2e42f0253a5f3eceefa5fffe7a99f29 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_matadd.py @@ -0,0 +1,58 @@ +from sympy.matrices.expressions import MatrixSymbol, MatAdd, MatPow, MatMul +from sympy.matrices.expressions.special import GenericZeroMatrix, ZeroMatrix +from sympy.matrices.exceptions import ShapeError +from sympy.matrices import eye, ImmutableMatrix +from sympy.core import Add, Basic, S +from sympy.core.add import add +from sympy.testing.pytest import XFAIL, raises + +X = MatrixSymbol('X', 2, 2) +Y = MatrixSymbol('Y', 2, 2) + +def test_evaluate(): + assert MatAdd(X, X, evaluate=True) == add(X, X, evaluate=True) == MatAdd(X, X).doit() + +def test_sort_key(): + assert MatAdd(Y, X).doit().args == add(Y, X).doit().args == (X, Y) + + +def test_matadd_sympify(): + assert isinstance(MatAdd(eye(1), eye(1)).args[0], Basic) + assert isinstance(add(eye(1), eye(1)).args[0], Basic) + + +def test_matadd_of_matrices(): + assert MatAdd(eye(2), 4*eye(2), eye(2)).doit() == ImmutableMatrix(6*eye(2)) + assert add(eye(2), 4*eye(2), eye(2)).doit() == ImmutableMatrix(6*eye(2)) + + +def test_doit_args(): + A = ImmutableMatrix([[1, 2], [3, 4]]) + B = ImmutableMatrix([[2, 3], [4, 5]]) + assert MatAdd(A, MatPow(B, 2)).doit() == A + B**2 + assert MatAdd(A, MatMul(A, B)).doit() == A + A*B + assert (MatAdd(A, X, MatMul(A, B), Y, MatAdd(2*A, B)).doit() == + add(A, X, MatMul(A, B), Y, add(2*A, B)).doit() == + MatAdd(3*A + A*B + B, X, Y)) + + +def test_generic_identity(): + assert MatAdd.identity == GenericZeroMatrix() + assert MatAdd.identity != S.Zero + + +def test_zero_matrix_add(): + assert Add(ZeroMatrix(2, 2), ZeroMatrix(2, 2)) == ZeroMatrix(2, 2) + +@XFAIL +def test_matrix_Add_with_scalar(): + raises(TypeError, lambda: Add(0, ZeroMatrix(2, 2))) + + +def test_shape_error(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 3, 3) + raises(ShapeError, lambda: MatAdd(A, B)) + + A = MatrixSymbol('A', 3, 2) + raises(ShapeError, lambda: MatAdd(A, B)) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_matexpr.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_matexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..f2319e8d8097c2ad3519eab783c4665623c55b80 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_matexpr.py @@ -0,0 +1,592 @@ +from sympy.concrete.summations import Sum +from sympy.core.exprtools import gcd_terms +from sympy.core.function import (diff, expand) +from sympy.core.relational import Eq +from sympy.core.symbol import (Dummy, Symbol, Str) +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.dense import zeros +from sympy.polys.polytools import factor + +from sympy.core import (S, symbols, Add, Mul, SympifyError, Rational, + Function) +from sympy.functions import sin, cos, tan, sqrt, cbrt, exp +from sympy.simplify import simplify +from sympy.matrices import (ImmutableMatrix, Inverse, MatAdd, MatMul, + MatPow, Matrix, MatrixExpr, MatrixSymbol, + SparseMatrix, Transpose, Adjoint, MatrixSet) +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices.expressions.determinant import Determinant, det +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.matrices.expressions.special import ZeroMatrix, Identity +from sympy.testing.pytest import raises, XFAIL, skip +from importlib.metadata import version + +n, m, l, k, p = symbols('n m l k p', integer=True) +x = symbols('x') +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) +D = MatrixSymbol('D', n, n) +E = MatrixSymbol('E', m, n) +w = MatrixSymbol('w', n, 1) + + +def test_matrix_symbol_creation(): + assert MatrixSymbol('A', 2, 2) + assert MatrixSymbol('A', 0, 0) + raises(ValueError, lambda: MatrixSymbol('A', -1, 2)) + raises(ValueError, lambda: MatrixSymbol('A', 2.0, 2)) + raises(ValueError, lambda: MatrixSymbol('A', 2j, 2)) + raises(ValueError, lambda: MatrixSymbol('A', 2, -1)) + raises(ValueError, lambda: MatrixSymbol('A', 2, 2.0)) + raises(ValueError, lambda: MatrixSymbol('A', 2, 2j)) + + n = symbols('n') + assert MatrixSymbol('A', n, n) + n = symbols('n', integer=False) + raises(ValueError, lambda: MatrixSymbol('A', n, n)) + n = symbols('n', negative=True) + raises(ValueError, lambda: MatrixSymbol('A', n, n)) + + +def test_matexpr_properties(): + assert A.shape == (n, m) + assert (A * B).shape == (n, l) + assert A[0, 1].indices == (0, 1) + assert A[0, 0].symbol == A + assert A[0, 0].symbol.name == 'A' + + +def test_matexpr(): + assert (x*A).shape == A.shape + assert (x*A).__class__ == MatMul + assert 2*A - A - A == ZeroMatrix(*A.shape) + assert (A*B).shape == (n, l) + + +def test_matexpr_subs(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', m, l) + C = MatrixSymbol('C', m, l) + + assert A.subs(n, m).shape == (m, m) + assert (A*B).subs(B, C) == A*C + assert (A*B).subs(l, n).is_square + + W = MatrixSymbol("W", 3, 3) + X = MatrixSymbol("X", 2, 2) + Y = MatrixSymbol("Y", 1, 2) + Z = MatrixSymbol("Z", n, 2) + # no restrictions on Symbol replacement + assert X.subs(X, Y) == Y + # it might be better to just change the name + y = Str('y') + assert X.subs(Str("X"), y).args == (y, 2, 2) + # it's ok to introduce a wider matrix + assert X[1, 1].subs(X, W) == W[1, 1] + # but for a given MatrixExpression, only change + # name if indexing on the new shape is valid. + # Here, X is 2,2; Y is 1,2 and Y[1, 1] is out + # of range so an error is raised + raises(IndexError, lambda: X[1, 1].subs(X, Y)) + # here, [0, 1] is in range so the subs succeeds + assert X[0, 1].subs(X, Y) == Y[0, 1] + # and here the size of n will accept any index + # in the first position + assert W[2, 1].subs(W, Z) == Z[2, 1] + # but not in the second position + raises(IndexError, lambda: W[2, 2].subs(W, Z)) + # any matrix should raise if invalid + raises(IndexError, lambda: W[2, 2].subs(W, zeros(2))) + + A = SparseMatrix([[1, 2], [3, 4]]) + B = Matrix([[1, 2], [3, 4]]) + C, D = MatrixSymbol('C', 2, 2), MatrixSymbol('D', 2, 2) + + assert (C*D).subs({C: A, D: B}) == MatMul(A, B) + + +def test_addition(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', n, m) + + assert isinstance(A + B, MatAdd) + assert (A + B).shape == A.shape + assert isinstance(A - A + 2*B, MatMul) + + raises(TypeError, lambda: A + 1) + raises(TypeError, lambda: 5 + A) + raises(TypeError, lambda: 5 - A) + + assert A + ZeroMatrix(n, m) - A == ZeroMatrix(n, m) + raises(TypeError, lambda: ZeroMatrix(n, m) + S.Zero) + + +def test_multiplication(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', m, l) + C = MatrixSymbol('C', n, n) + + assert (2*A*B).shape == (n, l) + assert (A*0*B) == ZeroMatrix(n, l) + assert (2*A).shape == A.shape + + assert A * ZeroMatrix(m, m) * B == ZeroMatrix(n, l) + + assert C * Identity(n) * C.I == Identity(n) + + assert B/2 == S.Half*B + raises(NotImplementedError, lambda: 2/B) + + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert Identity(n) * (A + B) == A + B + + assert A**2*A == A**3 + assert A**2*(A.I)**3 == A.I + assert A**3*(A.I)**2 == A + + +def test_MatPow(): + A = MatrixSymbol('A', n, n) + + AA = MatPow(A, 2) + assert AA.exp == 2 + assert AA.base == A + assert (A**n).exp == n + + assert A**0 == Identity(n) + assert A**1 == A + assert A**2 == AA + assert A**-1 == Inverse(A) + assert (A**-1)**-1 == A + assert (A**2)**3 == A**6 + assert A**S.Half == sqrt(A) + assert A**Rational(1, 3) == cbrt(A) + raises(NonSquareMatrixError, lambda: MatrixSymbol('B', 3, 2)**2) + + +def test_MatrixSymbol(): + n, m, t = symbols('n,m,t') + X = MatrixSymbol('X', n, m) + assert X.shape == (n, m) + raises(TypeError, lambda: MatrixSymbol('X', n, m)(t)) # issue 5855 + assert X.doit() == X + + +def test_dense_conversion(): + X = MatrixSymbol('X', 2, 2) + assert ImmutableMatrix(X) == ImmutableMatrix(2, 2, lambda i, j: X[i, j]) + assert Matrix(X) == Matrix(2, 2, lambda i, j: X[i, j]) + + +def test_free_symbols(): + assert (C*D).free_symbols == {C, D} + + +def test_zero_matmul(): + assert isinstance(S.Zero * MatrixSymbol('X', 2, 2), MatrixExpr) + + +def test_matadd_simplify(): + A = MatrixSymbol('A', 1, 1) + assert simplify(MatAdd(A, ImmutableMatrix([[sin(x)**2 + cos(x)**2]]))) == \ + MatAdd(A, Matrix([[1]])) + + +def test_matmul_simplify(): + A = MatrixSymbol('A', 1, 1) + assert simplify(MatMul(A, ImmutableMatrix([[sin(x)**2 + cos(x)**2]]))) == \ + MatMul(A, Matrix([[1]])) + + +def test_invariants(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', m, l) + X = MatrixSymbol('X', n, n) + objs = [Identity(n), ZeroMatrix(m, n), A, MatMul(A, B), MatAdd(A, A), + Transpose(A), Adjoint(A), Inverse(X), MatPow(X, 2), MatPow(X, -1), + MatPow(X, 0)] + for obj in objs: + assert obj == obj.__class__(*obj.args) + + +def test_matexpr_indexing(): + A = MatrixSymbol('A', n, m) + A[1, 2] + A[l, k] + A[l + 1, k + 1] + A = MatrixSymbol('A', 2, 1) + for i in range(-2, 2): + for j in range(-1, 1): + A[i, j] + + +def test_single_indexing(): + A = MatrixSymbol('A', 2, 3) + assert A[1] == A[0, 1] + assert A[int(1)] == A[0, 1] + assert A[3] == A[1, 0] + assert list(A[:2, :2]) == [A[0, 0], A[0, 1], A[1, 0], A[1, 1]] + raises(IndexError, lambda: A[6]) + raises(IndexError, lambda: A[n]) + B = MatrixSymbol('B', n, m) + raises(IndexError, lambda: B[1]) + B = MatrixSymbol('B', n, 3) + assert B[3] == B[1, 0] + + +def test_MatrixElement_commutative(): + assert A[0, 1]*A[1, 0] == A[1, 0]*A[0, 1] + + +def test_MatrixSymbol_determinant(): + A = MatrixSymbol('A', 4, 4) + assert A.as_explicit().det() == A[0, 0]*A[1, 1]*A[2, 2]*A[3, 3] - \ + A[0, 0]*A[1, 1]*A[2, 3]*A[3, 2] - A[0, 0]*A[1, 2]*A[2, 1]*A[3, 3] + \ + A[0, 0]*A[1, 2]*A[2, 3]*A[3, 1] + A[0, 0]*A[1, 3]*A[2, 1]*A[3, 2] - \ + A[0, 0]*A[1, 3]*A[2, 2]*A[3, 1] - A[0, 1]*A[1, 0]*A[2, 2]*A[3, 3] + \ + A[0, 1]*A[1, 0]*A[2, 3]*A[3, 2] + A[0, 1]*A[1, 2]*A[2, 0]*A[3, 3] - \ + A[0, 1]*A[1, 2]*A[2, 3]*A[3, 0] - A[0, 1]*A[1, 3]*A[2, 0]*A[3, 2] + \ + A[0, 1]*A[1, 3]*A[2, 2]*A[3, 0] + A[0, 2]*A[1, 0]*A[2, 1]*A[3, 3] - \ + A[0, 2]*A[1, 0]*A[2, 3]*A[3, 1] - A[0, 2]*A[1, 1]*A[2, 0]*A[3, 3] + \ + A[0, 2]*A[1, 1]*A[2, 3]*A[3, 0] + A[0, 2]*A[1, 3]*A[2, 0]*A[3, 1] - \ + A[0, 2]*A[1, 3]*A[2, 1]*A[3, 0] - A[0, 3]*A[1, 0]*A[2, 1]*A[3, 2] + \ + A[0, 3]*A[1, 0]*A[2, 2]*A[3, 1] + A[0, 3]*A[1, 1]*A[2, 0]*A[3, 2] - \ + A[0, 3]*A[1, 1]*A[2, 2]*A[3, 0] - A[0, 3]*A[1, 2]*A[2, 0]*A[3, 1] + \ + A[0, 3]*A[1, 2]*A[2, 1]*A[3, 0] + + B = MatrixSymbol('B', 4, 4) + assert Determinant(A + B).doit() == det(A + B) == (A + B).det() + + +def test_MatrixElement_diff(): + assert (A[3, 0]*A[0, 0]).diff(A[0, 0]) == A[3, 0] + + +def test_MatrixElement_doit(): + u = MatrixSymbol('u', 2, 1) + v = ImmutableMatrix([3, 5]) + assert u[0, 0].subs(u, v).doit() == v[0, 0] + + +def test_identity_powers(): + M = Identity(n) + assert MatPow(M, 3).doit() == M**3 + assert M**n == M + assert MatPow(M, 0).doit() == M**2 + assert M**-2 == M + assert MatPow(M, -2).doit() == M**0 + N = Identity(3) + assert MatPow(N, 2).doit() == N**n + assert MatPow(N, 3).doit() == N + assert MatPow(N, -2).doit() == N**4 + assert MatPow(N, 2).doit() == N**0 + + +def test_Zero_power(): + z1 = ZeroMatrix(n, n) + assert z1**4 == z1 + raises(ValueError, lambda:z1**-2) + assert z1**0 == Identity(n) + assert MatPow(z1, 2).doit() == z1**2 + raises(ValueError, lambda:MatPow(z1, -2).doit()) + z2 = ZeroMatrix(3, 3) + assert MatPow(z2, 4).doit() == z2**4 + raises(ValueError, lambda:z2**-3) + assert z2**3 == MatPow(z2, 3).doit() + assert z2**0 == Identity(3) + raises(ValueError, lambda:MatPow(z2, -1).doit()) + + +def test_matrixelement_diff(): + dexpr = diff((D*w)[k,0], w[p,0]) + + assert w[k, p].diff(w[k, p]) == 1 + assert w[k, p].diff(w[0, 0]) == KroneckerDelta(0, k, (0, n-1))*KroneckerDelta(0, p, (0, 0)) + _i_1 = Dummy("_i_1") + assert dexpr.dummy_eq(Sum(KroneckerDelta(_i_1, p, (0, n-1))*D[k, _i_1], (_i_1, 0, n - 1))) + assert dexpr.doit() == D[k, p] + + +def test_MatrixElement_with_values(): + x, y, z, w = symbols("x y z w") + M = Matrix([[x, y], [z, w]]) + i, j = symbols("i, j") + Mij = M[i, j] + assert isinstance(Mij, MatrixElement) + Ms = SparseMatrix([[2, 3], [4, 5]]) + msij = Ms[i, j] + assert isinstance(msij, MatrixElement) + for oi, oj in [(0, 0), (0, 1), (1, 0), (1, 1)]: + assert Mij.subs({i: oi, j: oj}) == M[oi, oj] + assert msij.subs({i: oi, j: oj}) == Ms[oi, oj] + A = MatrixSymbol("A", 2, 2) + assert A[0, 0].subs(A, M) == x + assert A[i, j].subs(A, M) == M[i, j] + assert M[i, j].subs(M, A) == A[i, j] + + assert isinstance(M[3*i - 2, j], MatrixElement) + assert M[3*i - 2, j].subs({i: 1, j: 0}) == M[1, 0] + assert isinstance(M[i, 0], MatrixElement) + assert M[i, 0].subs(i, 0) == M[0, 0] + assert M[0, i].subs(i, 1) == M[0, 1] + + assert M[i, j].diff(x) == Matrix([[1, 0], [0, 0]])[i, j] + + raises(ValueError, lambda: M[i, 2]) + raises(ValueError, lambda: M[i, -1]) + raises(ValueError, lambda: M[2, i]) + raises(ValueError, lambda: M[-1, i]) + + +def test_inv(): + B = MatrixSymbol('B', 3, 3) + assert B.inv() == B**-1 + + # https://github.com/sympy/sympy/issues/19162 + X = MatrixSymbol('X', 1, 1).as_explicit() + assert X.inv() == Matrix([[1/X[0, 0]]]) + + X = MatrixSymbol('X', 2, 2).as_explicit() + detX = X[0, 0]*X[1, 1] - X[0, 1]*X[1, 0] + invX = Matrix([[ X[1, 1], -X[0, 1]], + [-X[1, 0], X[0, 0]]]) / detX + assert X.inv() == invX + + +@XFAIL +def test_factor_expand(): + A = MatrixSymbol("A", n, n) + B = MatrixSymbol("B", n, n) + expr1 = (A + B)*(C + D) + expr2 = A*C + B*C + A*D + B*D + assert expr1 != expr2 + assert expand(expr1) == expr2 + assert factor(expr2) == expr1 + + expr = B**(-1)*(A**(-1)*B**(-1) - A**(-1)*C*B**(-1))**(-1)*A**(-1) + I = Identity(n) + # Ideally we get the first, but we at least don't want a wrong answer + assert factor(expr) in [I - C, B**-1*(A**-1*(I - C)*B**-1)**-1*A**-1] + +def test_numpy_conversion(): + try: + from numpy import array, array_equal + except ImportError: + skip('NumPy must be available to test creating matrices from ndarrays') + A = MatrixSymbol('A', 2, 2) + np_array = array([[MatrixElement(A, 0, 0), MatrixElement(A, 0, 1)], + [MatrixElement(A, 1, 0), MatrixElement(A, 1, 1)]]) + assert array_equal(array(A), np_array) + assert array_equal(array(A, copy=True), np_array) + if(int(version('numpy').split('.')[0]) >= 2): #run this test only if numpy is new enough that copy variable is passed properly. + raises(TypeError, lambda: array(A, copy=False)) + +def test_issue_2749(): + A = MatrixSymbol("A", 5, 2) + assert (A.T * A).I.as_explicit() == Matrix([[(A.T * A).I[0, 0], (A.T * A).I[0, 1]], \ + [(A.T * A).I[1, 0], (A.T * A).I[1, 1]]]) + + +def test_issue_2750(): + x = MatrixSymbol('x', 1, 1) + assert (x.T*x).as_explicit()**-1 == Matrix([[x[0, 0]**(-2)]]) + + +def test_issue_7842(): + A = MatrixSymbol('A', 3, 1) + B = MatrixSymbol('B', 2, 1) + assert Eq(A, B) == False + assert Eq(A[1,0], B[1, 0]).func is Eq + A = ZeroMatrix(2, 3) + B = ZeroMatrix(2, 3) + assert Eq(A, B) == True + + +def test_issue_21195(): + t = symbols('t') + x = Function('x')(t) + dx = x.diff(t) + exp1 = cos(x) + cos(x)*dx + exp2 = sin(x) + tan(x)*(dx.diff(t)) + exp3 = sin(x)*sin(t)*(dx.diff(t)).diff(t) + A = Matrix([[exp1], [exp2], [exp3]]) + B = Matrix([[exp1.diff(x)], [exp2.diff(x)], [exp3.diff(x)]]) + assert A.diff(x) == B + + +def test_issue_24859(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 3, 2) + J = A*B + Jinv = Matrix(J).adjugate() + u = MatrixSymbol('u', 2, 3) + Jk = Jinv.subs(A, A + x*u) + + expected = B[0, 1]*u[1, 0] + B[1, 1]*u[1, 1] + B[2, 1]*u[1, 2] + assert Jk[0, 0].diff(x) == expected + assert diff(Jk[0, 0], x).doit() == expected + + +def test_MatMul_postprocessor(): + z = zeros(2) + z1 = ZeroMatrix(2, 2) + assert Mul(0, z) == Mul(z, 0) in [z, z1] + + M = Matrix([[1, 2], [3, 4]]) + Mx = Matrix([[x, 2*x], [3*x, 4*x]]) + assert Mul(x, M) == Mul(M, x) == Mx + + A = MatrixSymbol("A", 2, 2) + assert Mul(A, M) == MatMul(A, M) + assert Mul(M, A) == MatMul(M, A) + # Scalars should be absorbed into constant matrices + a = Mul(x, M, A) + b = Mul(M, x, A) + c = Mul(M, A, x) + assert a == b == c == MatMul(Mx, A) + a = Mul(x, A, M) + b = Mul(A, x, M) + c = Mul(A, M, x) + assert a == b == c == MatMul(A, Mx) + assert Mul(M, M) == M**2 + assert Mul(A, M, M) == MatMul(A, M**2) + assert Mul(M, M, A) == MatMul(M**2, A) + assert Mul(M, A, M) == MatMul(M, A, M) + + assert Mul(A, x, M, M, x) == MatMul(A, Mx**2) + + +@XFAIL +def test_MatAdd_postprocessor_xfail(): + # This is difficult to get working because of the way that Add processes + # its args. + z = zeros(2) + assert Add(z, S.NaN) == Add(S.NaN, z) + + +def test_MatAdd_postprocessor(): + # Some of these are nonsensical, but we do not raise errors for Add + # because that breaks algorithms that want to replace matrices with dummy + # symbols. + + z = zeros(2) + + assert Add(0, z) == Add(z, 0) == z + + a = Add(S.Infinity, z) + assert a == Add(z, S.Infinity) + assert isinstance(a, Add) + assert a.args == (S.Infinity, z) + + a = Add(S.ComplexInfinity, z) + assert a == Add(z, S.ComplexInfinity) + assert isinstance(a, Add) + assert a.args == (S.ComplexInfinity, z) + + a = Add(z, S.NaN) + # assert a == Add(S.NaN, z) # See the XFAIL above + assert isinstance(a, Add) + assert a.args == (S.NaN, z) + + M = Matrix([[1, 2], [3, 4]]) + a = Add(x, M) + assert a == Add(M, x) + assert isinstance(a, Add) + assert a.args == (x, M) + + A = MatrixSymbol("A", 2, 2) + assert Add(A, M) == Add(M, A) == A + M + + # Scalars should be absorbed into constant matrices (producing an error) + a = Add(x, M, A) + assert a == Add(M, x, A) == Add(M, A, x) == Add(x, A, M) == Add(A, x, M) == Add(A, M, x) + assert isinstance(a, Add) + assert a.args == (x, A + M) + + assert Add(M, M) == 2*M + assert Add(M, A, M) == Add(M, M, A) == Add(A, M, M) == A + 2*M + + a = Add(A, x, M, M, x) + assert isinstance(a, Add) + assert a.args == (2*x, A + 2*M) + + +def test_simplify_matrix_expressions(): + # Various simplification functions + assert type(gcd_terms(C*D + D*C)) == MatAdd + a = gcd_terms(2*C*D + 4*D*C) + assert type(a) == MatAdd + assert a.args == (2*C*D, 4*D*C) + + +def test_exp(): + A = MatrixSymbol('A', 2, 2) + B = MatrixSymbol('B', 2, 2) + expr1 = exp(A)*exp(B) + expr2 = exp(B)*exp(A) + assert expr1 != expr2 + assert expr1 - expr2 != 0 + assert not isinstance(expr1, exp) + assert not isinstance(expr2, exp) + + +def test_invalid_args(): + raises(SympifyError, lambda: MatrixSymbol(1, 2, 'A')) + + +def test_matrixsymbol_from_symbol(): + # The label should be preserved during doit and subs + A_label = Symbol('A', complex=True) + A = MatrixSymbol(A_label, 2, 2) + + A_1 = A.doit() + A_2 = A.subs(2, 3) + assert A_1.args == A.args + assert A_2.args[0] == A.args[0] + + +def test_as_explicit(): + Z = MatrixSymbol('Z', 2, 3) + assert Z.as_explicit() == ImmutableMatrix([ + [Z[0, 0], Z[0, 1], Z[0, 2]], + [Z[1, 0], Z[1, 1], Z[1, 2]], + ]) + raises(ValueError, lambda: A.as_explicit()) + + +def test_MatrixSet(): + M = MatrixSet(2, 2, set=S.Reals) + assert M.shape == (2, 2) + assert M.set == S.Reals + X = Matrix([[1, 2], [3, 4]]) + assert X in M + X = ZeroMatrix(2, 2) + assert X in M + raises(TypeError, lambda: A in M) + raises(TypeError, lambda: 1 in M) + M = MatrixSet(n, m, set=S.Reals) + assert A in M + raises(TypeError, lambda: C in M) + raises(TypeError, lambda: X in M) + M = MatrixSet(2, 2, set={1, 2, 3}) + X = Matrix([[1, 2], [3, 4]]) + Y = Matrix([[1, 2]]) + assert (X in M) == S.false + assert (Y in M) == S.false + raises(ValueError, lambda: MatrixSet(2, -2, S.Reals)) + raises(ValueError, lambda: MatrixSet(2.4, -1, S.Reals)) + raises(TypeError, lambda: MatrixSet(2, 2, (1, 2, 3))) + + +def test_matrixsymbol_solving(): + A = MatrixSymbol('A', 2, 2) + B = MatrixSymbol('B', 2, 2) + Z = ZeroMatrix(2, 2) + assert -(-A + B) - A + B == Z + assert (-(-A + B) - A + B).simplify() == Z + assert (-(-A + B) - A + B).expand() == Z + assert (-(-A + B) - A + B - Z).simplify() == Z + assert (-(-A + B) - A + B - Z).expand() == Z + assert (A*(A + B) + B*(A.T + B.T)).expand() == A**2 + A*B + B*A.T + B*B.T diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_matmul.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..813926e2c83e27716f4f894ebebd09b2a576f046 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_matmul.py @@ -0,0 +1,193 @@ +from sympy.core import I, symbols, Basic, Mul, S +from sympy.core.mul import mul +from sympy.functions import adjoint, transpose +from sympy.matrices.exceptions import ShapeError +from sympy.matrices import (Identity, Inverse, Matrix, MatrixSymbol, ZeroMatrix, + eye, ImmutableMatrix) +from sympy.matrices.expressions import Adjoint, Transpose, det, MatPow +from sympy.matrices.expressions.special import GenericIdentity +from sympy.matrices.expressions.matmul import (factor_in_front, remove_ids, + MatMul, combine_powers, any_zeros, unpack, only_squares) +from sympy.strategies import null_safe +from sympy.assumptions.ask import Q +from sympy.assumptions.refine import refine +from sympy.core.symbol import Symbol + +from sympy.testing.pytest import XFAIL, raises + +n, m, l, k = symbols('n m l k', integer=True) +x = symbols('x') +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) +D = MatrixSymbol('D', n, n) +E = MatrixSymbol('E', m, n) + +def test_evaluate(): + assert MatMul(C, C, evaluate=True) == MatMul(C, C).doit() + +def test_adjoint(): + assert adjoint(A*B) == Adjoint(B)*Adjoint(A) + assert adjoint(2*A*B) == 2*Adjoint(B)*Adjoint(A) + assert adjoint(2*I*C) == -2*I*Adjoint(C) + + M = Matrix(2, 2, [1, 2 + I, 3, 4]) + MA = Matrix(2, 2, [1, 3, 2 - I, 4]) + assert adjoint(M) == MA + assert adjoint(2*M) == 2*MA + assert adjoint(MatMul(2, M)) == MatMul(2, MA).doit() + + +def test_transpose(): + assert transpose(A*B) == Transpose(B)*Transpose(A) + assert transpose(2*A*B) == 2*Transpose(B)*Transpose(A) + assert transpose(2*I*C) == 2*I*Transpose(C) + + M = Matrix(2, 2, [1, 2 + I, 3, 4]) + MT = Matrix(2, 2, [1, 3, 2 + I, 4]) + assert transpose(M) == MT + assert transpose(2*M) == 2*MT + assert transpose(x*M) == x*MT + assert transpose(MatMul(2, M)) == MatMul(2, MT).doit() + + +def test_factor_in_front(): + assert factor_in_front(MatMul(A, 2, B, evaluate=False)) ==\ + MatMul(2, A, B, evaluate=False) + + +def test_remove_ids(): + assert remove_ids(MatMul(A, Identity(m), B, evaluate=False)) == \ + MatMul(A, B, evaluate=False) + assert null_safe(remove_ids)(MatMul(Identity(n), evaluate=False)) == \ + MatMul(Identity(n), evaluate=False) + + +def test_combine_powers(): + assert combine_powers(MatMul(D, Inverse(D), D, evaluate=False)) == \ + MatMul(Identity(n), D, evaluate=False) + assert combine_powers(MatMul(B.T, Inverse(E*A), E, A, B, evaluate=False)) == \ + MatMul(B.T, Identity(m), B, evaluate=False) + assert combine_powers(MatMul(A, E, Inverse(A*E), D, evaluate=False)) == \ + MatMul(Identity(n), D, evaluate=False) + + +def test_any_zeros(): + assert any_zeros(MatMul(A, ZeroMatrix(m, k), evaluate=False)) == \ + ZeroMatrix(n, k) + + +def test_unpack(): + assert unpack(MatMul(A, evaluate=False)) == A + x = MatMul(A, B) + assert unpack(x) == x + + +def test_only_squares(): + assert only_squares(C) == [C] + assert only_squares(C, D) == [C, D] + assert only_squares(C, A, A.T, D) == [C, A*A.T, D] + + +def test_determinant(): + assert det(2*C) == 2**n*det(C) + assert det(2*C*D) == 2**n*det(C)*det(D) + assert det(3*C*A*A.T*D) == 3**n*det(C)*det(A*A.T)*det(D) + + +def test_doit(): + assert MatMul(C, 2, D).args == (C, 2, D) + assert MatMul(C, 2, D).doit().args == (2, C, D) + assert MatMul(C, Transpose(D*C)).args == (C, Transpose(D*C)) + assert MatMul(C, Transpose(D*C)).doit(deep=True).args == (C, C.T, D.T) + + +def test_doit_drills_down(): + X = ImmutableMatrix([[1, 2], [3, 4]]) + Y = ImmutableMatrix([[2, 3], [4, 5]]) + assert MatMul(X, MatPow(Y, 2)).doit() == X*Y**2 + assert MatMul(C, Transpose(D*C)).doit().args == (C, C.T, D.T) + + +def test_doit_deep_false_still_canonical(): + assert (MatMul(C, Transpose(D*C), 2).doit(deep=False).args == + (2, C, Transpose(D*C))) + + +def test_matmul_scalar_Matrix_doit(): + # Issue 9053 + X = Matrix([[1, 2], [3, 4]]) + assert MatMul(2, X).doit() == 2*X + + +def test_matmul_sympify(): + assert isinstance(MatMul(eye(1), eye(1)).args[0], Basic) + + +def test_collapse_MatrixBase(): + A = Matrix([[1, 1], [1, 1]]) + B = Matrix([[1, 2], [3, 4]]) + assert MatMul(A, B).doit() == ImmutableMatrix([[4, 6], [4, 6]]) + + +def test_refine(): + assert refine(C*C.T*D, Q.orthogonal(C)).doit() == D + + kC = k*C + assert refine(kC*C.T, Q.orthogonal(C)).doit() == k*Identity(n) + assert refine(kC* kC.T, Q.orthogonal(C)).doit() == (k**2)*Identity(n) + +def test_matmul_no_matrices(): + assert MatMul(1) == 1 + assert MatMul(n, m) == n*m + assert not isinstance(MatMul(n, m), MatMul) + +def test_matmul_args_cnc(): + assert MatMul(n, A, A.T).args_cnc() == [[n], [A, A.T]] + assert MatMul(A, A.T).args_cnc() == [[], [A, A.T]] + +@XFAIL +def test_matmul_args_cnc_symbols(): + # Not currently supported + a, b = symbols('a b', commutative=False) + assert MatMul(n, a, b, A, A.T).args_cnc() == [[n], [a, b, A, A.T]] + assert MatMul(n, a, A, b, A.T).args_cnc() == [[n], [a, A, b, A.T]] + +def test_issue_12950(): + M = Matrix([[Symbol("x")]]) * MatrixSymbol("A", 1, 1) + assert MatrixSymbol("A", 1, 1).as_explicit()[0]*Symbol('x') == M.as_explicit()[0] + +def test_construction_with_Mul(): + assert Mul(C, D) == MatMul(C, D) + assert Mul(D, C) == MatMul(D, C) + +def test_construction_with_mul(): + assert mul(C, D) == MatMul(C, D) + assert mul(D, C) == MatMul(D, C) + assert mul(C, D) != MatMul(D, C) + +def test_generic_identity(): + assert MatMul.identity == GenericIdentity() + assert MatMul.identity != S.One + + +def test_issue_23519(): + N = Symbol("N", integer=True) + M1 = MatrixSymbol("M1", N, N) + M2 = MatrixSymbol("M2", N, N) + I = Identity(N) + z = (M2 + 2 * (M2 + I) * M1 + I) + assert z.coeff(M1) == 2*I + 2*M2 + + +def test_shape_error(): + A = MatrixSymbol('A', 2, 2) + B = MatrixSymbol('B', 3, 3) + raises(ShapeError, lambda: MatMul(A, B)) + + +def test_matmul_transpose(): + # https://github.com/sympy/sympy/issues/9503 + M = Matrix(2, 2, [1, 2 + I, 3, 4]) + a = Symbol('a') + assert (MatMul(a, M).T).expand() == (a*Matrix([[1, 3],[2 + I, 4]])).expand() diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_matpow.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_matpow.py new file mode 100644 index 0000000000000000000000000000000000000000..2afb5fdc2aa652c321de52aba43db63da60941fd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_matpow.py @@ -0,0 +1,217 @@ +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.simplify.powsimp import powsimp +from sympy.testing.pytest import raises +from sympy.core.expr import unchanged +from sympy.core import symbols, S +from sympy.matrices import Identity, MatrixSymbol, ImmutableMatrix, ZeroMatrix, OneMatrix, Matrix +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices.expressions import MatPow, MatAdd, MatMul +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matexpr import MatrixElement + +n, m, l, k = symbols('n m l k', integer=True) +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) +D = MatrixSymbol('D', n, n) +E = MatrixSymbol('E', m, n) + + +def test_entry_matrix(): + X = ImmutableMatrix([[1, 2], [3, 4]]) + assert MatPow(X, 0)[0, 0] == 1 + assert MatPow(X, 0)[0, 1] == 0 + assert MatPow(X, 1)[0, 0] == 1 + assert MatPow(X, 1)[0, 1] == 2 + assert MatPow(X, 2)[0, 0] == 7 + + +def test_entry_symbol(): + from sympy.concrete import Sum + assert MatPow(C, 0)[0, 0] == 1 + assert MatPow(C, 0)[0, 1] == 0 + assert MatPow(C, 1)[0, 0] == C[0, 0] + assert isinstance(MatPow(C, 2)[0, 0], Sum) + assert isinstance(MatPow(C, n)[0, 0], MatrixElement) + + +def test_as_explicit_symbol(): + X = MatrixSymbol('X', 2, 2) + assert MatPow(X, 0).as_explicit() == ImmutableMatrix(Identity(2)) + assert MatPow(X, 1).as_explicit() == X.as_explicit() + assert MatPow(X, 2).as_explicit() == (X.as_explicit())**2 + assert MatPow(X, n).as_explicit() == ImmutableMatrix([ + [(X ** n)[0, 0], (X ** n)[0, 1]], + [(X ** n)[1, 0], (X ** n)[1, 1]], + ]) + + a = MatrixSymbol("a", 3, 1) + b = MatrixSymbol("b", 3, 1) + c = MatrixSymbol("c", 3, 1) + + expr = (a.T*b)**S.Half + assert expr.as_explicit() == Matrix([[sqrt(a[0, 0]*b[0, 0] + a[1, 0]*b[1, 0] + a[2, 0]*b[2, 0])]]) + + expr = c*(a.T*b)**S.Half + m = sqrt(a[0, 0]*b[0, 0] + a[1, 0]*b[1, 0] + a[2, 0]*b[2, 0]) + assert expr.as_explicit() == Matrix([[c[0, 0]*m], [c[1, 0]*m], [c[2, 0]*m]]) + + expr = (a*b.T)**S.Half + denom = sqrt(a[0, 0]*b[0, 0] + a[1, 0]*b[1, 0] + a[2, 0]*b[2, 0]) + expected = (a*b.T).as_explicit()/denom + assert expr.as_explicit() == expected + + expr = X**-1 + det = X[0, 0]*X[1, 1] - X[1, 0]*X[0, 1] + expected = Matrix([[X[1, 1], -X[0, 1]], [-X[1, 0], X[0, 0]]])/det + assert expr.as_explicit() == expected + + expr = X**m + assert expr.as_explicit() == X.as_explicit()**m + + +def test_as_explicit_matrix(): + A = ImmutableMatrix([[1, 2], [3, 4]]) + assert MatPow(A, 0).as_explicit() == ImmutableMatrix(Identity(2)) + assert MatPow(A, 1).as_explicit() == A + assert MatPow(A, 2).as_explicit() == A**2 + assert MatPow(A, -1).as_explicit() == A.inv() + assert MatPow(A, -2).as_explicit() == (A.inv())**2 + # less expensive than testing on a 2x2 + A = ImmutableMatrix([4]) + assert MatPow(A, S.Half).as_explicit() == A**S.Half + + +def test_doit_symbol(): + assert MatPow(C, 0).doit() == Identity(n) + assert MatPow(C, 1).doit() == C + assert MatPow(C, -1).doit() == C.I + for r in [2, S.Half, S.Pi, n]: + assert MatPow(C, r).doit() == MatPow(C, r) + + +def test_doit_matrix(): + X = ImmutableMatrix([[1, 2], [3, 4]]) + assert MatPow(X, 0).doit() == ImmutableMatrix(Identity(2)) + assert MatPow(X, 1).doit() == X + assert MatPow(X, 2).doit() == X**2 + assert MatPow(X, -1).doit() == X.inv() + assert MatPow(X, -2).doit() == (X.inv())**2 + # less expensive than testing on a 2x2 + assert MatPow(ImmutableMatrix([4]), S.Half).doit() == ImmutableMatrix([2]) + X = ImmutableMatrix([[0, 2], [0, 4]]) # det() == 0 + raises(ValueError, lambda: MatPow(X,-1).doit()) + raises(ValueError, lambda: MatPow(X,-2).doit()) + + +def test_nonsquare(): + A = MatrixSymbol('A', 2, 3) + B = ImmutableMatrix([[1, 2, 3], [4, 5, 6]]) + for r in [-1, 0, 1, 2, S.Half, S.Pi, n]: + raises(NonSquareMatrixError, lambda: MatPow(A, r)) + raises(NonSquareMatrixError, lambda: MatPow(B, r)) + + +def test_doit_equals_pow(): #17179 + X = ImmutableMatrix ([[1,0],[0,1]]) + assert MatPow(X, n).doit() == X**n == X + + +def test_doit_nested_MatrixExpr(): + X = ImmutableMatrix([[1, 2], [3, 4]]) + Y = ImmutableMatrix([[2, 3], [4, 5]]) + assert MatPow(MatMul(X, Y), 2).doit() == (X*Y)**2 + assert MatPow(MatAdd(X, Y), 2).doit() == (X + Y)**2 + + +def test_identity_power(): + k = Identity(n) + assert MatPow(k, 4).doit() == k + assert MatPow(k, n).doit() == k + assert MatPow(k, -3).doit() == k + assert MatPow(k, 0).doit() == k + l = Identity(3) + assert MatPow(l, n).doit() == l + assert MatPow(l, -1).doit() == l + assert MatPow(l, 0).doit() == l + + +def test_zero_power(): + z1 = ZeroMatrix(n, n) + assert MatPow(z1, 3).doit() == z1 + raises(ValueError, lambda:MatPow(z1, -1).doit()) + assert MatPow(z1, 0).doit() == Identity(n) + assert MatPow(z1, n).doit() == z1 + raises(ValueError, lambda:MatPow(z1, -2).doit()) + z2 = ZeroMatrix(4, 4) + assert MatPow(z2, n).doit() == z2 + raises(ValueError, lambda:MatPow(z2, -3).doit()) + assert MatPow(z2, 2).doit() == z2 + assert MatPow(z2, 0).doit() == Identity(4) + raises(ValueError, lambda:MatPow(z2, -1).doit()) + + +def test_OneMatrix_power(): + o = OneMatrix(3, 3) + assert o ** 0 == Identity(3) + assert o ** 1 == o + assert o * o == o ** 2 == 3 * o + assert o * o * o == o ** 3 == 9 * o + + o = OneMatrix(n, n) + assert o * o == o ** 2 == n * o + # powsimp necessary as n ** (n - 2) * n does not produce n ** (n - 1) + assert powsimp(o ** (n - 1) * o) == o ** n == n ** (n - 1) * o + + +def test_transpose_power(): + from sympy.matrices.expressions.transpose import Transpose as TP + + assert (C*D).T**5 == ((C*D)**5).T == (D.T * C.T)**5 + assert ((C*D).T**5).T == (C*D)**5 + + assert (C.T.I.T)**7 == C**-7 + assert (C.T**l).T**k == C**(l*k) + + assert ((E.T * A.T)**5).T == (A*E)**5 + assert ((A*E).T**5).T**7 == (A*E)**35 + assert TP(TP(C**2 * D**3)**5).doit() == (C**2 * D**3)**5 + + assert ((D*C)**-5).T**-5 == ((D*C)**25).T + assert (((D*C)**l).T**k).T == (D*C)**(l*k) + + +def test_Inverse(): + assert Inverse(MatPow(C, 0)).doit() == Identity(n) + assert Inverse(MatPow(C, 1)).doit() == Inverse(C) + assert Inverse(MatPow(C, 2)).doit() == MatPow(C, -2) + assert Inverse(MatPow(C, -1)).doit() == C + + assert MatPow(Inverse(C), 0).doit() == Identity(n) + assert MatPow(Inverse(C), 1).doit() == Inverse(C) + assert MatPow(Inverse(C), 2).doit() == MatPow(C, -2) + assert MatPow(Inverse(C), -1).doit() == C + + +def test_combine_powers(): + assert (C ** 1) ** 1 == C + assert (C ** 2) ** 3 == MatPow(C, 6) + assert (C ** -2) ** -3 == MatPow(C, 6) + assert (C ** -1) ** -1 == C + assert (((C ** 2) ** 3) ** 4) ** 5 == MatPow(C, 120) + assert (C ** n) ** n == C ** (n ** 2) + + +def test_unchanged(): + assert unchanged(MatPow, C, 0) + assert unchanged(MatPow, C, 1) + assert unchanged(MatPow, Inverse(C), -1) + assert unchanged(Inverse, MatPow(C, -1), -1) + assert unchanged(MatPow, MatPow(C, -1), -1) + assert unchanged(MatPow, MatPow(C, 1), 1) + + +def test_no_exponentiation(): + # if this passes, Pow.as_numer_denom should recognize + # MatAdd as exponent + raises(NotImplementedError, lambda: 3**(-2*C)) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_permutation.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_permutation.py new file mode 100644 index 0000000000000000000000000000000000000000..41a924f6636afb2e5b6560987e38a0fa0c861f1e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_permutation.py @@ -0,0 +1,166 @@ +from sympy.combinatorics import Permutation +from sympy.core.expr import unchanged +from sympy.matrices import Matrix +from sympy.matrices.expressions import \ + MatMul, BlockDiagMatrix, Determinant, Inverse +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import ZeroMatrix, OneMatrix, Identity +from sympy.matrices.expressions.permutation import \ + MatrixPermute, PermutationMatrix +from sympy.testing.pytest import raises +from sympy.core.symbol import Symbol + + +def test_PermutationMatrix_basic(): + p = Permutation([1, 0]) + assert unchanged(PermutationMatrix, p) + raises(ValueError, lambda: PermutationMatrix((0, 1, 2))) + assert PermutationMatrix(p).as_explicit() == Matrix([[0, 1], [1, 0]]) + assert isinstance(PermutationMatrix(p)*MatrixSymbol('A', 2, 2), MatMul) + + +def test_PermutationMatrix_matmul(): + p = Permutation([1, 2, 0]) + P = PermutationMatrix(p) + M = Matrix([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + assert (P*M).as_explicit() == P.as_explicit()*M + assert (M*P).as_explicit() == M*P.as_explicit() + + P1 = PermutationMatrix(Permutation([1, 2, 0])) + P2 = PermutationMatrix(Permutation([2, 1, 0])) + P3 = PermutationMatrix(Permutation([1, 0, 2])) + assert P1*P2 == P3 + + +def test_PermutationMatrix_matpow(): + p1 = Permutation([1, 2, 0]) + P1 = PermutationMatrix(p1) + p2 = Permutation([2, 0, 1]) + P2 = PermutationMatrix(p2) + assert P1**2 == P2 + assert P1**3 == Identity(3) + + +def test_PermutationMatrix_identity(): + p = Permutation([0, 1]) + assert PermutationMatrix(p).is_Identity + + p = Permutation([1, 0]) + assert not PermutationMatrix(p).is_Identity + + +def test_PermutationMatrix_determinant(): + P = PermutationMatrix(Permutation([0, 1, 2])) + assert Determinant(P).doit() == 1 + P = PermutationMatrix(Permutation([0, 2, 1])) + assert Determinant(P).doit() == -1 + P = PermutationMatrix(Permutation([2, 0, 1])) + assert Determinant(P).doit() == 1 + + +def test_PermutationMatrix_inverse(): + P = PermutationMatrix(Permutation(0, 1, 2)) + assert Inverse(P).doit() == PermutationMatrix(Permutation(0, 2, 1)) + + +def test_PermutationMatrix_rewrite_BlockDiagMatrix(): + P = PermutationMatrix(Permutation([0, 1, 2, 3, 4, 5])) + P0 = PermutationMatrix(Permutation([0])) + assert P.rewrite(BlockDiagMatrix) == \ + BlockDiagMatrix(P0, P0, P0, P0, P0, P0) + + P = PermutationMatrix(Permutation([0, 1, 3, 2, 4, 5])) + P10 = PermutationMatrix(Permutation(0, 1)) + assert P.rewrite(BlockDiagMatrix) == \ + BlockDiagMatrix(P0, P0, P10, P0, P0) + + P = PermutationMatrix(Permutation([1, 0, 3, 2, 5, 4])) + assert P.rewrite(BlockDiagMatrix) == \ + BlockDiagMatrix(P10, P10, P10) + + P = PermutationMatrix(Permutation([0, 4, 3, 2, 1, 5])) + P3210 = PermutationMatrix(Permutation([3, 2, 1, 0])) + assert P.rewrite(BlockDiagMatrix) == \ + BlockDiagMatrix(P0, P3210, P0) + + P = PermutationMatrix(Permutation([0, 4, 2, 3, 1, 5])) + P3120 = PermutationMatrix(Permutation([3, 1, 2, 0])) + assert P.rewrite(BlockDiagMatrix) == \ + BlockDiagMatrix(P0, P3120, P0) + + P = PermutationMatrix(Permutation(0, 3)(1, 4)(2, 5)) + assert P.rewrite(BlockDiagMatrix) == BlockDiagMatrix(P) + + +def test_MartrixPermute_basic(): + p = Permutation(0, 1) + P = PermutationMatrix(p) + A = MatrixSymbol('A', 2, 2) + + raises(ValueError, lambda: MatrixPermute(Symbol('x'), p)) + raises(ValueError, lambda: MatrixPermute(A, Symbol('x'))) + + assert MatrixPermute(A, P) == MatrixPermute(A, p) + raises(ValueError, lambda: MatrixPermute(A, p, 2)) + + pp = Permutation(0, 1, size=3) + assert MatrixPermute(A, pp) == MatrixPermute(A, p) + pp = Permutation(0, 1, 2) + raises(ValueError, lambda: MatrixPermute(A, pp)) + + +def test_MatrixPermute_shape(): + p = Permutation(0, 1) + A = MatrixSymbol('A', 2, 3) + assert MatrixPermute(A, p).shape == (2, 3) + + +def test_MatrixPermute_explicit(): + p = Permutation(0, 1, 2) + A = MatrixSymbol('A', 3, 3) + AA = A.as_explicit() + assert MatrixPermute(A, p, 0).as_explicit() == \ + AA.permute(p, orientation='rows') + assert MatrixPermute(A, p, 1).as_explicit() == \ + AA.permute(p, orientation='cols') + + +def test_MatrixPermute_rewrite_MatMul(): + p = Permutation(0, 1, 2) + A = MatrixSymbol('A', 3, 3) + + assert MatrixPermute(A, p, 0).rewrite(MatMul).as_explicit() == \ + MatrixPermute(A, p, 0).as_explicit() + assert MatrixPermute(A, p, 1).rewrite(MatMul).as_explicit() == \ + MatrixPermute(A, p, 1).as_explicit() + + +def test_MatrixPermute_doit(): + p = Permutation(0, 1, 2) + A = MatrixSymbol('A', 3, 3) + assert MatrixPermute(A, p).doit() == MatrixPermute(A, p) + + p = Permutation(0, size=3) + A = MatrixSymbol('A', 3, 3) + assert MatrixPermute(A, p).doit().as_explicit() == \ + MatrixPermute(A, p).as_explicit() + + p = Permutation(0, 1, 2) + A = Identity(3) + assert MatrixPermute(A, p, 0).doit().as_explicit() == \ + MatrixPermute(A, p, 0).as_explicit() + assert MatrixPermute(A, p, 1).doit().as_explicit() == \ + MatrixPermute(A, p, 1).as_explicit() + + A = ZeroMatrix(3, 3) + assert MatrixPermute(A, p).doit() == A + A = OneMatrix(3, 3) + assert MatrixPermute(A, p).doit() == A + + A = MatrixSymbol('A', 4, 4) + p1 = Permutation(0, 1, 2, 3) + p2 = Permutation(0, 2, 3, 1) + expr = MatrixPermute(MatrixPermute(A, p1, 0), p2, 0) + assert expr.as_explicit() == expr.doit().as_explicit() + expr = MatrixPermute(MatrixPermute(A, p1, 1), p2, 1) + assert expr.as_explicit() == expr.doit().as_explicit() diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_sets.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_sets.py new file mode 100644 index 0000000000000000000000000000000000000000..e811c7968c5a22d65f1c99e995aaa7e5e59d15c4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_sets.py @@ -0,0 +1,42 @@ +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.matrices import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.sets import MatrixSet +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.testing.pytest import raises +from sympy.sets.sets import SetKind +from sympy.matrices.kind import MatrixKind +from sympy.core.kind import NumberKind + + +def test_MatrixSet(): + n, m = symbols('n m', integer=True) + A = MatrixSymbol('A', n, m) + C = MatrixSymbol('C', n, n) + + M = MatrixSet(2, 2, set=S.Reals) + assert M.shape == (2, 2) + assert M.set == S.Reals + X = Matrix([[1, 2], [3, 4]]) + assert X in M + X = ZeroMatrix(2, 2) + assert X in M + raises(TypeError, lambda: A in M) + raises(TypeError, lambda: 1 in M) + M = MatrixSet(n, m, set=S.Reals) + assert A in M + raises(TypeError, lambda: C in M) + raises(TypeError, lambda: X in M) + M = MatrixSet(2, 2, set={1, 2, 3}) + X = Matrix([[1, 2], [3, 4]]) + Y = Matrix([[1, 2]]) + assert (X in M) == S.false + assert (Y in M) == S.false + raises(ValueError, lambda: MatrixSet(2, -2, S.Reals)) + raises(ValueError, lambda: MatrixSet(2.4, -1, S.Reals)) + raises(TypeError, lambda: MatrixSet(2, 2, (1, 2, 3))) + + +def test_SetKind_MatrixSet(): + assert MatrixSet(2, 2, set=S.Reals).kind is SetKind(MatrixKind(NumberKind)) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_slice.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_slice.py new file mode 100644 index 0000000000000000000000000000000000000000..36490719e26908b9e913ed99b7673d602647c492 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_slice.py @@ -0,0 +1,65 @@ +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices.expressions import MatrixSymbol +from sympy.abc import a, b, c, d, k, l, m, n +from sympy.testing.pytest import raises, XFAIL +from sympy.functions.elementary.integers import floor +from sympy.assumptions import assuming, Q + + +X = MatrixSymbol('X', n, m) +Y = MatrixSymbol('Y', m, k) + +def test_shape(): + B = MatrixSlice(X, (a, b), (c, d)) + assert B.shape == (b - a, d - c) + +def test_entry(): + B = MatrixSlice(X, (a, b), (c, d)) + assert B[0,0] == X[a, c] + assert B[k,l] == X[a+k, c+l] + raises(IndexError, lambda : MatrixSlice(X, 1, (2, 5))[1, 0]) + + assert X[1::2, :][1, 3] == X[1+2, 3] + assert X[:, 1::2][3, 1] == X[3, 1+2] + +def test_on_diag(): + assert not MatrixSlice(X, (a, b), (c, d)).on_diag + assert MatrixSlice(X, (a, b), (a, b)).on_diag + +def test_inputs(): + assert MatrixSlice(X, 1, (2, 5)) == MatrixSlice(X, (1, 2), (2, 5)) + assert MatrixSlice(X, 1, (2, 5)).shape == (1, 3) + +def test_slicing(): + assert X[1:5, 2:4] == MatrixSlice(X, (1, 5), (2, 4)) + assert X[1, 2:4] == MatrixSlice(X, 1, (2, 4)) + assert X[1:5, :].shape == (4, X.shape[1]) + assert X[:, 1:5].shape == (X.shape[0], 4) + + assert X[::2, ::2].shape == (floor(n/2), floor(m/2)) + assert X[2, :] == MatrixSlice(X, 2, (0, m)) + assert X[k, :] == MatrixSlice(X, k, (0, m)) + +def test_exceptions(): + X = MatrixSymbol('x', 10, 20) + raises(IndexError, lambda: X[0:12, 2]) + raises(IndexError, lambda: X[0:9, 22]) + raises(IndexError, lambda: X[-1:5, 2]) + +@XFAIL +def test_symmetry(): + X = MatrixSymbol('x', 10, 10) + Y = X[:5, 5:] + with assuming(Q.symmetric(X)): + assert Y.T == X[5:, :5] + +def test_slice_of_slice(): + X = MatrixSymbol('x', 10, 10) + assert X[2, :][:, 3][0, 0] == X[2, 3] + assert X[:5, :5][:4, :4] == X[:4, :4] + assert X[1:5, 2:6][1:3, 2] == X[2:4, 4] + assert X[1:9:2, 2:6][1:3, 2] == X[3:7:2, 4] + +def test_negative_index(): + X = MatrixSymbol('x', 10, 10) + assert X[-1, :] == X[9, :] diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_special.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_special.py new file mode 100644 index 0000000000000000000000000000000000000000..beeaf1d76a63673b6622709cda598dfcb295bba4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_special.py @@ -0,0 +1,228 @@ +from sympy.core.add import Add +from sympy.core.expr import unchanged +from sympy.core.mul import Mul +from sympy.core.symbol import symbols +from sympy.core.relational import Eq +from sympy.concrete.summations import Sum +from sympy.functions.elementary.complexes import im, re +from sympy.functions.elementary.piecewise import Piecewise +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.matadd import MatAdd +from sympy.matrices.expressions.special import ( + ZeroMatrix, GenericZeroMatrix, Identity, GenericIdentity, OneMatrix) +from sympy.matrices.expressions.matmul import MatMul +from sympy.testing.pytest import raises + + +def test_zero_matrix_creation(): + assert unchanged(ZeroMatrix, 2, 2) + assert unchanged(ZeroMatrix, 0, 0) + raises(ValueError, lambda: ZeroMatrix(-1, 2)) + raises(ValueError, lambda: ZeroMatrix(2.0, 2)) + raises(ValueError, lambda: ZeroMatrix(2j, 2)) + raises(ValueError, lambda: ZeroMatrix(2, -1)) + raises(ValueError, lambda: ZeroMatrix(2, 2.0)) + raises(ValueError, lambda: ZeroMatrix(2, 2j)) + + n = symbols('n') + assert unchanged(ZeroMatrix, n, n) + n = symbols('n', integer=False) + raises(ValueError, lambda: ZeroMatrix(n, n)) + n = symbols('n', negative=True) + raises(ValueError, lambda: ZeroMatrix(n, n)) + + +def test_generic_zero_matrix(): + z = GenericZeroMatrix() + n = symbols('n', integer=True) + A = MatrixSymbol("A", n, n) + + assert z == z + assert z != A + assert A != z + + assert z.is_ZeroMatrix + + raises(TypeError, lambda: z.shape) + raises(TypeError, lambda: z.rows) + raises(TypeError, lambda: z.cols) + + assert MatAdd() == z + assert MatAdd(z, A) == MatAdd(A) + # Make sure it is hashable + hash(z) + + +def test_identity_matrix_creation(): + assert Identity(2) + assert Identity(0) + raises(ValueError, lambda: Identity(-1)) + raises(ValueError, lambda: Identity(2.0)) + raises(ValueError, lambda: Identity(2j)) + + n = symbols('n') + assert Identity(n) + n = symbols('n', integer=False) + raises(ValueError, lambda: Identity(n)) + n = symbols('n', negative=True) + raises(ValueError, lambda: Identity(n)) + + +def test_generic_identity(): + I = GenericIdentity() + n = symbols('n', integer=True) + A = MatrixSymbol("A", n, n) + + assert I == I + assert I != A + assert A != I + + assert I.is_Identity + assert I**-1 == I + + raises(TypeError, lambda: I.shape) + raises(TypeError, lambda: I.rows) + raises(TypeError, lambda: I.cols) + + assert MatMul() == I + assert MatMul(I, A) == MatMul(A) + # Make sure it is hashable + hash(I) + + +def test_one_matrix_creation(): + assert OneMatrix(2, 2) + assert OneMatrix(0, 0) + assert Eq(OneMatrix(1, 1), Identity(1)) + raises(ValueError, lambda: OneMatrix(-1, 2)) + raises(ValueError, lambda: OneMatrix(2.0, 2)) + raises(ValueError, lambda: OneMatrix(2j, 2)) + raises(ValueError, lambda: OneMatrix(2, -1)) + raises(ValueError, lambda: OneMatrix(2, 2.0)) + raises(ValueError, lambda: OneMatrix(2, 2j)) + + n = symbols('n') + assert OneMatrix(n, n) + n = symbols('n', integer=False) + raises(ValueError, lambda: OneMatrix(n, n)) + n = symbols('n', negative=True) + raises(ValueError, lambda: OneMatrix(n, n)) + + +def test_ZeroMatrix(): + n, m = symbols('n m', integer=True) + A = MatrixSymbol('A', n, m) + Z = ZeroMatrix(n, m) + + assert A + Z == A + assert A*Z.T == ZeroMatrix(n, n) + assert Z*A.T == ZeroMatrix(n, n) + assert A - A == ZeroMatrix(*A.shape) + + assert Z + + assert Z.transpose() == ZeroMatrix(m, n) + assert Z.conjugate() == Z + assert Z.adjoint() == ZeroMatrix(m, n) + assert re(Z) == Z + assert im(Z) == Z + + assert ZeroMatrix(n, n)**0 == Identity(n) + assert ZeroMatrix(3, 3).as_explicit() == ImmutableDenseMatrix.zeros(3, 3) + + +def test_ZeroMatrix_doit(): + n = symbols('n', integer=True) + Znn = ZeroMatrix(Add(n, n, evaluate=False), n) + assert isinstance(Znn.rows, Add) + assert Znn.doit() == ZeroMatrix(2*n, n) + assert isinstance(Znn.doit().rows, Mul) + + +def test_OneMatrix(): + n, m = symbols('n m', integer=True) + A = MatrixSymbol('A', n, m) + U = OneMatrix(n, m) + + assert U.shape == (n, m) + assert isinstance(A + U, Add) + assert U.transpose() == OneMatrix(m, n) + assert U.conjugate() == U + assert U.adjoint() == OneMatrix(m, n) + assert re(U) == U + assert im(U) == ZeroMatrix(n, m) + + assert OneMatrix(n, n) ** 0 == Identity(n) + + U = OneMatrix(n, n) + assert U[1, 2] == 1 + + U = OneMatrix(2, 3) + assert U.as_explicit() == ImmutableDenseMatrix.ones(2, 3) + + +def test_OneMatrix_doit(): + n = symbols('n', integer=True) + Unn = OneMatrix(Add(n, n, evaluate=False), n) + assert isinstance(Unn.rows, Add) + assert Unn.doit() == OneMatrix(2 * n, n) + assert isinstance(Unn.doit().rows, Mul) + + +def test_OneMatrix_mul(): + n, m, k = symbols('n m k', integer=True) + w = MatrixSymbol('w', n, 1) + assert OneMatrix(n, m) * OneMatrix(m, k) == OneMatrix(n, k) * m + assert w * OneMatrix(1, 1) == w + assert OneMatrix(1, 1) * w.T == w.T + + +def test_Identity(): + n, m = symbols('n m', integer=True) + A = MatrixSymbol('A', n, m) + i, j = symbols('i j') + + In = Identity(n) + Im = Identity(m) + + assert A*Im == A + assert In*A == A + + assert In.transpose() == In + assert In.inverse() == In + assert In.conjugate() == In + assert In.adjoint() == In + assert re(In) == In + assert im(In) == ZeroMatrix(n, n) + + assert In[i, j] != 0 + assert Sum(In[i, j], (i, 0, n-1), (j, 0, n-1)).subs(n,3).doit() == 3 + assert Sum(Sum(In[i, j], (i, 0, n-1)), (j, 0, n-1)).subs(n,3).doit() == 3 + + # If range exceeds the limit `(0, n-1)`, do not remove `Piecewise`: + expr = Sum(In[i, j], (i, 0, n-1)) + assert expr.doit() == 1 + expr = Sum(In[i, j], (i, 0, n-2)) + assert expr.doit().dummy_eq( + Piecewise( + (1, (j >= 0) & (j <= n-2)), + (0, True) + ) + ) + expr = Sum(In[i, j], (i, 1, n-1)) + assert expr.doit().dummy_eq( + Piecewise( + (1, (j >= 1) & (j <= n-1)), + (0, True) + ) + ) + assert Identity(3).as_explicit() == ImmutableDenseMatrix.eye(3) + + +def test_Identity_doit(): + n = symbols('n', integer=True) + Inn = Identity(Add(n, n, evaluate=False)) + assert isinstance(Inn.rows, Add) + assert Inn.doit() == Identity(2*n) + assert isinstance(Inn.doit().rows, Mul) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_trace.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd66bec2377dae634ff486f42cc474eda7b23b1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_trace.py @@ -0,0 +1,116 @@ +from sympy.core import Lambda, S, symbols +from sympy.concrete import Sum +from sympy.functions import adjoint, conjugate, transpose +from sympy.matrices import eye, Matrix, ShapeError, ImmutableMatrix +from sympy.matrices.expressions import ( + Adjoint, Identity, FunctionMatrix, MatrixExpr, MatrixSymbol, Trace, + ZeroMatrix, trace, MatPow, MatAdd, MatMul +) +from sympy.matrices.expressions.special import OneMatrix +from sympy.testing.pytest import raises +from sympy.abc import i + + +n = symbols('n', integer=True) +A = MatrixSymbol('A', n, n) +B = MatrixSymbol('B', n, n) +C = MatrixSymbol('C', 3, 4) + + +def test_Trace(): + assert isinstance(Trace(A), Trace) + assert not isinstance(Trace(A), MatrixExpr) + raises(ShapeError, lambda: Trace(C)) + assert trace(eye(3)) == 3 + assert trace(Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9])) == 15 + + assert adjoint(Trace(A)) == trace(Adjoint(A)) + assert conjugate(Trace(A)) == trace(Adjoint(A)) + assert transpose(Trace(A)) == Trace(A) + + _ = A / Trace(A) # Make sure this is possible + + # Some easy simplifications + assert trace(Identity(5)) == 5 + assert trace(ZeroMatrix(5, 5)) == 0 + assert trace(OneMatrix(1, 1)) == 1 + assert trace(OneMatrix(2, 2)) == 2 + assert trace(OneMatrix(n, n)) == n + assert trace(2*A*B) == 2*Trace(A*B) + assert trace(A.T) == trace(A) + + i, j = symbols('i j') + F = FunctionMatrix(3, 3, Lambda((i, j), i + j)) + assert trace(F) == (0 + 0) + (1 + 1) + (2 + 2) + + raises(TypeError, lambda: Trace(S.One)) + + assert Trace(A).arg is A + + assert str(trace(A)) == str(Trace(A).doit()) + + assert Trace(A).is_commutative is True + +def test_Trace_A_plus_B(): + assert trace(A + B) == Trace(A) + Trace(B) + assert Trace(A + B).arg == MatAdd(A, B) + assert Trace(A + B).doit() == Trace(A) + Trace(B) + + +def test_Trace_MatAdd_doit(): + # See issue #9028 + X = ImmutableMatrix([[1, 2, 3]]*3) + Y = MatrixSymbol('Y', 3, 3) + q = MatAdd(X, 2*X, Y, -3*Y) + assert Trace(q).arg == q + assert Trace(q).doit() == 18 - 2*Trace(Y) + + +def test_Trace_MatPow_doit(): + X = Matrix([[1, 2], [3, 4]]) + assert Trace(X).doit() == 5 + q = MatPow(X, 2) + assert Trace(q).arg == q + assert Trace(q).doit() == 29 + + +def test_Trace_MutableMatrix_plus(): + # See issue #9043 + X = Matrix([[1, 2], [3, 4]]) + assert Trace(X) + Trace(X) == 2*Trace(X) + + +def test_Trace_doit_deep_False(): + X = Matrix([[1, 2], [3, 4]]) + q = MatPow(X, 2) + assert Trace(q).doit(deep=False).arg == q + q = MatAdd(X, 2*X) + assert Trace(q).doit(deep=False).arg == q + q = MatMul(X, 2*X) + assert Trace(q).doit(deep=False).arg == q + + +def test_trace_constant_factor(): + # Issue 9052: gave 2*Trace(MatMul(A)) instead of 2*Trace(A) + assert trace(2*A) == 2*Trace(A) + X = ImmutableMatrix([[1, 2], [3, 4]]) + assert trace(MatMul(2, X)) == 10 + + +def test_trace_rewrite(): + assert trace(A).rewrite(Sum) == Sum(A[i, i], (i, 0, n - 1)) + assert trace(eye(3)).rewrite(Sum) == 3 + + +def test_trace_normalize(): + assert Trace(B*A) != Trace(A*B) + assert Trace(B*A)._normalize() == Trace(A*B) + assert Trace(B*A.T)._normalize() == Trace(A*B.T) + + +def test_trace_as_explicit(): + raises(ValueError, lambda: Trace(A).as_explicit()) + + X = MatrixSymbol("X", 3, 3) + assert Trace(X).as_explicit() == X[0, 0] + X[1, 1] + X[2, 2] + assert Trace(eye(3)).as_explicit() == 3 diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_transpose.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a6113873426d99bacf85484d3b66781f300af7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/tests/test_transpose.py @@ -0,0 +1,69 @@ +from sympy.functions import adjoint, conjugate, transpose +from sympy.matrices.expressions import MatrixSymbol, Adjoint, trace, Transpose +from sympy.matrices import eye, Matrix +from sympy.assumptions.ask import Q +from sympy.assumptions.refine import refine +from sympy.core.singleton import S +from sympy.core.symbol import symbols + +n, m, l, k, p = symbols('n m l k p', integer=True) +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) + + +def test_transpose(): + Sq = MatrixSymbol('Sq', n, n) + + assert transpose(A) == Transpose(A) + assert Transpose(A).shape == (m, n) + assert Transpose(A*B).shape == (l, n) + assert transpose(Transpose(A)) == A + assert isinstance(Transpose(Transpose(A)), Transpose) + + assert adjoint(Transpose(A)) == Adjoint(Transpose(A)) + assert conjugate(Transpose(A)) == Adjoint(A) + + assert Transpose(eye(3)).doit() == eye(3) + + assert Transpose(S(5)).doit() == S(5) + + assert Transpose(Matrix([[1, 2], [3, 4]])).doit() == Matrix([[1, 3], [2, 4]]) + + assert transpose(trace(Sq)) == trace(Sq) + assert trace(Transpose(Sq)) == trace(Sq) + + assert Transpose(Sq)[0, 1] == Sq[1, 0] + + assert Transpose(A*B).doit() == Transpose(B) * Transpose(A) + + +def test_transpose_MatAdd_MatMul(): + # Issue 16807 + from sympy.functions.elementary.trigonometric import cos + + x = symbols('x') + M = MatrixSymbol('M', 3, 3) + N = MatrixSymbol('N', 3, 3) + + assert (N + (cos(x) * M)).T == cos(x)*M.T + N.T + + +def test_refine(): + assert refine(C.T, Q.symmetric(C)) == C + + +def test_transpose1x1(): + m = MatrixSymbol('m', 1, 1) + assert m == refine(m.T) + assert m == refine(m.T.T) + +def test_issue_9817(): + from sympy.matrices.expressions import Identity + v = MatrixSymbol('v', 3, 1) + A = MatrixSymbol('A', 3, 3) + x = Matrix([i + 1 for i in range(3)]) + X = Identity(3) + quadratic = v.T * A * v + subbed = quadratic.xreplace({v:x, A:X}) + assert subbed.as_explicit() == Matrix([[14]]) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/trace.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/trace.py new file mode 100644 index 0000000000000000000000000000000000000000..b5f9f94ea7486dc21b47c2e2e783a93280b180e0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/trace.py @@ -0,0 +1,167 @@ +from sympy.core.basic import Basic +from sympy.core.expr import Expr, ExprBuilder +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import uniquely_named_symbol +from sympy.core.sympify import sympify +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices.exceptions import NonSquareMatrixError + + +class Trace(Expr): + """Matrix Trace + + Represents the trace of a matrix expression. + + Examples + ======== + + >>> from sympy import MatrixSymbol, Trace, eye + >>> A = MatrixSymbol('A', 3, 3) + >>> Trace(A) + Trace(A) + >>> Trace(eye(3)) + Trace(Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]])) + >>> Trace(eye(3)).simplify() + 3 + """ + is_Trace = True + is_commutative = True + + def __new__(cls, mat): + mat = sympify(mat) + + if not mat.is_Matrix: + raise TypeError("input to Trace, %s, is not a matrix" % str(mat)) + + if mat.is_square is False: + raise NonSquareMatrixError("Trace of a non-square matrix") + + return Basic.__new__(cls, mat) + + def _eval_transpose(self): + return self + + def _eval_derivative(self, v): + from sympy.concrete.summations import Sum + from .matexpr import MatrixElement + if isinstance(v, MatrixElement): + return self.rewrite(Sum).diff(v) + expr = self.doit() + if isinstance(expr, Trace): + # Avoid looping infinitely: + raise NotImplementedError + return expr._eval_derivative(v) + + def _eval_derivative_matrix_lines(self, x): + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayContraction + r = self.args[0]._eval_derivative_matrix_lines(x) + for lr in r: + if lr.higher == 1: + lr.higher = ExprBuilder( + ArrayContraction, + [ + ExprBuilder( + ArrayTensorProduct, + [ + lr._lines[0], + lr._lines[1], + ] + ), + (1, 3), + ], + validator=ArrayContraction._validate + ) + else: + # This is not a matrix line: + lr.higher = ExprBuilder( + ArrayContraction, + [ + ExprBuilder( + ArrayTensorProduct, + [ + lr._lines[0], + lr._lines[1], + lr.higher, + ] + ), + (1, 3), (0, 2) + ] + ) + lr._lines = [S.One, S.One] + lr._first_pointer_parent = lr._lines + lr._second_pointer_parent = lr._lines + lr._first_pointer_index = 0 + lr._second_pointer_index = 1 + return r + + @property + def arg(self): + return self.args[0] + + def doit(self, **hints): + if hints.get('deep', True): + arg = self.arg.doit(**hints) + result = arg._eval_trace() + if result is not None: + return result + else: + return Trace(arg) + else: + # _eval_trace would go too deep here + if isinstance(self.arg, MatrixBase): + return trace(self.arg) + else: + return Trace(self.arg) + + def as_explicit(self): + return Trace(self.arg.as_explicit()).doit() + + def _normalize(self): + # Normalization of trace of matrix products. Use transposition and + # cyclic properties of traces to make sure the arguments of the matrix + # product are sorted and the first argument is not a transposition. + from sympy.matrices.expressions.matmul import MatMul + from sympy.matrices.expressions.transpose import Transpose + trace_arg = self.arg + if isinstance(trace_arg, MatMul): + + def get_arg_key(x): + a = trace_arg.args[x] + if isinstance(a, Transpose): + a = a.arg + return default_sort_key(a) + + indmin = min(range(len(trace_arg.args)), key=get_arg_key) + if isinstance(trace_arg.args[indmin], Transpose): + trace_arg = Transpose(trace_arg).doit() + indmin = min(range(len(trace_arg.args)), key=lambda x: default_sort_key(trace_arg.args[x])) + trace_arg = MatMul.fromiter(trace_arg.args[indmin:] + trace_arg.args[:indmin]) + return Trace(trace_arg) + return self + + def _eval_rewrite_as_Sum(self, expr, **kwargs): + from sympy.concrete.summations import Sum + i = uniquely_named_symbol('i', [expr]) + s = Sum(self.arg[i, i], (i, 0, self.arg.rows - 1)) + return s.doit() + + +def trace(expr): + """Trace of a Matrix. Sum of the diagonal elements. + + Examples + ======== + + >>> from sympy import trace, Symbol, MatrixSymbol, eye + >>> n = Symbol('n') + >>> X = MatrixSymbol('X', n, n) # A square matrix + >>> trace(2*X) + 2*Trace(X) + >>> trace(eye(3)) + 3 + """ + return Trace(expr).doit() diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/transpose.py b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..b11f7fc21490aab219420610ca529d81d6995d40 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/expressions/transpose.py @@ -0,0 +1,103 @@ +from sympy.core.basic import Basic +from sympy.matrices.expressions.matexpr import MatrixExpr + + +class Transpose(MatrixExpr): + """ + The transpose of a matrix expression. + + This is a symbolic object that simply stores its argument without + evaluating it. To actually compute the transpose, use the ``transpose()`` + function, or the ``.T`` attribute of matrices. + + Examples + ======== + + >>> from sympy import MatrixSymbol, Transpose, transpose + >>> A = MatrixSymbol('A', 3, 5) + >>> B = MatrixSymbol('B', 5, 3) + >>> Transpose(A) + A.T + >>> A.T == transpose(A) == Transpose(A) + True + >>> Transpose(A*B) + (A*B).T + >>> transpose(A*B) + B.T*A.T + + """ + is_Transpose = True + + def doit(self, **hints): + arg = self.arg + if hints.get('deep', True) and isinstance(arg, Basic): + arg = arg.doit(**hints) + _eval_transpose = getattr(arg, '_eval_transpose', None) + if _eval_transpose is not None: + result = _eval_transpose() + return result if result is not None else Transpose(arg) + else: + return Transpose(arg) + + @property + def arg(self): + return self.args[0] + + @property + def shape(self): + return self.arg.shape[::-1] + + def _entry(self, i, j, expand=False, **kwargs): + return self.arg._entry(j, i, expand=expand, **kwargs) + + def _eval_adjoint(self): + return self.arg.conjugate() + + def _eval_conjugate(self): + return self.arg.adjoint() + + def _eval_transpose(self): + return self.arg + + def _eval_trace(self): + from .trace import Trace + return Trace(self.arg) # Trace(X.T) => Trace(X) + + def _eval_determinant(self): + from sympy.matrices.expressions.determinant import det + return det(self.arg) + + def _eval_derivative(self, x): + # x is a scalar: + return self.arg._eval_derivative(x) + + def _eval_derivative_matrix_lines(self, x): + lines = self.args[0]._eval_derivative_matrix_lines(x) + return [i.transpose() for i in lines] + + +def transpose(expr): + """Matrix transpose""" + return Transpose(expr).doit(deep=False) + + +from sympy.assumptions.ask import ask, Q +from sympy.assumptions.refine import handlers_dict + + +def refine_Transpose(expr, assumptions): + """ + >>> from sympy import MatrixSymbol, Q, assuming, refine + >>> X = MatrixSymbol('X', 2, 2) + >>> X.T + X.T + >>> with assuming(Q.symmetric(X)): + ... print(refine(X.T)) + X + """ + if ask(Q.symmetric(expr), assumptions): + return expr.arg + + return expr + +handlers_dict['Transpose'] = refine_Transpose diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/graph.py b/.venv/lib/python3.13/site-packages/sympy/matrices/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..4c6356db884cfcd3c759ada07ac559f43dbcbbcb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/graph.py @@ -0,0 +1,279 @@ +from sympy.utilities.iterables import \ + flatten, connected_components, strongly_connected_components +from .exceptions import NonSquareMatrixError + + +def _connected_components(M): + """Returns the list of connected vertices of the graph when + a square matrix is viewed as a weighted graph. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([ + ... [66, 0, 0, 68, 0, 0, 0, 0, 67], + ... [0, 55, 0, 0, 0, 0, 54, 53, 0], + ... [0, 0, 0, 0, 1, 2, 0, 0, 0], + ... [86, 0, 0, 88, 0, 0, 0, 0, 87], + ... [0, 0, 10, 0, 11, 12, 0, 0, 0], + ... [0, 0, 20, 0, 21, 22, 0, 0, 0], + ... [0, 45, 0, 0, 0, 0, 44, 43, 0], + ... [0, 35, 0, 0, 0, 0, 34, 33, 0], + ... [76, 0, 0, 78, 0, 0, 0, 0, 77]]) + >>> A.connected_components() + [[0, 3, 8], [1, 6, 7], [2, 4, 5]] + + Notes + ===== + + Even if any symbolic elements of the matrix can be indeterminate + to be zero mathematically, this only takes the account of the + structural aspect of the matrix, so they will considered to be + nonzero. + """ + if not M.is_square: + raise NonSquareMatrixError + + V = range(M.rows) + E = sorted(M.todok().keys()) + return connected_components((V, E)) + + +def _strongly_connected_components(M): + """Returns the list of strongly connected vertices of the graph when + a square matrix is viewed as a weighted graph. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([ + ... [44, 0, 0, 0, 43, 0, 45, 0, 0], + ... [0, 66, 62, 61, 0, 68, 0, 60, 67], + ... [0, 0, 22, 21, 0, 0, 0, 20, 0], + ... [0, 0, 12, 11, 0, 0, 0, 10, 0], + ... [34, 0, 0, 0, 33, 0, 35, 0, 0], + ... [0, 86, 82, 81, 0, 88, 0, 80, 87], + ... [54, 0, 0, 0, 53, 0, 55, 0, 0], + ... [0, 0, 2, 1, 0, 0, 0, 0, 0], + ... [0, 76, 72, 71, 0, 78, 0, 70, 77]]) + >>> A.strongly_connected_components() + [[0, 4, 6], [2, 3, 7], [1, 5, 8]] + """ + if not M.is_square: + raise NonSquareMatrixError + + # RepMatrix uses the more efficient DomainMatrix.scc() method + rep = getattr(M, '_rep', None) + if rep is not None: + return rep.scc() + + V = range(M.rows) + E = sorted(M.todok().keys()) + return strongly_connected_components((V, E)) + + +def _connected_components_decomposition(M): + """Decomposes a square matrix into block diagonal form only + using the permutations. + + Explanation + =========== + + The decomposition is in a form of $A = P^{-1} B P$ where $P$ is a + permutation matrix and $B$ is a block diagonal matrix. + + Returns + ======= + + P, B : PermutationMatrix, BlockDiagMatrix + *P* is a permutation matrix for the similarity transform + as in the explanation. And *B* is the block diagonal matrix of + the result of the permutation. + + If you would like to get the diagonal blocks from the + BlockDiagMatrix, see + :meth:`~sympy.matrices.expressions.blockmatrix.BlockDiagMatrix.get_diag_blocks`. + + Examples + ======== + + >>> from sympy import Matrix, pprint + >>> A = Matrix([ + ... [66, 0, 0, 68, 0, 0, 0, 0, 67], + ... [0, 55, 0, 0, 0, 0, 54, 53, 0], + ... [0, 0, 0, 0, 1, 2, 0, 0, 0], + ... [86, 0, 0, 88, 0, 0, 0, 0, 87], + ... [0, 0, 10, 0, 11, 12, 0, 0, 0], + ... [0, 0, 20, 0, 21, 22, 0, 0, 0], + ... [0, 45, 0, 0, 0, 0, 44, 43, 0], + ... [0, 35, 0, 0, 0, 0, 34, 33, 0], + ... [76, 0, 0, 78, 0, 0, 0, 0, 77]]) + + >>> P, B = A.connected_components_decomposition() + >>> pprint(P) + PermutationMatrix((1 3)(2 8 5 7 4 6)) + >>> pprint(B) + [[66 68 67] ] + [[ ] ] + [[86 88 87] 0 0 ] + [[ ] ] + [[76 78 77] ] + [ ] + [ [55 54 53] ] + [ [ ] ] + [ 0 [45 44 43] 0 ] + [ [ ] ] + [ [35 34 33] ] + [ ] + [ [0 1 2 ]] + [ [ ]] + [ 0 0 [10 11 12]] + [ [ ]] + [ [20 21 22]] + + >>> P = P.as_explicit() + >>> B = B.as_explicit() + >>> P.T*B*P == A + True + + Notes + ===== + + This problem corresponds to the finding of the connected components + of a graph, when a matrix is viewed as a weighted graph. + """ + from sympy.combinatorics.permutations import Permutation + from sympy.matrices.expressions.blockmatrix import BlockDiagMatrix + from sympy.matrices.expressions.permutation import PermutationMatrix + + iblocks = M.connected_components() + + p = Permutation(flatten(iblocks)) + P = PermutationMatrix(p) + + blocks = [] + for b in iblocks: + blocks.append(M[b, b]) + B = BlockDiagMatrix(*blocks) + return P, B + + +def _strongly_connected_components_decomposition(M, lower=True): + """Decomposes a square matrix into block triangular form only + using the permutations. + + Explanation + =========== + + The decomposition is in a form of $A = P^{-1} B P$ where $P$ is a + permutation matrix and $B$ is a block diagonal matrix. + + Parameters + ========== + + lower : bool + Makes $B$ lower block triangular when ``True``. + Otherwise, makes $B$ upper block triangular. + + Returns + ======= + + P, B : PermutationMatrix, BlockMatrix + *P* is a permutation matrix for the similarity transform + as in the explanation. And *B* is the block triangular matrix of + the result of the permutation. + + Examples + ======== + + >>> from sympy import Matrix, pprint + >>> A = Matrix([ + ... [44, 0, 0, 0, 43, 0, 45, 0, 0], + ... [0, 66, 62, 61, 0, 68, 0, 60, 67], + ... [0, 0, 22, 21, 0, 0, 0, 20, 0], + ... [0, 0, 12, 11, 0, 0, 0, 10, 0], + ... [34, 0, 0, 0, 33, 0, 35, 0, 0], + ... [0, 86, 82, 81, 0, 88, 0, 80, 87], + ... [54, 0, 0, 0, 53, 0, 55, 0, 0], + ... [0, 0, 2, 1, 0, 0, 0, 0, 0], + ... [0, 76, 72, 71, 0, 78, 0, 70, 77]]) + + A lower block triangular decomposition: + + >>> P, B = A.strongly_connected_components_decomposition() + >>> pprint(P) + PermutationMatrix((8)(1 4 3 2 6)(5 7)) + >>> pprint(B) + [[44 43 45] [0 0 0] [0 0 0] ] + [[ ] [ ] [ ] ] + [[34 33 35] [0 0 0] [0 0 0] ] + [[ ] [ ] [ ] ] + [[54 53 55] [0 0 0] [0 0 0] ] + [ ] + [ [0 0 0] [22 21 20] [0 0 0] ] + [ [ ] [ ] [ ] ] + [ [0 0 0] [12 11 10] [0 0 0] ] + [ [ ] [ ] [ ] ] + [ [0 0 0] [2 1 0 ] [0 0 0] ] + [ ] + [ [0 0 0] [62 61 60] [66 68 67]] + [ [ ] [ ] [ ]] + [ [0 0 0] [82 81 80] [86 88 87]] + [ [ ] [ ] [ ]] + [ [0 0 0] [72 71 70] [76 78 77]] + + >>> P = P.as_explicit() + >>> B = B.as_explicit() + >>> P.T * B * P == A + True + + An upper block triangular decomposition: + + >>> P, B = A.strongly_connected_components_decomposition(lower=False) + >>> pprint(P) + PermutationMatrix((0 1 5 7 4 3 2 8 6)) + >>> pprint(B) + [[66 68 67] [62 61 60] [0 0 0] ] + [[ ] [ ] [ ] ] + [[86 88 87] [82 81 80] [0 0 0] ] + [[ ] [ ] [ ] ] + [[76 78 77] [72 71 70] [0 0 0] ] + [ ] + [ [0 0 0] [22 21 20] [0 0 0] ] + [ [ ] [ ] [ ] ] + [ [0 0 0] [12 11 10] [0 0 0] ] + [ [ ] [ ] [ ] ] + [ [0 0 0] [2 1 0 ] [0 0 0] ] + [ ] + [ [0 0 0] [0 0 0] [44 43 45]] + [ [ ] [ ] [ ]] + [ [0 0 0] [0 0 0] [34 33 35]] + [ [ ] [ ] [ ]] + [ [0 0 0] [0 0 0] [54 53 55]] + + >>> P = P.as_explicit() + >>> B = B.as_explicit() + >>> P.T * B * P == A + True + """ + from sympy.combinatorics.permutations import Permutation + from sympy.matrices.expressions.blockmatrix import BlockMatrix + from sympy.matrices.expressions.permutation import PermutationMatrix + + iblocks = M.strongly_connected_components() + if not lower: + iblocks = list(reversed(iblocks)) + + p = Permutation(flatten(iblocks)) + P = PermutationMatrix(p) + + rows = [] + for a in iblocks: + cols = [] + for b in iblocks: + cols.append(M[a, b]) + rows.append(cols) + B = BlockMatrix(rows) + return P, B diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/immutable.py b/.venv/lib/python3.13/site-packages/sympy/matrices/immutable.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec2174bf1c785e1a4698e1b55078d300e62dafe --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/immutable.py @@ -0,0 +1,196 @@ +from mpmath.matrices.matrices import _matrix + +from sympy.core import Basic, Dict, Tuple +from sympy.core.numbers import Integer +from sympy.core.cache import cacheit +from sympy.core.sympify import _sympy_converter as sympify_converter, _sympify +from sympy.matrices.dense import DenseMatrix +from sympy.matrices.expressions import MatrixExpr +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices.repmatrix import RepMatrix +from sympy.matrices.sparse import SparseRepMatrix +from sympy.multipledispatch import dispatch + + +def sympify_matrix(arg): + return arg.as_immutable() + + +sympify_converter[MatrixBase] = sympify_matrix + + +def sympify_mpmath_matrix(arg): + mat = [_sympify(x) for x in arg] + return ImmutableDenseMatrix(arg.rows, arg.cols, mat) + + +sympify_converter[_matrix] = sympify_mpmath_matrix + + +class ImmutableRepMatrix(RepMatrix, MatrixExpr): # type: ignore + """Immutable matrix based on RepMatrix + + Uses DomainMAtrix as the internal representation. + """ + + # + # This is a subclass of RepMatrix that adds/overrides some methods to make + # the instances Basic and immutable. ImmutableRepMatrix is a superclass for + # both ImmutableDenseMatrix and ImmutableSparseMatrix. + # + + def __new__(cls, *args, **kwargs): + return cls._new(*args, **kwargs) + + __hash__ = MatrixExpr.__hash__ + + def copy(self): + return self + + @property + def cols(self): + return self._cols + + @property + def rows(self): + return self._rows + + @property + def shape(self): + return self._rows, self._cols + + def as_immutable(self): + return self + + def _entry(self, i, j, **kwargs): + return self[i, j] + + def __setitem__(self, *args): + raise TypeError("Cannot set values of {}".format(self.__class__)) + + def is_diagonalizable(self, reals_only=False, **kwargs): + return super().is_diagonalizable( + reals_only=reals_only, **kwargs) + + is_diagonalizable.__doc__ = SparseRepMatrix.is_diagonalizable.__doc__ + is_diagonalizable = cacheit(is_diagonalizable) + + def analytic_func(self, f, x): + return self.as_mutable().analytic_func(f, x).as_immutable() + + +class ImmutableDenseMatrix(DenseMatrix, ImmutableRepMatrix): # type: ignore + """Create an immutable version of a matrix. + + Examples + ======== + + >>> from sympy import eye, ImmutableMatrix + >>> ImmutableMatrix(eye(3)) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> _[0, 0] = 42 + Traceback (most recent call last): + ... + TypeError: Cannot set values of ImmutableDenseMatrix + """ + + # MatrixExpr is set as NotIterable, but we want explicit matrices to be + # iterable + _iterable = True + _class_priority = 8 + _op_priority = 10.001 + + @classmethod + def _new(cls, *args, **kwargs): + if len(args) == 1 and isinstance(args[0], ImmutableDenseMatrix): + return args[0] + if kwargs.get('copy', True) is False: + if len(args) != 3: + raise TypeError("'copy=False' requires a matrix be initialized as rows,cols,[list]") + rows, cols, flat_list = args + else: + rows, cols, flat_list = cls._handle_creation_inputs(*args, **kwargs) + flat_list = list(flat_list) # create a shallow copy + + rep = cls._flat_list_to_DomainMatrix(rows, cols, flat_list) + + return cls._fromrep(rep) + + @classmethod + def _fromrep(cls, rep): + rows, cols = rep.shape + flat_list = rep.to_sympy().to_list_flat() + obj = Basic.__new__(cls, + Integer(rows), + Integer(cols), + Tuple(*flat_list, sympify=False)) + obj._rows = rows + obj._cols = cols + obj._rep = rep + return obj + + +# make sure ImmutableDenseMatrix is aliased as ImmutableMatrix +ImmutableMatrix = ImmutableDenseMatrix + + +class ImmutableSparseMatrix(SparseRepMatrix, ImmutableRepMatrix): # type:ignore + """Create an immutable version of a sparse matrix. + + Examples + ======== + + >>> from sympy import eye, ImmutableSparseMatrix + >>> ImmutableSparseMatrix(1, 1, {}) + Matrix([[0]]) + >>> ImmutableSparseMatrix(eye(3)) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> _[0, 0] = 42 + Traceback (most recent call last): + ... + TypeError: Cannot set values of ImmutableSparseMatrix + >>> _.shape + (3, 3) + """ + is_Matrix = True + _class_priority = 9 + + @classmethod + def _new(cls, *args, **kwargs): + rows, cols, smat = cls._handle_creation_inputs(*args, **kwargs) + + rep = cls._smat_to_DomainMatrix(rows, cols, smat) + + return cls._fromrep(rep) + + @classmethod + def _fromrep(cls, rep): + rows, cols = rep.shape + smat = rep.to_sympy().to_dok() + obj = Basic.__new__(cls, Integer(rows), Integer(cols), Dict(smat)) + obj._rows = rows + obj._cols = cols + obj._rep = rep + return obj + + +@dispatch(ImmutableDenseMatrix, ImmutableDenseMatrix) +def _eval_is_eq(lhs, rhs): # noqa:F811 + """Helper method for Equality with matrices.sympy. + + Relational automatically converts matrices to ImmutableDenseMatrix + instances, so this method only applies here. Returns True if the + matrices are definitively the same, False if they are definitively + different, and None if undetermined (e.g. if they contain Symbols). + Returning None triggers default handling of Equalities. + + """ + if lhs.shape != rhs.shape: + return False + return (lhs - rhs).is_zero_matrix diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/inverse.py b/.venv/lib/python3.13/site-packages/sympy/matrices/inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..61d9e12edf013d2f5555d61786343aa3840edfd3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/inverse.py @@ -0,0 +1,524 @@ +from sympy.polys.matrices.exceptions import DMNonInvertibleMatrixError +from sympy.polys.domains import EX + +from .exceptions import MatrixError, NonSquareMatrixError, NonInvertibleMatrixError +from .utilities import _iszero + + +def _pinv_full_rank(M): + """Subroutine for full row or column rank matrices. + + For full row rank matrices, inverse of ``A * A.H`` Exists. + For full column rank matrices, inverse of ``A.H * A`` Exists. + + This routine can apply for both cases by checking the shape + and have small decision. + """ + + if M.is_zero_matrix: + return M.H + + if M.rows >= M.cols: + return M.H.multiply(M).inv().multiply(M.H) + else: + return M.H.multiply(M.multiply(M.H).inv()) + +def _pinv_rank_decomposition(M): + """Subroutine for rank decomposition + + With rank decompositions, `A` can be decomposed into two full- + rank matrices, and each matrix can take pseudoinverse + individually. + """ + + if M.is_zero_matrix: + return M.H + + B, C = M.rank_decomposition() + + Bp = _pinv_full_rank(B) + Cp = _pinv_full_rank(C) + + return Cp.multiply(Bp) + +def _pinv_diagonalization(M): + """Subroutine using diagonalization + + This routine can sometimes fail if SymPy's eigenvalue + computation is not reliable. + """ + + if M.is_zero_matrix: + return M.H + + A = M + AH = M.H + + try: + if M.rows >= M.cols: + P, D = AH.multiply(A).diagonalize(normalize=True) + D_pinv = D.applyfunc(lambda x: 0 if _iszero(x) else 1 / x) + + return P.multiply(D_pinv).multiply(P.H).multiply(AH) + + else: + P, D = A.multiply(AH).diagonalize( + normalize=True) + D_pinv = D.applyfunc(lambda x: 0 if _iszero(x) else 1 / x) + + return AH.multiply(P).multiply(D_pinv).multiply(P.H) + + except MatrixError: + raise NotImplementedError( + 'pinv for rank-deficient matrices where ' + 'diagonalization of A.H*A fails is not supported yet.') + +def _pinv(M, method='RD'): + """Calculate the Moore-Penrose pseudoinverse of the matrix. + + The Moore-Penrose pseudoinverse exists and is unique for any matrix. + If the matrix is invertible, the pseudoinverse is the same as the + inverse. + + Parameters + ========== + + method : String, optional + Specifies the method for computing the pseudoinverse. + + If ``'RD'``, Rank-Decomposition will be used. + + If ``'ED'``, Diagonalization will be used. + + Examples + ======== + + Computing pseudoinverse by rank decomposition : + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> A.pinv() + Matrix([ + [-17/18, 4/9], + [ -1/9, 1/9], + [ 13/18, -2/9]]) + + Computing pseudoinverse by diagonalization : + + >>> B = A.pinv(method='ED') + >>> B.simplify() + >>> B + Matrix([ + [-17/18, 4/9], + [ -1/9, 1/9], + [ 13/18, -2/9]]) + + See Also + ======== + + inv + pinv_solve + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Moore-Penrose_pseudoinverse + + """ + + # Trivial case: pseudoinverse of all-zero matrix is its transpose. + if M.is_zero_matrix: + return M.H + + if method == 'RD': + return _pinv_rank_decomposition(M) + elif method == 'ED': + return _pinv_diagonalization(M) + else: + raise ValueError('invalid pinv method %s' % repr(method)) + + +def _verify_invertible(M, iszerofunc=_iszero): + """Initial check to see if a matrix is invertible. Raises or returns + determinant for use in _inv_ADJ.""" + + if not M.is_square: + raise NonSquareMatrixError("A Matrix must be square to invert.") + + d = M.det(method='berkowitz') + zero = d.equals(0) + + if zero is None: # if equals() can't decide, will rref be able to? + ok = M.rref(simplify=True)[0] + zero = any(iszerofunc(ok[j, j]) for j in range(ok.rows)) + + if zero: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + + return d + +def _inv_ADJ(M, iszerofunc=_iszero): + """Calculates the inverse using the adjugate matrix and a determinant. + + See Also + ======== + + inv + inverse_GE + inverse_LU + inverse_CH + inverse_LDL + """ + + d = _verify_invertible(M, iszerofunc=iszerofunc) + + return M.adjugate() / d + +def _inv_GE(M, iszerofunc=_iszero): + """Calculates the inverse using Gaussian elimination. + + See Also + ======== + + inv + inverse_ADJ + inverse_LU + inverse_CH + inverse_LDL + """ + + from .dense import Matrix + + if not M.is_square: + raise NonSquareMatrixError("A Matrix must be square to invert.") + + big = Matrix.hstack(M.as_mutable(), Matrix.eye(M.rows)) + red = big.rref(iszerofunc=iszerofunc, simplify=True)[0] + + if any(iszerofunc(red[j, j]) for j in range(red.rows)): + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + + return M._new(red[:, big.rows:]) + +def _inv_LU(M, iszerofunc=_iszero): + """Calculates the inverse using LU decomposition. + + See Also + ======== + + inv + inverse_ADJ + inverse_GE + inverse_CH + inverse_LDL + """ + + if not M.is_square: + raise NonSquareMatrixError("A Matrix must be square to invert.") + if M.free_symbols: + _verify_invertible(M, iszerofunc=iszerofunc) + + return M.LUsolve(M.eye(M.rows), iszerofunc=_iszero) + +def _inv_CH(M, iszerofunc=_iszero): + """Calculates the inverse using cholesky decomposition. + + See Also + ======== + + inv + inverse_ADJ + inverse_GE + inverse_LU + inverse_LDL + """ + + _verify_invertible(M, iszerofunc=iszerofunc) + + return M.cholesky_solve(M.eye(M.rows)) + +def _inv_LDL(M, iszerofunc=_iszero): + """Calculates the inverse using LDL decomposition. + + See Also + ======== + + inv + inverse_ADJ + inverse_GE + inverse_LU + inverse_CH + """ + + _verify_invertible(M, iszerofunc=iszerofunc) + + return M.LDLsolve(M.eye(M.rows)) + +def _inv_QR(M, iszerofunc=_iszero): + """Calculates the inverse using QR decomposition. + + See Also + ======== + + inv + inverse_ADJ + inverse_GE + inverse_CH + inverse_LDL + """ + + _verify_invertible(M, iszerofunc=iszerofunc) + + return M.QRsolve(M.eye(M.rows)) + +def _try_DM(M, use_EX=False): + """Try to convert a matrix to a ``DomainMatrix``.""" + dM = M.to_DM() + K = dM.domain + + # Return DomainMatrix if a domain is found. Only use EX if use_EX=True. + if not use_EX and K.is_EXRAW: + return None + elif K.is_EXRAW: + return dM.convert_to(EX) + else: + return dM + + +def _use_exact_domain(dom): + """Check whether to convert to an exact domain.""" + # DomainMatrix can handle RR and CC with partial pivoting. Other inexact + # domains like RR[a,b,...] can only be handled by converting to an exact + # domain like QQ[a,b,...] + if dom.is_RR or dom.is_CC: + return False + else: + return not dom.is_Exact + + +def _inv_DM(dM, cancel=True): + """Calculates the inverse using ``DomainMatrix``. + + See Also + ======== + + inv + inverse_ADJ + inverse_GE + inverse_CH + inverse_LDL + sympy.polys.matrices.domainmatrix.DomainMatrix.inv + """ + m, n = dM.shape + dom = dM.domain + + if m != n: + raise NonSquareMatrixError("A Matrix must be square to invert.") + + # Convert RR[a,b,...] to QQ[a,b,...] + use_exact = _use_exact_domain(dom) + + if use_exact: + dom_exact = dom.get_exact() + dM = dM.convert_to(dom_exact) + + try: + dMi, den = dM.inv_den() + except DMNonInvertibleMatrixError: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + + if use_exact: + dMi = dMi.convert_to(dom) + den = dom.convert_from(den, dom_exact) + + if cancel: + # Convert to field and cancel with the denominator. + if not dMi.domain.is_Field: + dMi = dMi.to_field() + Mi = (dMi / den).to_Matrix() + else: + # Convert to Matrix and divide without cancelling + Mi = dMi.to_Matrix() / dMi.domain.to_sympy(den) + + return Mi + +def _inv_block(M, iszerofunc=_iszero): + """Calculates the inverse using BLOCKWISE inversion. + + See Also + ======== + + inv + inverse_ADJ + inverse_GE + inverse_CH + inverse_LDL + """ + from sympy.matrices.expressions.blockmatrix import BlockMatrix + i = M.shape[0] + if i <= 20 : + return M.inv(method="LU", iszerofunc=_iszero) + A = M[:i // 2, :i //2] + B = M[:i // 2, i // 2:] + C = M[i // 2:, :i // 2] + D = M[i // 2:, i // 2:] + try: + D_inv = _inv_block(D) + except NonInvertibleMatrixError: + return M.inv(method="LU", iszerofunc=_iszero) + B_D_i = B*D_inv + BDC = B_D_i*C + A_n = A - BDC + try: + A_n = _inv_block(A_n) + except NonInvertibleMatrixError: + return M.inv(method="LU", iszerofunc=_iszero) + B_n = -A_n*B_D_i + dc = D_inv*C + C_n = -dc*A_n + D_n = D_inv + dc*-B_n + nn = BlockMatrix([[A_n, B_n], [C_n, D_n]]).as_explicit() + return nn + +def _inv(M, method=None, iszerofunc=_iszero, try_block_diag=False): + """ + Return the inverse of a matrix using the method indicated. The default + is DM if a suitable domain is found or otherwise GE for dense matrices + LDL for sparse matrices. + + Parameters + ========== + + method : ('DM', 'DMNC', 'GE', 'LU', 'ADJ', 'CH', 'LDL', 'QR') + + iszerofunc : function, optional + Zero-testing function to use. + + try_block_diag : bool, optional + If True then will try to form block diagonal matrices using the + method get_diag_blocks(), invert these individually, and then + reconstruct the full inverse matrix. + + Examples + ======== + + >>> from sympy import SparseMatrix, Matrix + >>> A = SparseMatrix([ + ... [ 2, -1, 0], + ... [-1, 2, -1], + ... [ 0, 0, 2]]) + >>> A.inv('CH') + Matrix([ + [2/3, 1/3, 1/6], + [1/3, 2/3, 1/3], + [ 0, 0, 1/2]]) + >>> A.inv(method='LDL') # use of 'method=' is optional + Matrix([ + [2/3, 1/3, 1/6], + [1/3, 2/3, 1/3], + [ 0, 0, 1/2]]) + >>> A * _ + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> A = Matrix(A) + >>> A.inv('CH') + Matrix([ + [2/3, 1/3, 1/6], + [1/3, 2/3, 1/3], + [ 0, 0, 1/2]]) + >>> A.inv('ADJ') == A.inv('GE') == A.inv('LU') == A.inv('CH') == A.inv('LDL') == A.inv('QR') + True + + Notes + ===== + + According to the ``method`` keyword, it calls the appropriate method: + + DM .... Use DomainMatrix ``inv_den`` method + DMNC .... Use DomainMatrix ``inv_den`` method without cancellation + GE .... inverse_GE(); default for dense matrices + LU .... inverse_LU() + ADJ ... inverse_ADJ() + CH ... inverse_CH() + LDL ... inverse_LDL(); default for sparse matrices + QR ... inverse_QR() + + Note, the GE and LU methods may require the matrix to be simplified + before it is inverted in order to properly detect zeros during + pivoting. In difficult cases a custom zero detection function can + be provided by setting the ``iszerofunc`` argument to a function that + should return True if its argument is zero. The ADJ routine computes + the determinant and uses that to detect singular matrices in addition + to testing for zeros on the diagonal. + + See Also + ======== + + inverse_ADJ + inverse_GE + inverse_LU + inverse_CH + inverse_LDL + + Raises + ====== + + ValueError + If the determinant of the matrix is zero. + """ + + from sympy.matrices import diag, SparseMatrix + + if not M.is_square: + raise NonSquareMatrixError("A Matrix must be square to invert.") + + if try_block_diag: + blocks = M.get_diag_blocks() + r = [] + + for block in blocks: + r.append(block.inv(method=method, iszerofunc=iszerofunc)) + + return diag(*r) + + # Default: Use DomainMatrix if the domain is not EX. + # If DM is requested explicitly then use it even if the domain is EX. + if method is None and iszerofunc is _iszero: + dM = _try_DM(M, use_EX=False) + if dM is not None: + method = 'DM' + elif method in ("DM", "DMNC"): + dM = _try_DM(M, use_EX=True) + + # A suitable domain was not found, fall back to GE for dense matrices + # and LDL for sparse matrices. + if method is None: + if isinstance(M, SparseMatrix): + method = 'LDL' + else: + method = 'GE' + + if method == "DM": + rv = _inv_DM(dM) + elif method == "DMNC": + rv = _inv_DM(dM, cancel=False) + elif method == "GE": + rv = M.inverse_GE(iszerofunc=iszerofunc) + elif method == "LU": + rv = M.inverse_LU(iszerofunc=iszerofunc) + elif method == "ADJ": + rv = M.inverse_ADJ(iszerofunc=iszerofunc) + elif method == "CH": + rv = M.inverse_CH(iszerofunc=iszerofunc) + elif method == "LDL": + rv = M.inverse_LDL(iszerofunc=iszerofunc) + elif method == "QR": + rv = M.inverse_QR(iszerofunc=iszerofunc) + elif method == "BLOCK": + rv = M.inverse_BLOCK(iszerofunc=iszerofunc) + else: + raise ValueError("Inversion method unrecognized") + + return M._new(rv) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/kind.py b/.venv/lib/python3.13/site-packages/sympy/matrices/kind.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f53ffe16f7cbde60213e49071a2a74e80e5c6c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/kind.py @@ -0,0 +1,97 @@ +# sympy.matrices.kind + +from sympy.core.kind import Kind, _NumberKind, NumberKind +from sympy.core.mul import Mul + + +class MatrixKind(Kind): + """ + Kind for all matrices in SymPy. + + Basic class for this kind is ``MatrixBase`` and ``MatrixExpr``, + but any expression representing the matrix can have this. + + Parameters + ========== + + element_kind : Kind + Kind of the element. Default is + :class:`sympy.core.kind.NumberKind`, + which means that the matrix contains only numbers. + + Examples + ======== + + Any instance of matrix class has kind ``MatrixKind``: + + >>> from sympy import MatrixSymbol + >>> A = MatrixSymbol('A', 2, 2) + >>> A.kind + MatrixKind(NumberKind) + + An expression representing a matrix may not be an instance of + the Matrix class, but it will have kind ``MatrixKind``: + + >>> from sympy import MatrixExpr, Integral + >>> from sympy.abc import x + >>> intM = Integral(A, x) + >>> isinstance(intM, MatrixExpr) + False + >>> intM.kind + MatrixKind(NumberKind) + + Use ``isinstance()`` to check for ``MatrixKind`` without specifying the + element kind. Use ``is`` to check the kind including the element kind: + + >>> from sympy import Matrix + >>> from sympy.core import NumberKind + >>> from sympy.matrices import MatrixKind + >>> M = Matrix([1, 2]) + >>> isinstance(M.kind, MatrixKind) + True + >>> M.kind is MatrixKind(NumberKind) + True + + See Also + ======== + + sympy.core.kind.NumberKind + sympy.core.kind.UndefinedKind + sympy.core.containers.TupleKind + sympy.sets.sets.SetKind + + """ + def __new__(cls, element_kind=NumberKind): + obj = super().__new__(cls, element_kind) + obj.element_kind = element_kind + return obj + + def __repr__(self): + return "MatrixKind(%s)" % self.element_kind + + +@Mul._kind_dispatcher.register(_NumberKind, MatrixKind) +def num_mat_mul(k1, k2): + """ + Return MatrixKind. The element kind is selected by recursive dispatching. + Do not need to dispatch in reversed order because KindDispatcher + searches for this automatically. + """ + # Deal with Mul._kind_dispatcher's commutativity + # XXX: this function is called with either k1 or k2 as MatrixKind because + # the Mul kind dispatcher is commutative. Maybe it shouldn't be. Need to + # swap the args here because NumberKind does not have an element_kind + # attribute. + if not isinstance(k2, MatrixKind): + k1, k2 = k2, k1 + elemk = Mul._kind_dispatcher(k1, k2.element_kind) + return MatrixKind(elemk) + + +@Mul._kind_dispatcher.register(MatrixKind, MatrixKind) +def mat_mat_mul(k1, k2): + """ + Return MatrixKind. The element kind is selected by recursive dispatching. + """ + elemk = Mul._kind_dispatcher(k1.element_kind, k2.element_kind) + return MatrixKind(elemk) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/matrices.py b/.venv/lib/python3.13/site-packages/sympy/matrices/matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..fed41a626cb395ac0529071317630d853e0d3a96 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/matrices.py @@ -0,0 +1,687 @@ +# +# A module consisting of deprecated matrix classes. New code should not be +# added here. +# +from sympy.core.basic import Basic +from sympy.core.symbol import Dummy + +from .common import MatrixCommon + +from .exceptions import NonSquareMatrixError + +from .utilities import _iszero, _is_zero_after_expand_mul, _simplify + +from .determinant import ( + _find_reasonable_pivot, _find_reasonable_pivot_naive, + _adjugate, _charpoly, _cofactor, _cofactor_matrix, _per, + _det, _det_bareiss, _det_berkowitz, _det_bird, _det_laplace, _det_LU, + _minor, _minor_submatrix) + +from .reductions import _is_echelon, _echelon_form, _rank, _rref +from .subspaces import _columnspace, _nullspace, _rowspace, _orthogonalize + +from .eigen import ( + _eigenvals, _eigenvects, + _bidiagonalize, _bidiagonal_decomposition, + _is_diagonalizable, _diagonalize, + _is_positive_definite, _is_positive_semidefinite, + _is_negative_definite, _is_negative_semidefinite, _is_indefinite, + _jordan_form, _left_eigenvects, _singular_values) + + +# This class was previously defined in this module, but was moved to +# sympy.matrices.matrixbase. We import it here for backwards compatibility in +# case someone was importing it from here. +from .matrixbase import MatrixBase + + +__doctest_requires__ = { + ('MatrixEigen.is_indefinite', + 'MatrixEigen.is_negative_definite', + 'MatrixEigen.is_negative_semidefinite', + 'MatrixEigen.is_positive_definite', + 'MatrixEigen.is_positive_semidefinite'): ['matplotlib'], +} + + +class MatrixDeterminant(MatrixCommon): + """Provides basic matrix determinant operations. Should not be instantiated + directly. See ``determinant.py`` for their implementations.""" + + def _eval_det_bareiss(self, iszerofunc=_is_zero_after_expand_mul): + return _det_bareiss(self, iszerofunc=iszerofunc) + + def _eval_det_berkowitz(self): + return _det_berkowitz(self) + + def _eval_det_lu(self, iszerofunc=_iszero, simpfunc=None): + return _det_LU(self, iszerofunc=iszerofunc, simpfunc=simpfunc) + + def _eval_det_bird(self): + return _det_bird(self) + + def _eval_det_laplace(self): + return _det_laplace(self) + + def _eval_determinant(self): # for expressions.determinant.Determinant + return _det(self) + + def adjugate(self, method="berkowitz"): + return _adjugate(self, method=method) + + def charpoly(self, x='lambda', simplify=_simplify): + return _charpoly(self, x=x, simplify=simplify) + + def cofactor(self, i, j, method="berkowitz"): + return _cofactor(self, i, j, method=method) + + def cofactor_matrix(self, method="berkowitz"): + return _cofactor_matrix(self, method=method) + + def det(self, method="bareiss", iszerofunc=None): + return _det(self, method=method, iszerofunc=iszerofunc) + + def per(self): + return _per(self) + + def minor(self, i, j, method="berkowitz"): + return _minor(self, i, j, method=method) + + def minor_submatrix(self, i, j): + return _minor_submatrix(self, i, j) + + _find_reasonable_pivot.__doc__ = _find_reasonable_pivot.__doc__ + _find_reasonable_pivot_naive.__doc__ = _find_reasonable_pivot_naive.__doc__ + _eval_det_bareiss.__doc__ = _det_bareiss.__doc__ + _eval_det_berkowitz.__doc__ = _det_berkowitz.__doc__ + _eval_det_bird.__doc__ = _det_bird.__doc__ + _eval_det_laplace.__doc__ = _det_laplace.__doc__ + _eval_det_lu.__doc__ = _det_LU.__doc__ + _eval_determinant.__doc__ = _det.__doc__ + adjugate.__doc__ = _adjugate.__doc__ + charpoly.__doc__ = _charpoly.__doc__ + cofactor.__doc__ = _cofactor.__doc__ + cofactor_matrix.__doc__ = _cofactor_matrix.__doc__ + det.__doc__ = _det.__doc__ + per.__doc__ = _per.__doc__ + minor.__doc__ = _minor.__doc__ + minor_submatrix.__doc__ = _minor_submatrix.__doc__ + + +class MatrixReductions(MatrixDeterminant): + """Provides basic matrix row/column operations. Should not be instantiated + directly. See ``reductions.py`` for some of their implementations.""" + + def echelon_form(self, iszerofunc=_iszero, simplify=False, with_pivots=False): + return _echelon_form(self, iszerofunc=iszerofunc, simplify=simplify, + with_pivots=with_pivots) + + @property + def is_echelon(self): + return _is_echelon(self) + + def rank(self, iszerofunc=_iszero, simplify=False): + return _rank(self, iszerofunc=iszerofunc, simplify=simplify) + + def rref_rhs(self, rhs): + """Return reduced row-echelon form of matrix, matrix showing + rhs after reduction steps. ``rhs`` must have the same number + of rows as ``self``. + + Examples + ======== + + >>> from sympy import Matrix, symbols + >>> r1, r2 = symbols('r1 r2') + >>> Matrix([[1, 1], [2, 1]]).rref_rhs(Matrix([r1, r2])) + (Matrix([ + [1, 0], + [0, 1]]), Matrix([ + [ -r1 + r2], + [2*r1 - r2]])) + """ + r, _ = _rref(self.hstack(self, self.eye(self.rows), rhs)) + return r[:, :self.cols], r[:, -rhs.cols:] + + def rref(self, iszerofunc=_iszero, simplify=False, pivots=True, + normalize_last=True): + return _rref(self, iszerofunc=iszerofunc, simplify=simplify, + pivots=pivots, normalize_last=normalize_last) + + echelon_form.__doc__ = _echelon_form.__doc__ + is_echelon.__doc__ = _is_echelon.__doc__ + rank.__doc__ = _rank.__doc__ + rref.__doc__ = _rref.__doc__ + + def _normalize_op_args(self, op, col, k, col1, col2, error_str="col"): + """Validate the arguments for a row/column operation. ``error_str`` + can be one of "row" or "col" depending on the arguments being parsed.""" + if op not in ["n->kn", "n<->m", "n->n+km"]: + raise ValueError("Unknown {} operation '{}'. Valid col operations " + "are 'n->kn', 'n<->m', 'n->n+km'".format(error_str, op)) + + # define self_col according to error_str + self_cols = self.cols if error_str == 'col' else self.rows + + # normalize and validate the arguments + if op == "n->kn": + col = col if col is not None else col1 + if col is None or k is None: + raise ValueError("For a {0} operation 'n->kn' you must provide the " + "kwargs `{0}` and `k`".format(error_str)) + if not 0 <= col < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col)) + + elif op == "n<->m": + # we need two cols to swap. It does not matter + # how they were specified, so gather them together and + # remove `None` + cols = {col, k, col1, col2}.difference([None]) + if len(cols) > 2: + # maybe the user left `k` by mistake? + cols = {col, col1, col2}.difference([None]) + if len(cols) != 2: + raise ValueError("For a {0} operation 'n<->m' you must provide the " + "kwargs `{0}1` and `{0}2`".format(error_str)) + col1, col2 = cols + if not 0 <= col1 < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col1)) + if not 0 <= col2 < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col2)) + + elif op == "n->n+km": + col = col1 if col is None else col + col2 = col1 if col2 is None else col2 + if col is None or col2 is None or k is None: + raise ValueError("For a {0} operation 'n->n+km' you must provide the " + "kwargs `{0}`, `k`, and `{0}2`".format(error_str)) + if col == col2: + raise ValueError("For a {0} operation 'n->n+km' `{0}` and `{0}2` must " + "be different.".format(error_str)) + if not 0 <= col < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col)) + if not 0 <= col2 < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col2)) + + else: + raise ValueError('invalid operation %s' % repr(op)) + + return op, col, k, col1, col2 + + def _eval_col_op_multiply_col_by_const(self, col, k): + def entry(i, j): + if j == col: + return k * self[i, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_col_op_swap(self, col1, col2): + def entry(i, j): + if j == col1: + return self[i, col2] + elif j == col2: + return self[i, col1] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_col_op_add_multiple_to_other_col(self, col, k, col2): + def entry(i, j): + if j == col: + return self[i, j] + k * self[i, col2] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_row_op_swap(self, row1, row2): + def entry(i, j): + if i == row1: + return self[row2, j] + elif i == row2: + return self[row1, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_row_op_multiply_row_by_const(self, row, k): + def entry(i, j): + if i == row: + return k * self[i, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_row_op_add_multiple_to_other_row(self, row, k, row2): + def entry(i, j): + if i == row: + return self[i, j] + k * self[row2, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def elementary_col_op(self, op="n->kn", col=None, k=None, col1=None, col2=None): + """Performs the elementary column operation `op`. + + `op` may be one of + + * ``"n->kn"`` (column n goes to k*n) + * ``"n<->m"`` (swap column n and column m) + * ``"n->n+km"`` (column n goes to column n + k*column m) + + Parameters + ========== + + op : string; the elementary row operation + col : the column to apply the column operation + k : the multiple to apply in the column operation + col1 : one column of a column swap + col2 : second column of a column swap or column "m" in the column operation + "n->n+km" + """ + + op, col, k, col1, col2 = self._normalize_op_args(op, col, k, col1, col2, "col") + + # now that we've validated, we're all good to dispatch + if op == "n->kn": + return self._eval_col_op_multiply_col_by_const(col, k) + if op == "n<->m": + return self._eval_col_op_swap(col1, col2) + if op == "n->n+km": + return self._eval_col_op_add_multiple_to_other_col(col, k, col2) + + def elementary_row_op(self, op="n->kn", row=None, k=None, row1=None, row2=None): + """Performs the elementary row operation `op`. + + `op` may be one of + + * ``"n->kn"`` (row n goes to k*n) + * ``"n<->m"`` (swap row n and row m) + * ``"n->n+km"`` (row n goes to row n + k*row m) + + Parameters + ========== + + op : string; the elementary row operation + row : the row to apply the row operation + k : the multiple to apply in the row operation + row1 : one row of a row swap + row2 : second row of a row swap or row "m" in the row operation + "n->n+km" + """ + + op, row, k, row1, row2 = self._normalize_op_args(op, row, k, row1, row2, "row") + + # now that we've validated, we're all good to dispatch + if op == "n->kn": + return self._eval_row_op_multiply_row_by_const(row, k) + if op == "n<->m": + return self._eval_row_op_swap(row1, row2) + if op == "n->n+km": + return self._eval_row_op_add_multiple_to_other_row(row, k, row2) + + +class MatrixSubspaces(MatrixReductions): + """Provides methods relating to the fundamental subspaces of a matrix. + Should not be instantiated directly. See ``subspaces.py`` for their + implementations.""" + + def columnspace(self, simplify=False): + return _columnspace(self, simplify=simplify) + + def nullspace(self, simplify=False, iszerofunc=_iszero): + return _nullspace(self, simplify=simplify, iszerofunc=iszerofunc) + + def rowspace(self, simplify=False): + return _rowspace(self, simplify=simplify) + + # This is a classmethod but is converted to such later in order to allow + # assignment of __doc__ since that does not work for already wrapped + # classmethods in Python 3.6. + def orthogonalize(cls, *vecs, **kwargs): + return _orthogonalize(cls, *vecs, **kwargs) + + columnspace.__doc__ = _columnspace.__doc__ + nullspace.__doc__ = _nullspace.__doc__ + rowspace.__doc__ = _rowspace.__doc__ + orthogonalize.__doc__ = _orthogonalize.__doc__ + + orthogonalize = classmethod(orthogonalize) # type:ignore + + +class MatrixEigen(MatrixSubspaces): + """Provides basic matrix eigenvalue/vector operations. + Should not be instantiated directly. See ``eigen.py`` for their + implementations.""" + + def eigenvals(self, error_when_incomplete=True, **flags): + return _eigenvals(self, error_when_incomplete=error_when_incomplete, **flags) + + def eigenvects(self, error_when_incomplete=True, iszerofunc=_iszero, **flags): + return _eigenvects(self, error_when_incomplete=error_when_incomplete, + iszerofunc=iszerofunc, **flags) + + def is_diagonalizable(self, reals_only=False, **kwargs): + return _is_diagonalizable(self, reals_only=reals_only, **kwargs) + + def diagonalize(self, reals_only=False, sort=False, normalize=False): + return _diagonalize(self, reals_only=reals_only, sort=sort, + normalize=normalize) + + def bidiagonalize(self, upper=True): + return _bidiagonalize(self, upper=upper) + + def bidiagonal_decomposition(self, upper=True): + return _bidiagonal_decomposition(self, upper=upper) + + @property + def is_positive_definite(self): + return _is_positive_definite(self) + + @property + def is_positive_semidefinite(self): + return _is_positive_semidefinite(self) + + @property + def is_negative_definite(self): + return _is_negative_definite(self) + + @property + def is_negative_semidefinite(self): + return _is_negative_semidefinite(self) + + @property + def is_indefinite(self): + return _is_indefinite(self) + + def jordan_form(self, calc_transform=True, **kwargs): + return _jordan_form(self, calc_transform=calc_transform, **kwargs) + + def left_eigenvects(self, **flags): + return _left_eigenvects(self, **flags) + + def singular_values(self): + return _singular_values(self) + + eigenvals.__doc__ = _eigenvals.__doc__ + eigenvects.__doc__ = _eigenvects.__doc__ + is_diagonalizable.__doc__ = _is_diagonalizable.__doc__ + diagonalize.__doc__ = _diagonalize.__doc__ + is_positive_definite.__doc__ = _is_positive_definite.__doc__ + is_positive_semidefinite.__doc__ = _is_positive_semidefinite.__doc__ + is_negative_definite.__doc__ = _is_negative_definite.__doc__ + is_negative_semidefinite.__doc__ = _is_negative_semidefinite.__doc__ + is_indefinite.__doc__ = _is_indefinite.__doc__ + jordan_form.__doc__ = _jordan_form.__doc__ + left_eigenvects.__doc__ = _left_eigenvects.__doc__ + singular_values.__doc__ = _singular_values.__doc__ + bidiagonalize.__doc__ = _bidiagonalize.__doc__ + bidiagonal_decomposition.__doc__ = _bidiagonal_decomposition.__doc__ + + +class MatrixCalculus(MatrixCommon): + """Provides calculus-related matrix operations.""" + + def diff(self, *args, evaluate=True, **kwargs): + """Calculate the derivative of each element in the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.diff(x) + Matrix([ + [1, 0], + [0, 0]]) + + See Also + ======== + + integrate + limit + """ + # XXX this should be handled here rather than in Derivative + from sympy.tensor.array.array_derivatives import ArrayDerivative + deriv = ArrayDerivative(self, *args, evaluate=evaluate) + # XXX This can rather changed to always return immutable matrix + if not isinstance(self, Basic) and evaluate: + return deriv.as_mutable() + return deriv + + def _eval_derivative(self, arg): + return self.applyfunc(lambda x: x.diff(arg)) + + def integrate(self, *args, **kwargs): + """Integrate each element of the matrix. ``args`` will + be passed to the ``integrate`` function. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.integrate((x, )) + Matrix([ + [x**2/2, x*y], + [ x, 0]]) + >>> M.integrate((x, 0, 2)) + Matrix([ + [2, 2*y], + [2, 0]]) + + See Also + ======== + + limit + diff + """ + return self.applyfunc(lambda x: x.integrate(*args, **kwargs)) + + def jacobian(self, X): + """Calculates the Jacobian matrix (derivative of a vector-valued function). + + Parameters + ========== + + ``self`` : vector of expressions representing functions f_i(x_1, ..., x_n). + X : set of x_i's in order, it can be a list or a Matrix + + Both ``self`` and X can be a row or a column matrix in any order + (i.e., jacobian() should always work). + + Examples + ======== + + >>> from sympy import sin, cos, Matrix + >>> from sympy.abc import rho, phi + >>> X = Matrix([rho*cos(phi), rho*sin(phi), rho**2]) + >>> Y = Matrix([rho, phi]) + >>> X.jacobian(Y) + Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)], + [ 2*rho, 0]]) + >>> X = Matrix([rho*cos(phi), rho*sin(phi)]) + >>> X.jacobian(Y) + Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)]]) + + See Also + ======== + + hessian + wronskian + """ + if not isinstance(X, MatrixBase): + X = self._new(X) + # Both X and ``self`` can be a row or a column matrix, so we need to make + # sure all valid combinations work, but everything else fails: + if self.shape[0] == 1: + m = self.shape[1] + elif self.shape[1] == 1: + m = self.shape[0] + else: + raise TypeError("``self`` must be a row or a column matrix") + if X.shape[0] == 1: + n = X.shape[1] + elif X.shape[1] == 1: + n = X.shape[0] + else: + raise TypeError("X must be a row or a column matrix") + + # m is the number of functions and n is the number of variables + # computing the Jacobian is now easy: + return self._new(m, n, lambda j, i: self[j].diff(X[i])) + + def limit(self, *args): + """Calculate the limit of each element in the matrix. + ``args`` will be passed to the ``limit`` function. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.limit(x, 2) + Matrix([ + [2, y], + [1, 0]]) + + See Also + ======== + + integrate + diff + """ + return self.applyfunc(lambda x: x.limit(*args)) + + +# https://github.com/sympy/sympy/pull/12854 +class MatrixDeprecated(MatrixCommon): + """A class to house deprecated matrix methods.""" + def berkowitz_charpoly(self, x=Dummy('lambda'), simplify=_simplify): + return self.charpoly(x=x) + + def berkowitz_det(self): + """Computes determinant using Berkowitz method. + + See Also + ======== + + det + berkowitz + """ + return self.det(method='berkowitz') + + def berkowitz_eigenvals(self, **flags): + """Computes eigenvalues of a Matrix using Berkowitz method. + + See Also + ======== + + berkowitz + """ + return self.eigenvals(**flags) + + def berkowitz_minors(self): + """Computes principal minors using Berkowitz method. + + See Also + ======== + + berkowitz + """ + sign, minors = self.one, [] + + for poly in self.berkowitz(): + minors.append(sign * poly[-1]) + sign = -sign + + return tuple(minors) + + def berkowitz(self): + from sympy.matrices import zeros + berk = ((1,),) + if not self: + return berk + + if not self.is_square: + raise NonSquareMatrixError() + + A, N = self, self.rows + transforms = [0] * (N - 1) + + for n in range(N, 1, -1): + T, k = zeros(n + 1, n), n - 1 + + R, C = -A[k, :k], A[:k, k] + A, a = A[:k, :k], -A[k, k] + + items = [C] + + for i in range(0, n - 2): + items.append(A * items[i]) + + for i, B in enumerate(items): + items[i] = (R * B)[0, 0] + + items = [self.one, a] + items + + for i in range(n): + T[i:, i] = items[:n - i + 1] + + transforms[k - 1] = T + + polys = [self._new([self.one, -A[0, 0]])] + + for i, T in enumerate(transforms): + polys.append(T * polys[i]) + + return berk + tuple(map(tuple, polys)) + + def cofactorMatrix(self, method="berkowitz"): + return self.cofactor_matrix(method=method) + + def det_bareis(self): + return _det_bareiss(self) + + def det_LU_decomposition(self): + """Compute matrix determinant using LU decomposition. + + + Note that this method fails if the LU decomposition itself + fails. In particular, if the matrix has no inverse this method + will fail. + + TODO: Implement algorithm for sparse matrices (SFF), + https://www.eecis.udel.edu/~saunders/papers/sffge/it5.ps + + See Also + ======== + + + det + det_bareiss + berkowitz_det + """ + return self.det(method='lu') + + def jordan_cell(self, eigenval, n): + return self.jordan_block(size=n, eigenvalue=eigenval) + + def jordan_cells(self, calc_transformation=True): + P, J = self.jordan_form() + return P, J.get_diag_blocks() + + def minorEntry(self, i, j, method="berkowitz"): + return self.minor(i, j, method=method) + + def minorMatrix(self, i, j): + return self.minor_submatrix(i, j) + + def permuteBkwd(self, perm): + """Permute the rows of the matrix with the given permutation in reverse.""" + return self.permute_rows(perm, direction='backward') + + def permuteFwd(self, perm): + """Permute the rows of the matrix with the given permutation.""" + return self.permute_rows(perm, direction='forward') diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/matrixbase.py b/.venv/lib/python3.13/site-packages/sympy/matrices/matrixbase.py new file mode 100644 index 0000000000000000000000000000000000000000..49acc04043b30e003f7eed256f2e06e6a6556401 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/matrixbase.py @@ -0,0 +1,5428 @@ +from __future__ import annotations +from collections import defaultdict +from collections.abc import Iterable +from inspect import isfunction +from functools import reduce + +from sympy.assumptions.refine import refine +from sympy.core import SympifyError, Add +from sympy.core.basic import Atom, Basic +from sympy.core.kind import UndefinedKind +from sympy.core.numbers import Integer +from sympy.core.mod import Mod +from sympy.core.symbol import Symbol, Dummy +from sympy.core.sympify import sympify, _sympify +from sympy.core.function import diff +from sympy.polys import cancel +from sympy.functions.elementary.complexes import Abs, re, im +from sympy.printing import sstr +from sympy.functions.elementary.miscellaneous import Max, Min, sqrt +from sympy.functions.special.tensor_functions import KroneckerDelta, LeviCivita +from sympy.core.singleton import S +from sympy.printing.defaults import Printable +from sympy.printing.str import StrPrinter +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.combinatorial.factorials import binomial, factorial + +import mpmath as mp +from collections.abc import Callable +from sympy.utilities.iterables import reshape +from sympy.core.expr import Expr +from sympy.core.power import Pow +from sympy.core.symbol import uniquely_named_symbol + +from .utilities import _dotprodsimp, _simplify as _utilities_simplify +from sympy.polys.polytools import Poly +from sympy.utilities.iterables import flatten, is_sequence +from sympy.utilities.misc import as_int, filldedent +from sympy.core.decorators import call_highest_priority +from sympy.core.logic import fuzzy_and, FuzzyBool +from sympy.tensor.array import NDimArray +from sympy.utilities.iterables import NotIterable + +from .utilities import _get_intermediate_simp_bool + +from .kind import MatrixKind + +from .exceptions import ( + MatrixError, ShapeError, NonSquareMatrixError, NonInvertibleMatrixError, +) + +from .utilities import _iszero, _is_zero_after_expand_mul + +from .determinant import ( + _find_reasonable_pivot, _find_reasonable_pivot_naive, + _adjugate, _charpoly, _cofactor, _cofactor_matrix, _per, + _det, _det_bareiss, _det_berkowitz, _det_bird, _det_laplace, _det_LU, + _minor, _minor_submatrix) + +from .reductions import _is_echelon, _echelon_form, _rank, _rref + +from .solvers import ( + _diagonal_solve, _lower_triangular_solve, _upper_triangular_solve, + _cholesky_solve, _LDLsolve, _LUsolve, _QRsolve, _gauss_jordan_solve, + _pinv_solve, _cramer_solve, _solve, _solve_least_squares) + +from .inverse import ( + _pinv, _inv_ADJ, _inv_GE, _inv_LU, _inv_CH, _inv_LDL, _inv_QR, + _inv, _inv_block) + +from .subspaces import _columnspace, _nullspace, _rowspace, _orthogonalize + +from .eigen import ( + _eigenvals, _eigenvects, + _bidiagonalize, _bidiagonal_decomposition, + _is_diagonalizable, _diagonalize, + _is_positive_definite, _is_positive_semidefinite, + _is_negative_definite, _is_negative_semidefinite, _is_indefinite, + _jordan_form, _left_eigenvects, _singular_values) + +from .decompositions import ( + _rank_decomposition, _cholesky, _LDLdecomposition, + _LUdecomposition, _LUdecomposition_Simple, _LUdecompositionFF, + _singular_value_decomposition, _QRdecomposition, _upper_hessenberg_decomposition) + +from .graph import ( + _connected_components, _connected_components_decomposition, + _strongly_connected_components, _strongly_connected_components_decomposition) + + +__doctest_requires__ = { + ('MatrixBase.is_indefinite', + 'MatrixBase.is_positive_definite', + 'MatrixBase.is_positive_semidefinite', + 'MatrixBase.is_negative_definite', + 'MatrixBase.is_negative_semidefinite'): ['matplotlib'], +} + + +class MatrixBase(Printable): + """All common matrix operations including basic arithmetic, shaping, + and special matrices like `zeros`, and `eye`.""" + + _op_priority = 10.01 + + # Added just for numpy compatibility + __array_priority__ = 11 + + is_Matrix = True + _class_priority = 3 + _sympify = staticmethod(sympify) + zero = S.Zero + one = S.One + + _diff_wrt: bool = True + rows: int + cols: int + _simplify = None + + @classmethod + def _new(cls, *args, **kwargs): + """`_new` must, at minimum, be callable as + `_new(rows, cols, mat) where mat is a flat list of the + elements of the matrix.""" + raise NotImplementedError("Subclasses must implement this.") + + def __eq__(self, other): + raise NotImplementedError("Subclasses must implement this.") + + def __getitem__(self, key): + """Implementations of __getitem__ should accept ints, in which + case the matrix is indexed as a flat list, tuples (i,j) in which + case the (i,j) entry is returned, slices, or mixed tuples (a,b) + where a and b are any combination of slices and integers.""" + raise NotImplementedError("Subclasses must implement this.") + + @property + def shape(self): + """The shape (dimensions) of the matrix as the 2-tuple (rows, cols). + + Examples + ======== + + >>> from sympy import zeros + >>> M = zeros(2, 3) + >>> M.shape + (2, 3) + >>> M.rows + 2 + >>> M.cols + 3 + """ + return (self.rows, self.cols) + + def _eval_col_del(self, col): + def entry(i, j): + return self[i, j] if j < col else self[i, j + 1] + return self._new(self.rows, self.cols - 1, entry) + + def _eval_col_insert(self, pos, other): + + def entry(i, j): + if j < pos: + return self[i, j] + elif pos <= j < pos + other.cols: + return other[i, j - pos] + return self[i, j - other.cols] + + return self._new(self.rows, self.cols + other.cols, entry) + + def _eval_col_join(self, other): + rows = self.rows + + def entry(i, j): + if i < rows: + return self[i, j] + return other[i - rows, j] + + return classof(self, other)._new(self.rows + other.rows, self.cols, + entry) + + def _eval_extract(self, rowsList, colsList): + mat = list(self) + cols = self.cols + indices = (i * cols + j for i in rowsList for j in colsList) + return self._new(len(rowsList), len(colsList), + [mat[i] for i in indices]) + + def _eval_get_diag_blocks(self): + sub_blocks = [] + + def recurse_sub_blocks(M): + for i in range(1, M.shape[0] + 1): + if i == 1: + to_the_right = M[0, i:] + to_the_bottom = M[i:, 0] + else: + to_the_right = M[:i, i:] + to_the_bottom = M[i:, :i] + if any(to_the_right) or any(to_the_bottom): + continue + sub_blocks.append(M[:i, :i]) + if M.shape != M[:i, :i].shape: + recurse_sub_blocks(M[i:, i:]) + return + + recurse_sub_blocks(self) + return sub_blocks + + def _eval_row_del(self, row): + def entry(i, j): + return self[i, j] if i < row else self[i + 1, j] + return self._new(self.rows - 1, self.cols, entry) + + def _eval_row_insert(self, pos, other): + entries = list(self) + insert_pos = pos * self.cols + entries[insert_pos:insert_pos] = list(other) + return self._new(self.rows + other.rows, self.cols, entries) + + def _eval_row_join(self, other): + cols = self.cols + + def entry(i, j): + if j < cols: + return self[i, j] + return other[i, j - cols] + + return classof(self, other)._new(self.rows, self.cols + other.cols, + entry) + + def _eval_tolist(self): + return [list(self[i,:]) for i in range(self.rows)] + + def _eval_todok(self): + dok = {} + rows, cols = self.shape + for i in range(rows): + for j in range(cols): + val = self[i, j] + if val != self.zero: + dok[i, j] = val + return dok + + @classmethod + def _eval_from_dok(cls, rows, cols, dok): + out_flat = [cls.zero] * (rows * cols) + for (i, j), val in dok.items(): + out_flat[i * cols + j] = val + return cls._new(rows, cols, out_flat) + + def _eval_vec(self): + rows = self.rows + + def entry(n, _): + # we want to read off the columns first + j = n // rows + i = n - j * rows + return self[i, j] + + return self._new(len(self), 1, entry) + + def _eval_vech(self, diagonal): + c = self.cols + v = [] + if diagonal: + for j in range(c): + for i in range(j, c): + v.append(self[i, j]) + else: + for j in range(c): + for i in range(j + 1, c): + v.append(self[i, j]) + return self._new(len(v), 1, v) + + def col_del(self, col): + """Delete the specified column.""" + if col < 0: + col += self.cols + if not 0 <= col < self.cols: + raise IndexError("Column {} is out of range.".format(col)) + return self._eval_col_del(col) + + def col_insert(self, pos, other): + """Insert one or more columns at the given column position. + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(3, 1) + >>> M.col_insert(1, V) + Matrix([ + [0, 1, 0, 0], + [0, 1, 0, 0], + [0, 1, 0, 0]]) + + See Also + ======== + + col + row_insert + """ + # Allows you to build a matrix even if it is null matrix + if not self: + return type(self)(other) + + pos = as_int(pos) + + if pos < 0: + pos = self.cols + pos + if pos < 0: + pos = 0 + elif pos > self.cols: + pos = self.cols + + if self.rows != other.rows: + raise ShapeError( + "The matrices have incompatible number of rows ({} and {})" + .format(self.rows, other.rows)) + + return self._eval_col_insert(pos, other) + + def col_join(self, other): + """Concatenates two matrices along self's last and other's first row. + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(1, 3) + >>> M.col_join(V) + Matrix([ + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [1, 1, 1]]) + + See Also + ======== + + col + row_join + """ + # A null matrix can always be stacked (see #10770) + if self.rows == 0 and self.cols != other.cols: + return self._new(0, other.cols, []).col_join(other) + + if self.cols != other.cols: + raise ShapeError( + "The matrices have incompatible number of columns ({} and {})" + .format(self.cols, other.cols)) + return self._eval_col_join(other) + + def col(self, j): + """Elementary column selector. + + Examples + ======== + + >>> from sympy import eye + >>> eye(2).col(0) + Matrix([ + [1], + [0]]) + + See Also + ======== + + row + col_del + col_join + col_insert + """ + return self[:, j] + + def extract(self, rowsList, colsList): + r"""Return a submatrix by specifying a list of rows and columns. + Negative indices can be given. All indices must be in the range + $-n \le i < n$ where $n$ is the number of rows or columns. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(4, 3, range(12)) + >>> m + Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11]]) + >>> m.extract([0, 1, 3], [0, 1]) + Matrix([ + [0, 1], + [3, 4], + [9, 10]]) + + Rows or columns can be repeated: + + >>> m.extract([0, 0, 1], [-1]) + Matrix([ + [2], + [2], + [5]]) + + Every other row can be taken by using range to provide the indices: + + >>> m.extract(range(0, m.rows, 2), [-1]) + Matrix([ + [2], + [8]]) + + RowsList or colsList can also be a list of booleans, in which case + the rows or columns corresponding to the True values will be selected: + + >>> m.extract([0, 1, 2, 3], [True, False, True]) + Matrix([ + [0, 2], + [3, 5], + [6, 8], + [9, 11]]) + """ + + if not is_sequence(rowsList) or not is_sequence(colsList): + raise TypeError("rowsList and colsList must be iterable") + # ensure rowsList and colsList are lists of integers + if rowsList and all(isinstance(i, bool) for i in rowsList): + rowsList = [index for index, item in enumerate(rowsList) if item] + if colsList and all(isinstance(i, bool) for i in colsList): + colsList = [index for index, item in enumerate(colsList) if item] + + # ensure everything is in range + rowsList = [a2idx(k, self.rows) for k in rowsList] + colsList = [a2idx(k, self.cols) for k in colsList] + + return self._eval_extract(rowsList, colsList) + + def get_diag_blocks(self): + """Obtains the square sub-matrices on the main diagonal of a square matrix. + + Useful for inverting symbolic matrices or solving systems of + linear equations which may be decoupled by having a block diagonal + structure. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y, z + >>> A = Matrix([[1, 3, 0, 0], [y, z*z, 0, 0], [0, 0, x, 0], [0, 0, 0, 0]]) + >>> a1, a2, a3 = A.get_diag_blocks() + >>> a1 + Matrix([ + [1, 3], + [y, z**2]]) + >>> a2 + Matrix([[x]]) + >>> a3 + Matrix([[0]]) + + """ + return self._eval_get_diag_blocks() + + @classmethod + def hstack(cls, *args): + """Return a matrix formed by joining args horizontally (i.e. + by repeated application of row_join). + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> Matrix.hstack(eye(2), 2*eye(2)) + Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2]]) + """ + if len(args) == 0: + return cls._new() + + kls = type(args[0]) + return reduce(kls.row_join, args) + + def reshape(self, rows, cols): + """Reshape the matrix. Total number of elements must remain the same. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 3, lambda i, j: 1) + >>> m + Matrix([ + [1, 1, 1], + [1, 1, 1]]) + >>> m.reshape(1, 6) + Matrix([[1, 1, 1, 1, 1, 1]]) + >>> m.reshape(3, 2) + Matrix([ + [1, 1], + [1, 1], + [1, 1]]) + + """ + if self.rows * self.cols != rows * cols: + raise ValueError("Invalid reshape parameters %d %d" % (rows, cols)) + dok = {divmod(i*self.cols + j, cols): + v for (i, j), v in self.todok().items()} + return self._eval_from_dok(rows, cols, dok) + + def row_del(self, row): + """Delete the specified row.""" + if row < 0: + row += self.rows + if not 0 <= row < self.rows: + raise IndexError("Row {} is out of range.".format(row)) + + return self._eval_row_del(row) + + def row_insert(self, pos, other): + """Insert one or more rows at the given row position. + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(1, 3) + >>> M.row_insert(1, V) + Matrix([ + [0, 0, 0], + [1, 1, 1], + [0, 0, 0], + [0, 0, 0]]) + + See Also + ======== + + row + col_insert + """ + # Allows you to build a matrix even if it is null matrix + if not self: + return self._new(other) + + pos = as_int(pos) + + if pos < 0: + pos = self.rows + pos + if pos < 0: + pos = 0 + elif pos > self.rows: + pos = self.rows + + if self.cols != other.cols: + raise ShapeError( + "The matrices have incompatible number of columns ({} and {})" + .format(self.cols, other.cols)) + + return self._eval_row_insert(pos, other) + + def row_join(self, other): + """Concatenates two matrices along self's last and rhs's first column + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(3, 1) + >>> M.row_join(V) + Matrix([ + [0, 0, 0, 1], + [0, 0, 0, 1], + [0, 0, 0, 1]]) + + See Also + ======== + + row + col_join + """ + # A null matrix can always be stacked (see #10770) + if self.cols == 0 and self.rows != other.rows: + return self._new(other.rows, 0, []).row_join(other) + + if self.rows != other.rows: + raise ShapeError( + "The matrices have incompatible number of rows ({} and {})" + .format(self.rows, other.rows)) + return self._eval_row_join(other) + + def diagonal(self, k=0): + """Returns the kth diagonal of self. The main diagonal + corresponds to `k=0`; diagonals above and below correspond to + `k > 0` and `k < 0`, respectively. The values of `self[i, j]` + for which `j - i = k`, are returned in order of increasing + `i + j`, starting with `i + j = |k|`. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(3, 3, lambda i, j: j - i); m + Matrix([ + [ 0, 1, 2], + [-1, 0, 1], + [-2, -1, 0]]) + >>> _.diagonal() + Matrix([[0, 0, 0]]) + >>> m.diagonal(1) + Matrix([[1, 1]]) + >>> m.diagonal(-2) + Matrix([[-2]]) + + Even though the diagonal is returned as a Matrix, the element + retrieval can be done with a single index: + + >>> Matrix.diag(1, 2, 3).diagonal()[1] # instead of [0, 1] + 2 + + See Also + ======== + + diag + """ + rv = [] + k = as_int(k) + r = 0 if k > 0 else -k + c = 0 if r else k + while True: + if r == self.rows or c == self.cols: + break + rv.append(self[r, c]) + r += 1 + c += 1 + if not rv: + raise ValueError(filldedent(''' + The %s diagonal is out of range [%s, %s]''' % ( + k, 1 - self.rows, self.cols - 1))) + return self._new(1, len(rv), rv) + + def row(self, i): + """Elementary row selector. + + Examples + ======== + + >>> from sympy import eye + >>> eye(2).row(0) + Matrix([[1, 0]]) + + See Also + ======== + + col + row_del + row_join + row_insert + """ + return self[i, :] + + def todok(self): + """Return the matrix as dictionary of keys. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix.eye(3) + >>> M.todok() + {(0, 0): 1, (1, 1): 1, (2, 2): 1} + """ + return self._eval_todok() + + @classmethod + def from_dok(cls, rows, cols, dok): + """Create a matrix from a dictionary of keys. + + Examples + ======== + + >>> from sympy import Matrix + >>> d = {(0, 0): 1, (1, 2): 3, (2, 1): 4} + >>> Matrix.from_dok(3, 3, d) + Matrix([ + [1, 0, 0], + [0, 0, 3], + [0, 4, 0]]) + """ + dok = {ij: cls._sympify(val) for ij, val in dok.items()} + return cls._eval_from_dok(rows, cols, dok) + + def tolist(self): + """Return the Matrix as a nested Python list. + + Examples + ======== + + >>> from sympy import Matrix, ones + >>> m = Matrix(3, 3, range(9)) + >>> m + Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> m.tolist() + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + >>> ones(3, 0).tolist() + [[], [], []] + + When there are no rows then it will not be possible to tell how + many columns were in the original matrix: + + >>> ones(0, 3).tolist() + [] + + """ + if not self.rows: + return [] + if not self.cols: + return [[] for i in range(self.rows)] + return self._eval_tolist() + + def todod(M): + """Returns matrix as dict of dicts containing non-zero elements of the Matrix + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[0, 1],[0, 3]]) + >>> A + Matrix([ + [0, 1], + [0, 3]]) + >>> A.todod() + {0: {1: 1}, 1: {1: 3}} + + + """ + rowsdict = {} + Mlol = M.tolist() + for i, Mi in enumerate(Mlol): + row = {j: Mij for j, Mij in enumerate(Mi) if Mij} + if row: + rowsdict[i] = row + return rowsdict + + def vec(self): + """Return the Matrix converted into a one column matrix by stacking columns + + Examples + ======== + + >>> from sympy import Matrix + >>> m=Matrix([[1, 3], [2, 4]]) + >>> m + Matrix([ + [1, 3], + [2, 4]]) + >>> m.vec() + Matrix([ + [1], + [2], + [3], + [4]]) + + See Also + ======== + + vech + """ + return self._eval_vec() + + def vech(self, diagonal=True, check_symmetry=True): + """Reshapes the matrix into a column vector by stacking the + elements in the lower triangle. + + Parameters + ========== + + diagonal : bool, optional + If ``True``, it includes the diagonal elements. + + check_symmetry : bool, optional + If ``True``, it checks whether the matrix is symmetric. + + Examples + ======== + + >>> from sympy import Matrix + >>> m=Matrix([[1, 2], [2, 3]]) + >>> m + Matrix([ + [1, 2], + [2, 3]]) + >>> m.vech() + Matrix([ + [1], + [2], + [3]]) + >>> m.vech(diagonal=False) + Matrix([[2]]) + + Notes + ===== + + This should work for symmetric matrices and ``vech`` can + represent symmetric matrices in vector form with less size than + ``vec``. + + See Also + ======== + + vec + """ + if not self.is_square: + raise NonSquareMatrixError + + if check_symmetry and not self.is_symmetric(): + raise ValueError("The matrix is not symmetric.") + + return self._eval_vech(diagonal) + + @classmethod + def vstack(cls, *args): + """Return a matrix formed by joining args vertically (i.e. + by repeated application of col_join). + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> Matrix.vstack(eye(2), 2*eye(2)) + Matrix([ + [1, 0], + [0, 1], + [2, 0], + [0, 2]]) + """ + if len(args) == 0: + return cls._new() + + kls = type(args[0]) + return reduce(kls.col_join, args) + + @classmethod + def _eval_diag(cls, rows, cols, diag_dict): + """diag_dict is a defaultdict containing + all the entries of the diagonal matrix.""" + def entry(i, j): + return diag_dict[(i, j)] + return cls._new(rows, cols, entry) + + @classmethod + def _eval_eye(cls, rows, cols): + vals = [cls.zero]*(rows*cols) + vals[::cols+1] = [cls.one]*min(rows, cols) + return cls._new(rows, cols, vals, copy=False) + + @classmethod + def _eval_jordan_block(cls, size: int, eigenvalue, band='upper'): + if band == 'lower': + def entry(i, j): + if i == j: + return eigenvalue + elif j + 1 == i: + return cls.one + return cls.zero + else: + def entry(i, j): + if i == j: + return eigenvalue + elif i + 1 == j: + return cls.one + return cls.zero + return cls._new(size, size, entry) + + @classmethod + def _eval_ones(cls, rows, cols): + def entry(i, j): + return cls.one + return cls._new(rows, cols, entry) + + @classmethod + def _eval_zeros(cls, rows, cols): + return cls._new(rows, cols, [cls.zero]*(rows*cols), copy=False) + + @classmethod + def _eval_wilkinson(cls, n): + def entry(i, j): + return cls.one if i + 1 == j else cls.zero + + D = cls._new(2*n + 1, 2*n + 1, entry) + + wminus = cls.diag(list(range(-n, n + 1)), unpack=True) + D + D.T + wplus = abs(cls.diag(list(range(-n, n + 1)), unpack=True)) + D + D.T + + return wminus, wplus + + @classmethod + def diag(kls, *args, strict=False, unpack=True, rows=None, cols=None, **kwargs): + """Returns a matrix with the specified diagonal. + If matrices are passed, a block-diagonal matrix + is created (i.e. the "direct sum" of the matrices). + + kwargs + ====== + + rows : rows of the resulting matrix; computed if + not given. + + cols : columns of the resulting matrix; computed if + not given. + + cls : class for the resulting matrix + + unpack : bool which, when True (default), unpacks a single + sequence rather than interpreting it as a Matrix. + + strict : bool which, when False (default), allows Matrices to + have variable-length rows. + + Examples + ======== + + >>> from sympy import Matrix + >>> Matrix.diag(1, 2, 3) + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + + The current default is to unpack a single sequence. If this is + not desired, set `unpack=False` and it will be interpreted as + a matrix. + + >>> Matrix.diag([1, 2, 3]) == Matrix.diag(1, 2, 3) + True + + When more than one element is passed, each is interpreted as + something to put on the diagonal. Lists are converted to + matrices. Filling of the diagonal always continues from + the bottom right hand corner of the previous item: this + will create a block-diagonal matrix whether the matrices + are square or not. + + >>> col = [1, 2, 3] + >>> row = [[4, 5]] + >>> Matrix.diag(col, row) + Matrix([ + [1, 0, 0], + [2, 0, 0], + [3, 0, 0], + [0, 4, 5]]) + + When `unpack` is False, elements within a list need not all be + of the same length. Setting `strict` to True would raise a + ValueError for the following: + + >>> Matrix.diag([[1, 2, 3], [4, 5], [6]], unpack=False) + Matrix([ + [1, 2, 3], + [4, 5, 0], + [6, 0, 0]]) + + The type of the returned matrix can be set with the ``cls`` + keyword. + + >>> from sympy import ImmutableMatrix + >>> from sympy.utilities.misc import func_name + >>> func_name(Matrix.diag(1, cls=ImmutableMatrix)) + 'ImmutableDenseMatrix' + + A zero dimension matrix can be used to position the start of + the filling at the start of an arbitrary row or column: + + >>> from sympy import ones + >>> r2 = ones(0, 2) + >>> Matrix.diag(r2, 1, 2) + Matrix([ + [0, 0, 1, 0], + [0, 0, 0, 2]]) + + See Also + ======== + eye + diagonal + .dense.diag + .expressions.blockmatrix.BlockMatrix + .sparsetools.banded + """ + from sympy.matrices.matrixbase import MatrixBase + from sympy.matrices.dense import Matrix + from sympy.matrices import SparseMatrix + klass = kwargs.get('cls', kls) + if unpack and len(args) == 1 and is_sequence(args[0]) and \ + not isinstance(args[0], MatrixBase): + args = args[0] + + # fill a default dict with the diagonal entries + diag_entries = defaultdict(int) + rmax = cmax = 0 # keep track of the biggest index seen + for m in args: + if isinstance(m, list): + if strict: + # if malformed, Matrix will raise an error + _ = Matrix(m) + r, c = _.shape + m = _.tolist() + else: + r, c, smat = SparseMatrix._handle_creation_inputs(m) + for (i, j), _ in smat.items(): + diag_entries[(i + rmax, j + cmax)] = _ + m = [] # to skip process below + elif hasattr(m, 'shape'): # a Matrix + # convert to list of lists + r, c = m.shape + m = m.tolist() + else: # in this case, we're a single value + diag_entries[(rmax, cmax)] = m + rmax += 1 + cmax += 1 + continue + # process list of lists + for i, mi in enumerate(m): + for j, _ in enumerate(mi): + diag_entries[(i + rmax, j + cmax)] = _ + rmax += r + cmax += c + if rows is None: + rows, cols = cols, rows + if rows is None: + rows, cols = rmax, cmax + else: + cols = rows if cols is None else cols + if rows < rmax or cols < cmax: + raise ValueError(filldedent(''' + The constructed matrix is {} x {} but a size of {} x {} + was specified.'''.format(rmax, cmax, rows, cols))) + return klass._eval_diag(rows, cols, diag_entries) + + @classmethod + def eye(kls, rows, cols=None, **kwargs): + """Returns an identity matrix. + + Parameters + ========== + + rows : rows of the matrix + cols : cols of the matrix (if None, cols=rows) + + kwargs + ====== + cls : class of the returned matrix + """ + if cols is None: + cols = rows + if rows < 0 or cols < 0: + raise ValueError("Cannot create a {} x {} matrix. " + "Both dimensions must be positive".format(rows, cols)) + klass = kwargs.get('cls', kls) + rows, cols = as_int(rows), as_int(cols) + + return klass._eval_eye(rows, cols) + + @classmethod + def jordan_block(kls, size=None, eigenvalue=None, *, band='upper', **kwargs): + """Returns a Jordan block + + Parameters + ========== + + size : Integer, optional + Specifies the shape of the Jordan block matrix. + + eigenvalue : Number or Symbol + Specifies the value for the main diagonal of the matrix. + + .. note:: + The keyword ``eigenval`` is also specified as an alias + of this keyword, but it is not recommended to use. + + We may deprecate the alias in later release. + + band : 'upper' or 'lower', optional + Specifies the position of the off-diagonal to put `1` s on. + + cls : Matrix, optional + Specifies the matrix class of the output form. + + If it is not specified, the class type where the method is + being executed on will be returned. + + Returns + ======= + + Matrix + A Jordan block matrix. + + Raises + ====== + + ValueError + If insufficient arguments are given for matrix size + specification, or no eigenvalue is given. + + Examples + ======== + + Creating a default Jordan block: + + >>> from sympy import Matrix + >>> from sympy.abc import x + >>> Matrix.jordan_block(4, x) + Matrix([ + [x, 1, 0, 0], + [0, x, 1, 0], + [0, 0, x, 1], + [0, 0, 0, x]]) + + Creating an alternative Jordan block matrix where `1` is on + lower off-diagonal: + + >>> Matrix.jordan_block(4, x, band='lower') + Matrix([ + [x, 0, 0, 0], + [1, x, 0, 0], + [0, 1, x, 0], + [0, 0, 1, x]]) + + Creating a Jordan block with keyword arguments + + >>> Matrix.jordan_block(size=4, eigenvalue=x) + Matrix([ + [x, 1, 0, 0], + [0, x, 1, 0], + [0, 0, x, 1], + [0, 0, 0, x]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Jordan_matrix + """ + klass = kwargs.pop('cls', kls) + + eigenval = kwargs.get('eigenval', None) + if eigenvalue is None and eigenval is None: + raise ValueError("Must supply an eigenvalue") + elif eigenvalue != eigenval and None not in (eigenval, eigenvalue): + raise ValueError( + "Inconsistent values are given: 'eigenval'={}, " + "'eigenvalue'={}".format(eigenval, eigenvalue)) + else: + if eigenval is not None: + eigenvalue = eigenval + + if size is None: + raise ValueError("Must supply a matrix size") + + size = as_int(size) + return klass._eval_jordan_block(size, eigenvalue, band) + + @classmethod + def ones(kls, rows, cols=None, **kwargs): + """Returns a matrix of ones. + + Parameters + ========== + + rows : rows of the matrix + cols : cols of the matrix (if None, cols=rows) + + kwargs + ====== + cls : class of the returned matrix + """ + if cols is None: + cols = rows + klass = kwargs.get('cls', kls) + rows, cols = as_int(rows), as_int(cols) + + return klass._eval_ones(rows, cols) + + @classmethod + def zeros(kls, rows, cols=None, **kwargs): + """Returns a matrix of zeros. + + Parameters + ========== + + rows : rows of the matrix + cols : cols of the matrix (if None, cols=rows) + + kwargs + ====== + cls : class of the returned matrix + """ + if cols is None: + cols = rows + if rows < 0 or cols < 0: + raise ValueError("Cannot create a {} x {} matrix. " + "Both dimensions must be positive".format(rows, cols)) + klass = kwargs.get('cls', kls) + rows, cols = as_int(rows), as_int(cols) + + return klass._eval_zeros(rows, cols) + + @classmethod + def companion(kls, poly): + """Returns a companion matrix of a polynomial. + + Examples + ======== + + >>> from sympy import Matrix, Poly, Symbol, symbols + >>> x = Symbol('x') + >>> c0, c1, c2, c3, c4 = symbols('c0:5') + >>> p = Poly(c0 + c1*x + c2*x**2 + c3*x**3 + c4*x**4 + x**5, x) + >>> Matrix.companion(p) + Matrix([ + [0, 0, 0, 0, -c0], + [1, 0, 0, 0, -c1], + [0, 1, 0, 0, -c2], + [0, 0, 1, 0, -c3], + [0, 0, 0, 1, -c4]]) + """ + poly = kls._sympify(poly) + if not isinstance(poly, Poly): + raise ValueError("{} must be a Poly instance.".format(poly)) + if not poly.is_monic: + raise ValueError("{} must be a monic polynomial.".format(poly)) + if not poly.is_univariate: + raise ValueError( + "{} must be a univariate polynomial.".format(poly)) + + size = poly.degree() + if not size >= 1: + raise ValueError( + "{} must have degree not less than 1.".format(poly)) + + coeffs = poly.all_coeffs() + def entry(i, j): + if j == size - 1: + return -coeffs[-1 - i] + elif i == j + 1: + return kls.one + return kls.zero + return kls._new(size, size, entry) + + + @classmethod + def wilkinson(kls, n, **kwargs): + """Returns two square Wilkinson Matrix of size 2*n + 1 + $W_{2n + 1}^-, W_{2n + 1}^+ =$ Wilkinson(n) + + Examples + ======== + + >>> from sympy import Matrix + >>> wminus, wplus = Matrix.wilkinson(3) + >>> wminus + Matrix([ + [-3, 1, 0, 0, 0, 0, 0], + [ 1, -2, 1, 0, 0, 0, 0], + [ 0, 1, -1, 1, 0, 0, 0], + [ 0, 0, 1, 0, 1, 0, 0], + [ 0, 0, 0, 1, 1, 1, 0], + [ 0, 0, 0, 0, 1, 2, 1], + [ 0, 0, 0, 0, 0, 1, 3]]) + >>> wplus + Matrix([ + [3, 1, 0, 0, 0, 0, 0], + [1, 2, 1, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 2, 1], + [0, 0, 0, 0, 0, 1, 3]]) + + References + ========== + + .. [1] https://blogs.mathworks.com/cleve/2013/04/15/wilkinsons-matrices-2/ + .. [2] J. H. Wilkinson, The Algebraic Eigenvalue Problem, Claredon Press, Oxford, 1965, 662 pp. + + """ + klass = kwargs.get('cls', kls) + n = as_int(n) + return klass._eval_wilkinson(n) + + # The RepMatrix subclass uses more efficient sparse implementations of + # _eval_iter_values and other things. + + def _eval_iter_values(self): + return (i for i in self if i is not S.Zero) + + def _eval_values(self): + return list(self.iter_values()) + + def _eval_iter_items(self): + for i in range(self.rows): + for j in range(self.cols): + if self[i, j]: + yield (i, j), self[i, j] + + def _eval_atoms(self, *types): + values = self.values() + if len(values) < self.rows * self.cols and isinstance(S.Zero, types): + s = {S.Zero} + else: + s = set() + return s.union(*[v.atoms(*types) for v in values]) + + def _eval_free_symbols(self): + return set().union(*(i.free_symbols for i in set(self.values()))) + + def _eval_has(self, *patterns): + return any(a.has(*patterns) for a in self.iter_values()) + + def _eval_is_symbolic(self): + return self.has(Symbol) + + # _eval_is_hermitian is called by some general SymPy + # routines and has a different *args signature. Make + # sure the names don't clash by adding `_matrix_` in name. + def _eval_is_matrix_hermitian(self, simpfunc): + herm = lambda i, j: simpfunc(self[i, j] - self[j, i].adjoint()).is_zero + return fuzzy_and(herm(i, j) for (i, j), v in self.iter_items()) + + def _eval_is_zero_matrix(self): + return fuzzy_and(v.is_zero for v in self.iter_values()) + + def _eval_is_Identity(self) -> FuzzyBool: + one = self.one + zero = self.zero + ident = lambda i, j, v: v is one if i == j else v is zero + return all(ident(i, j, v) for (i, j), v in self.iter_items()) + + def _eval_is_diagonal(self): + return fuzzy_and(v.is_zero for (i, j), v in self.iter_items() if i != j) + + def _eval_is_lower(self): + return all(v.is_zero for (i, j), v in self.iter_items() if i < j) + + def _eval_is_upper(self): + return all(v.is_zero for (i, j), v in self.iter_items() if i > j) + + def _eval_is_lower_hessenberg(self): + return all(v.is_zero for (i, j), v in self.iter_items() if i + 1 < j) + + def _eval_is_upper_hessenberg(self): + return all(v.is_zero for (i, j), v in self.iter_items() if i > j + 1) + + def _eval_is_symmetric(self, simpfunc): + sym = lambda i, j: simpfunc(self[i, j] - self[j, i]).is_zero + return fuzzy_and(sym(i, j) for (i, j), v in self.iter_items()) + + def _eval_is_anti_symmetric(self, simpfunc): + anti = lambda i, j: simpfunc(self[i, j] + self[j, i]).is_zero + return fuzzy_and(anti(i, j) for (i, j), v in self.iter_items()) + + def _has_positive_diagonals(self): + diagonal_entries = (self[i, i] for i in range(self.rows)) + return fuzzy_and(x.is_positive for x in diagonal_entries) + + def _has_nonnegative_diagonals(self): + diagonal_entries = (self[i, i] for i in range(self.rows)) + return fuzzy_and(x.is_nonnegative for x in diagonal_entries) + + def atoms(self, *types): + """Returns the atoms that form the current object. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import Matrix + >>> Matrix([[x]]) + Matrix([[x]]) + >>> _.atoms() + {x} + >>> Matrix([[x, y], [y, x]]) + Matrix([ + [x, y], + [y, x]]) + >>> _.atoms() + {x, y} + """ + + types = tuple(t if isinstance(t, type) else type(t) for t in types) + if not types: + types = (Atom,) + return self._eval_atoms(*types) + + @property + def free_symbols(self): + """Returns the free symbols within the matrix. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import Matrix + >>> Matrix([[x], [1]]).free_symbols + {x} + """ + return self._eval_free_symbols() + + def has(self, *patterns): + """Test whether any subexpression matches any of the patterns. + + Examples + ======== + + >>> from sympy import Matrix, SparseMatrix, Float + >>> from sympy.abc import x, y + >>> A = Matrix(((1, x), (0.2, 3))) + >>> B = SparseMatrix(((1, x), (0.2, 3))) + >>> A.has(x) + True + >>> A.has(y) + False + >>> A.has(Float) + True + >>> B.has(x) + True + >>> B.has(y) + False + >>> B.has(Float) + True + """ + return self._eval_has(*patterns) + + def is_anti_symmetric(self, simplify=True): + """Check if matrix M is an antisymmetric matrix, + that is, M is a square matrix with all M[i, j] == -M[j, i]. + + When ``simplify=True`` (default), the sum M[i, j] + M[j, i] is + simplified before testing to see if it is zero. By default, + the SymPy simplify function is used. To use a custom function + set simplify to a function that accepts a single argument which + returns a simplified expression. To skip simplification, set + simplify to False but note that although this will be faster, + it may induce false negatives. + + Examples + ======== + + >>> from sympy import Matrix, symbols + >>> m = Matrix(2, 2, [0, 1, -1, 0]) + >>> m + Matrix([ + [ 0, 1], + [-1, 0]]) + >>> m.is_anti_symmetric() + True + >>> x, y = symbols('x y') + >>> m = Matrix(2, 3, [0, 0, x, -y, 0, 0]) + >>> m + Matrix([ + [ 0, 0, x], + [-y, 0, 0]]) + >>> m.is_anti_symmetric() + False + + >>> from sympy.abc import x, y + >>> m = Matrix(3, 3, [0, x**2 + 2*x + 1, y, + ... -(x + 1)**2, 0, x*y, + ... -y, -x*y, 0]) + + Simplification of matrix elements is done by default so even + though two elements which should be equal and opposite would not + pass an equality test, the matrix is still reported as + anti-symmetric: + + >>> m[0, 1] == -m[1, 0] + False + >>> m.is_anti_symmetric() + True + + If ``simplify=False`` is used for the case when a Matrix is already + simplified, this will speed things up. Here, we see that without + simplification the matrix does not appear anti-symmetric: + + >>> print(m.is_anti_symmetric(simplify=False)) + None + + But if the matrix were already expanded, then it would appear + anti-symmetric and simplification in the is_anti_symmetric routine + is not needed: + + >>> m = m.expand() + >>> m.is_anti_symmetric(simplify=False) + True + """ + # accept custom simplification + simpfunc = simplify + if not isfunction(simplify): + simpfunc = _utilities_simplify if simplify else lambda x: x + + if not self.is_square: + return False + return self._eval_is_anti_symmetric(simpfunc) + + def is_diagonal(self): + """Check if matrix is diagonal, + that is matrix in which the entries outside the main diagonal are all zero. + + Examples + ======== + + >>> from sympy import Matrix, diag + >>> m = Matrix(2, 2, [1, 0, 0, 2]) + >>> m + Matrix([ + [1, 0], + [0, 2]]) + >>> m.is_diagonal() + True + + >>> m = Matrix(2, 2, [1, 1, 0, 2]) + >>> m + Matrix([ + [1, 1], + [0, 2]]) + >>> m.is_diagonal() + False + + >>> m = diag(1, 2, 3) + >>> m + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + >>> m.is_diagonal() + True + + See Also + ======== + + is_lower + is_upper + sympy.matrices.matrixbase.MatrixBase.is_diagonalizable + diagonalize + """ + return self._eval_is_diagonal() + + @property + def is_weakly_diagonally_dominant(self): + r"""Tests if the matrix is row weakly diagonally dominant. + + Explanation + =========== + + A $n, n$ matrix $A$ is row weakly diagonally dominant if + + .. math:: + \left|A_{i, i}\right| \ge \sum_{j = 0, j \neq i}^{n-1} + \left|A_{i, j}\right| \quad {\text{for all }} + i \in \{ 0, ..., n-1 \} + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[3, -2, 1], [1, -3, 2], [-1, 2, 4]]) + >>> A.is_weakly_diagonally_dominant + True + + >>> A = Matrix([[-2, 2, 1], [1, 3, 2], [1, -2, 0]]) + >>> A.is_weakly_diagonally_dominant + False + + >>> A = Matrix([[-4, 2, 1], [1, 6, 2], [1, -2, 5]]) + >>> A.is_weakly_diagonally_dominant + True + + Notes + ===== + + If you want to test whether a matrix is column diagonally + dominant, you can apply the test after transposing the matrix. + """ + if not self.is_square: + return False + + rows, cols = self.shape + + def test_row(i): + summation = self.zero + for j in range(cols): + if i != j: + summation += Abs(self[i, j]) + return (Abs(self[i, i]) - summation).is_nonnegative + + return fuzzy_and(test_row(i) for i in range(rows)) + + @property + def is_strongly_diagonally_dominant(self): + r"""Tests if the matrix is row strongly diagonally dominant. + + Explanation + =========== + + A $n, n$ matrix $A$ is row strongly diagonally dominant if + + .. math:: + \left|A_{i, i}\right| > \sum_{j = 0, j \neq i}^{n-1} + \left|A_{i, j}\right| \quad {\text{for all }} + i \in \{ 0, ..., n-1 \} + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[3, -2, 1], [1, -3, 2], [-1, 2, 4]]) + >>> A.is_strongly_diagonally_dominant + False + + >>> A = Matrix([[-2, 2, 1], [1, 3, 2], [1, -2, 0]]) + >>> A.is_strongly_diagonally_dominant + False + + >>> A = Matrix([[-4, 2, 1], [1, 6, 2], [1, -2, 5]]) + >>> A.is_strongly_diagonally_dominant + True + + Notes + ===== + + If you want to test whether a matrix is column diagonally + dominant, you can apply the test after transposing the matrix. + """ + if not self.is_square: + return False + + rows, cols = self.shape + + def test_row(i): + summation = self.zero + for j in range(cols): + if i != j: + summation += Abs(self[i, j]) + return (Abs(self[i, i]) - summation).is_positive + + return fuzzy_and(test_row(i) for i in range(rows)) + + @property + def is_hermitian(self): + """Checks if the matrix is Hermitian. + + In a Hermitian matrix element i,j is the complex conjugate of + element j,i. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy import I + >>> from sympy.abc import x + >>> a = Matrix([[1, I], [-I, 1]]) + >>> a + Matrix([ + [ 1, I], + [-I, 1]]) + >>> a.is_hermitian + True + >>> a[0, 0] = 2*I + >>> a.is_hermitian + False + >>> a[0, 0] = x + >>> a.is_hermitian + >>> a[0, 1] = a[1, 0]*I + >>> a.is_hermitian + False + """ + if not self.is_square: + return False + + return self._eval_is_matrix_hermitian(_utilities_simplify) + + @property + def is_Identity(self) -> FuzzyBool: + if not self.is_square: + return False + return self._eval_is_Identity() + + @property + def is_lower_hessenberg(self): + r"""Checks if the matrix is in the lower-Hessenberg form. + + The lower hessenberg matrix has zero entries + above the first superdiagonal. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[1, 2, 0, 0], [5, 2, 3, 0], [3, 4, 3, 7], [5, 6, 1, 1]]) + >>> a + Matrix([ + [1, 2, 0, 0], + [5, 2, 3, 0], + [3, 4, 3, 7], + [5, 6, 1, 1]]) + >>> a.is_lower_hessenberg + True + + See Also + ======== + + is_upper_hessenberg + is_lower + """ + return self._eval_is_lower_hessenberg() + + @property + def is_lower(self): + """Check if matrix is a lower triangular matrix. True can be returned + even if the matrix is not square. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, [1, 0, 0, 1]) + >>> m + Matrix([ + [1, 0], + [0, 1]]) + >>> m.is_lower + True + + >>> m = Matrix(4, 3, [0, 0, 0, 2, 0, 0, 1, 4, 0, 6, 6, 5]) + >>> m + Matrix([ + [0, 0, 0], + [2, 0, 0], + [1, 4, 0], + [6, 6, 5]]) + >>> m.is_lower + True + + >>> from sympy.abc import x, y + >>> m = Matrix(2, 2, [x**2 + y, y**2 + x, 0, x + y]) + >>> m + Matrix([ + [x**2 + y, x + y**2], + [ 0, x + y]]) + >>> m.is_lower + False + + See Also + ======== + + is_upper + is_diagonal + is_lower_hessenberg + """ + return self._eval_is_lower() + + @property + def is_square(self): + """Checks if a matrix is square. + + A matrix is square if the number of rows equals the number of columns. + The empty matrix is square by definition, since the number of rows and + the number of columns are both zero. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> b = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> c = Matrix([]) + >>> a.is_square + False + >>> b.is_square + True + >>> c.is_square + True + """ + return self.rows == self.cols + + def is_symbolic(self): + """Checks if any elements contain Symbols. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.is_symbolic() + True + + """ + return self._eval_is_symbolic() + + def is_symmetric(self, simplify=True): + """Check if matrix is symmetric matrix, + that is square matrix and is equal to its transpose. + + By default, simplifications occur before testing symmetry. + They can be skipped using 'simplify=False'; while speeding things a bit, + this may however induce false negatives. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, [0, 1, 1, 2]) + >>> m + Matrix([ + [0, 1], + [1, 2]]) + >>> m.is_symmetric() + True + + >>> m = Matrix(2, 2, [0, 1, 2, 0]) + >>> m + Matrix([ + [0, 1], + [2, 0]]) + >>> m.is_symmetric() + False + + >>> m = Matrix(2, 3, [0, 0, 0, 0, 0, 0]) + >>> m + Matrix([ + [0, 0, 0], + [0, 0, 0]]) + >>> m.is_symmetric() + False + + >>> from sympy.abc import x, y + >>> m = Matrix(3, 3, [1, x**2 + 2*x + 1, y, (x + 1)**2, 2, 0, y, 0, 3]) + >>> m + Matrix([ + [ 1, x**2 + 2*x + 1, y], + [(x + 1)**2, 2, 0], + [ y, 0, 3]]) + >>> m.is_symmetric() + True + + If the matrix is already simplified, you may speed-up is_symmetric() + test by using 'simplify=False'. + + >>> bool(m.is_symmetric(simplify=False)) + False + >>> m1 = m.expand() + >>> m1.is_symmetric(simplify=False) + True + """ + simpfunc = simplify + if not isfunction(simplify): + simpfunc = _utilities_simplify if simplify else lambda x: x + + if not self.is_square: + return False + + return self._eval_is_symmetric(simpfunc) + + @property + def is_upper_hessenberg(self): + """Checks if the matrix is the upper-Hessenberg form. + + The upper hessenberg matrix has zero entries + below the first subdiagonal. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[1, 4, 2, 3], [3, 4, 1, 7], [0, 2, 3, 4], [0, 0, 1, 3]]) + >>> a + Matrix([ + [1, 4, 2, 3], + [3, 4, 1, 7], + [0, 2, 3, 4], + [0, 0, 1, 3]]) + >>> a.is_upper_hessenberg + True + + See Also + ======== + + is_lower_hessenberg + is_upper + """ + return self._eval_is_upper_hessenberg() + + @property + def is_upper(self): + """Check if matrix is an upper triangular matrix. True can be returned + even if the matrix is not square. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, [1, 0, 0, 1]) + >>> m + Matrix([ + [1, 0], + [0, 1]]) + >>> m.is_upper + True + + >>> m = Matrix(4, 3, [5, 1, 9, 0, 4, 6, 0, 0, 5, 0, 0, 0]) + >>> m + Matrix([ + [5, 1, 9], + [0, 4, 6], + [0, 0, 5], + [0, 0, 0]]) + >>> m.is_upper + True + + >>> m = Matrix(2, 3, [4, 2, 5, 6, 1, 1]) + >>> m + Matrix([ + [4, 2, 5], + [6, 1, 1]]) + >>> m.is_upper + False + + See Also + ======== + + is_lower + is_diagonal + is_upper_hessenberg + """ + return self._eval_is_upper() + + @property + def is_zero_matrix(self): + """Checks if a matrix is a zero matrix. + + A matrix is zero if every element is zero. A matrix need not be square + to be considered zero. The empty matrix is zero by the principle of + vacuous truth. For a matrix that may or may not be zero (e.g. + contains a symbol), this will be None + + Examples + ======== + + >>> from sympy import Matrix, zeros + >>> from sympy.abc import x + >>> a = Matrix([[0, 0], [0, 0]]) + >>> b = zeros(3, 4) + >>> c = Matrix([[0, 1], [0, 0]]) + >>> d = Matrix([]) + >>> e = Matrix([[x, 0], [0, 0]]) + >>> a.is_zero_matrix + True + >>> b.is_zero_matrix + True + >>> c.is_zero_matrix + False + >>> d.is_zero_matrix + True + >>> e.is_zero_matrix + """ + return self._eval_is_zero_matrix() + + def values(self): + """Return non-zero values of self. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix([[0, 1], [2, 3]]) + >>> m.values() + [1, 2, 3] + + See Also + ======== + + iter_values + tolist + flat + """ + return self._eval_values() + + def iter_values(self): + """ + Iterate over non-zero values of self. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix([[0, 1], [2, 3]]) + >>> list(m.iter_values()) + [1, 2, 3] + + See Also + ======== + + values + """ + return self._eval_iter_values() + + def iter_items(self): + """Iterate over indices and values of nonzero items. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix([[0, 1], [2, 3]]) + >>> list(m.iter_items()) + [((0, 1), 1), ((1, 0), 2), ((1, 1), 3)] + + See Also + ======== + + iter_values + todok + """ + return self._eval_iter_items() + + def _eval_adjoint(self): + return self.transpose().applyfunc(lambda x: x.adjoint()) + + def _eval_applyfunc(self, f): + cols = self.cols + size = self.rows*self.cols + + dok = self.todok() + valmap = {v: f(v) for v in dok.values()} + + if len(dok) < size and ((fzero := f(S.Zero)) is not S.Zero): + out_flat = [fzero]*size + for (i, j), v in dok.items(): + out_flat[i*cols + j] = valmap[v] + out = self._new(self.rows, self.cols, out_flat) + else: + fdok = {ij: valmap[v] for ij, v in dok.items()} + out = self.from_dok(self.rows, self.cols, fdok) + + return out + + def _eval_as_real_imag(self): # type: ignore + return (self.applyfunc(re), self.applyfunc(im)) + + def _eval_conjugate(self): + return self.applyfunc(lambda x: x.conjugate()) + + def _eval_permute_cols(self, perm): + # apply the permutation to a list + mapping = list(perm) + + def entry(i, j): + return self[i, mapping[j]] + + return self._new(self.rows, self.cols, entry) + + def _eval_permute_rows(self, perm): + # apply the permutation to a list + mapping = list(perm) + + def entry(i, j): + return self[mapping[i], j] + + return self._new(self.rows, self.cols, entry) + + def _eval_trace(self): + return sum(self[i, i] for i in range(self.rows)) + + def _eval_transpose(self): + return self._new(self.cols, self.rows, lambda i, j: self[j, i]) + + def adjoint(self): + """Conjugate transpose or Hermitian conjugation.""" + return self._eval_adjoint() + + def applyfunc(self, f): + """Apply a function to each element of the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, lambda i, j: i*2+j) + >>> m + Matrix([ + [0, 1], + [2, 3]]) + >>> m.applyfunc(lambda i: 2*i) + Matrix([ + [0, 2], + [4, 6]]) + + """ + if not callable(f): + raise TypeError("`f` must be callable.") + + return self._eval_applyfunc(f) + + def as_real_imag(self, deep=True, **hints): + """Returns a tuple containing the (real, imaginary) part of matrix.""" + # XXX: Ignoring deep and hints... + return self._eval_as_real_imag() + + def conjugate(self): + """Return the by-element conjugation. + + Examples + ======== + + >>> from sympy import SparseMatrix, I + >>> a = SparseMatrix(((1, 2 + I), (3, 4), (I, -I))) + >>> a + Matrix([ + [1, 2 + I], + [3, 4], + [I, -I]]) + >>> a.C + Matrix([ + [ 1, 2 - I], + [ 3, 4], + [-I, I]]) + + See Also + ======== + + transpose: Matrix transposition + H: Hermite conjugation + sympy.matrices.matrixbase.MatrixBase.D: Dirac conjugation + """ + return self._eval_conjugate() + + def doit(self, **hints): + return self.applyfunc(lambda x: x.doit(**hints)) + + def evalf(self, n=15, subs=None, maxn=100, chop=False, strict=False, quad=None, verbose=False): + """Apply evalf() to each element of self.""" + options = {'subs':subs, 'maxn':maxn, 'chop':chop, 'strict':strict, + 'quad':quad, 'verbose':verbose} + return self.applyfunc(lambda i: i.evalf(n, **options)) + + def expand(self, deep=True, modulus=None, power_base=True, power_exp=True, + mul=True, log=True, multinomial=True, basic=True, **hints): + """Apply core.function.expand to each entry of the matrix. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import Matrix + >>> Matrix(1, 1, [x*(x+1)]) + Matrix([[x*(x + 1)]]) + >>> _.expand() + Matrix([[x**2 + x]]) + + """ + return self.applyfunc(lambda x: x.expand( + deep, modulus, power_base, power_exp, mul, log, multinomial, basic, + **hints)) + + @property + def H(self): + """Return Hermite conjugate. + + Examples + ======== + + >>> from sympy import Matrix, I + >>> m = Matrix((0, 1 + I, 2, 3)) + >>> m + Matrix([ + [ 0], + [1 + I], + [ 2], + [ 3]]) + >>> m.H + Matrix([[0, 1 - I, 2, 3]]) + + See Also + ======== + + conjugate: By-element conjugation + sympy.matrices.matrixbase.MatrixBase.D: Dirac conjugation + """ + return self.adjoint() + + def permute(self, perm, orientation='rows', direction='forward'): + r"""Permute the rows or columns of a matrix by the given list of + swaps. + + Parameters + ========== + + perm : Permutation, list, or list of lists + A representation for the permutation. + + If it is ``Permutation``, it is used directly with some + resizing with respect to the matrix size. + + If it is specified as list of lists, + (e.g., ``[[0, 1], [0, 2]]``), then the permutation is formed + from applying the product of cycles. The direction how the + cyclic product is applied is described in below. + + If it is specified as a list, the list should represent + an array form of a permutation. (e.g., ``[1, 2, 0]``) which + would would form the swapping function + `0 \mapsto 1, 1 \mapsto 2, 2\mapsto 0`. + + orientation : 'rows', 'cols' + A flag to control whether to permute the rows or the columns + + direction : 'forward', 'backward' + A flag to control whether to apply the permutations from + the start of the list first, or from the back of the list + first. + + For example, if the permutation specification is + ``[[0, 1], [0, 2]]``, + + If the flag is set to ``'forward'``, the cycle would be + formed as `0 \mapsto 2, 2 \mapsto 1, 1 \mapsto 0`. + + If the flag is set to ``'backward'``, the cycle would be + formed as `0 \mapsto 1, 1 \mapsto 2, 2 \mapsto 0`. + + If the argument ``perm`` is not in a form of list of lists, + this flag takes no effect. + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.permute([[0, 1], [0, 2]], orientation='rows', direction='forward') + Matrix([ + [0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + + >>> from sympy import eye + >>> M = eye(3) + >>> M.permute([[0, 1], [0, 2]], orientation='rows', direction='backward') + Matrix([ + [0, 1, 0], + [0, 0, 1], + [1, 0, 0]]) + + Notes + ===== + + If a bijective function + `\sigma : \mathbb{N}_0 \rightarrow \mathbb{N}_0` denotes the + permutation. + + If the matrix `A` is the matrix to permute, represented as + a horizontal or a vertical stack of vectors: + + .. math:: + A = + \begin{bmatrix} + a_0 \\ a_1 \\ \vdots \\ a_{n-1} + \end{bmatrix} = + \begin{bmatrix} + \alpha_0 & \alpha_1 & \cdots & \alpha_{n-1} + \end{bmatrix} + + If the matrix `B` is the result, the permutation of matrix rows + is defined as: + + .. math:: + B := \begin{bmatrix} + a_{\sigma(0)} \\ a_{\sigma(1)} \\ \vdots \\ a_{\sigma(n-1)} + \end{bmatrix} + + And the permutation of matrix columns is defined as: + + .. math:: + B := \begin{bmatrix} + \alpha_{\sigma(0)} & \alpha_{\sigma(1)} & + \cdots & \alpha_{\sigma(n-1)} + \end{bmatrix} + """ + from sympy.combinatorics import Permutation + + # allow british variants and `columns` + if direction == 'forwards': + direction = 'forward' + if direction == 'backwards': + direction = 'backward' + if orientation == 'columns': + orientation = 'cols' + + if direction not in ('forward', 'backward'): + raise TypeError("direction='{}' is an invalid kwarg. " + "Try 'forward' or 'backward'".format(direction)) + if orientation not in ('rows', 'cols'): + raise TypeError("orientation='{}' is an invalid kwarg. " + "Try 'rows' or 'cols'".format(orientation)) + + if not isinstance(perm, (Permutation, Iterable)): + raise ValueError( + "{} must be a list, a list of lists, " + "or a SymPy permutation object.".format(perm)) + + # ensure all swaps are in range + max_index = self.rows if orientation == 'rows' else self.cols + if not all(0 <= t <= max_index for t in flatten(list(perm))): + raise IndexError("`swap` indices out of range.") + + if perm and not isinstance(perm, Permutation) and \ + isinstance(perm[0], Iterable): + if direction == 'forward': + perm = list(reversed(perm)) + perm = Permutation(perm, size=max_index+1) + else: + perm = Permutation(perm, size=max_index+1) + + if orientation == 'rows': + return self._eval_permute_rows(perm) + if orientation == 'cols': + return self._eval_permute_cols(perm) + + def permute_cols(self, swaps, direction='forward'): + """Alias for + ``self.permute(swaps, orientation='cols', direction=direction)`` + + See Also + ======== + + permute + """ + return self.permute(swaps, orientation='cols', direction=direction) + + def permute_rows(self, swaps, direction='forward'): + """Alias for + ``self.permute(swaps, orientation='rows', direction=direction)`` + + See Also + ======== + + permute + """ + return self.permute(swaps, orientation='rows', direction=direction) + + def refine(self, assumptions=True): + """Apply refine to each element of the matrix. + + Examples + ======== + + >>> from sympy import Symbol, Matrix, Abs, sqrt, Q + >>> x = Symbol('x') + >>> Matrix([[Abs(x)**2, sqrt(x**2)],[sqrt(x**2), Abs(x)**2]]) + Matrix([ + [ Abs(x)**2, sqrt(x**2)], + [sqrt(x**2), Abs(x)**2]]) + >>> _.refine(Q.real(x)) + Matrix([ + [ x**2, Abs(x)], + [Abs(x), x**2]]) + + """ + return self.applyfunc(lambda x: refine(x, assumptions)) + + def replace(self, F, G, map=False, simultaneous=True, exact=None): + """Replaces Function F in Matrix entries with Function G. + + Examples + ======== + + >>> from sympy import symbols, Function, Matrix + >>> F, G = symbols('F, G', cls=Function) + >>> M = Matrix(2, 2, lambda i, j: F(i+j)) ; M + Matrix([ + [F(0), F(1)], + [F(1), F(2)]]) + >>> N = M.replace(F,G) + >>> N + Matrix([ + [G(0), G(1)], + [G(1), G(2)]]) + """ + kwargs = {'map': map, 'simultaneous': simultaneous, 'exact': exact} + + if map: + + d = {} + def func(eij): + eij, dij = eij.replace(F, G, **kwargs) + d.update(dij) + return eij + + M = self.applyfunc(func) + return M, d + + else: + return self.applyfunc(lambda i: i.replace(F, G, **kwargs)) + + def rot90(self, k=1): + """Rotates Matrix by 90 degrees + + Parameters + ========== + + k : int + Specifies how many times the matrix is rotated by 90 degrees + (clockwise when positive, counter-clockwise when negative). + + Examples + ======== + + >>> from sympy import Matrix, symbols + >>> A = Matrix(2, 2, symbols('a:d')) + >>> A + Matrix([ + [a, b], + [c, d]]) + + Rotating the matrix clockwise one time: + + >>> A.rot90(1) + Matrix([ + [c, a], + [d, b]]) + + Rotating the matrix anticlockwise two times: + + >>> A.rot90(-2) + Matrix([ + [d, c], + [b, a]]) + """ + + mod = k%4 + if mod == 0: + return self + if mod == 1: + return self[::-1, ::].T + if mod == 2: + return self[::-1, ::-1] + if mod == 3: + return self[::, ::-1].T + + def simplify(self, **kwargs): + """Apply simplify to each element of the matrix. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import SparseMatrix, sin, cos + >>> SparseMatrix(1, 1, [x*sin(y)**2 + x*cos(y)**2]) + Matrix([[x*sin(y)**2 + x*cos(y)**2]]) + >>> _.simplify() + Matrix([[x]]) + """ + return self.applyfunc(lambda x: x.simplify(**kwargs)) + + def subs(self, *args, **kwargs): # should mirror core.basic.subs + """Return a new matrix with subs applied to each entry. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import SparseMatrix, Matrix + >>> SparseMatrix(1, 1, [x]) + Matrix([[x]]) + >>> _.subs(x, y) + Matrix([[y]]) + >>> Matrix(_).subs(y, x) + Matrix([[x]]) + """ + + if len(args) == 1 and not isinstance(args[0], (dict, set)) and iter(args[0]) and not is_sequence(args[0]): + args = (list(args[0]),) + + return self.applyfunc(lambda x: x.subs(*args, **kwargs)) + + def trace(self): + """ + Returns the trace of a square matrix i.e. the sum of the + diagonal elements. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(2, 2, [1, 2, 3, 4]) + >>> A.trace() + 5 + + """ + if self.rows != self.cols: + raise NonSquareMatrixError() + return self._eval_trace() + + def transpose(self): + """ + Returns the transpose of the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(2, 2, [1, 2, 3, 4]) + >>> A.transpose() + Matrix([ + [1, 3], + [2, 4]]) + + >>> from sympy import Matrix, I + >>> m=Matrix(((1, 2+I), (3, 4))) + >>> m + Matrix([ + [1, 2 + I], + [3, 4]]) + >>> m.transpose() + Matrix([ + [ 1, 3], + [2 + I, 4]]) + >>> m.T == m.transpose() + True + + See Also + ======== + + conjugate: By-element conjugation + + """ + return self._eval_transpose() + + @property + def T(self): + '''Matrix transposition''' + return self.transpose() + + @property + def C(self): + '''By-element conjugation''' + return self.conjugate() + + def n(self, *args, **kwargs): + """Apply evalf() to each element of self.""" + return self.evalf(*args, **kwargs) + + def xreplace(self, rule): # should mirror core.basic.xreplace + """Return a new matrix with xreplace applied to each entry. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import SparseMatrix, Matrix + >>> SparseMatrix(1, 1, [x]) + Matrix([[x]]) + >>> _.xreplace({x: y}) + Matrix([[y]]) + >>> Matrix(_).xreplace({y: x}) + Matrix([[x]]) + """ + return self.applyfunc(lambda x: x.xreplace(rule)) + + def _eval_simplify(self, **kwargs): + # XXX: We can't use self.simplify here as mutable subclasses will + # override simplify and have it return None + return self.applyfunc(lambda x: x.simplify(**kwargs)) + + def _eval_trigsimp(self, **opts): + from sympy.simplify.trigsimp import trigsimp + return self.applyfunc(lambda x: trigsimp(x, **opts)) + + def upper_triangular(self, k=0): + """Return the elements on and above the kth diagonal of a matrix. + If k is not specified then simply returns upper-triangular portion + of a matrix + + Examples + ======== + + >>> from sympy import ones + >>> A = ones(4) + >>> A.upper_triangular() + Matrix([ + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 1]]) + + >>> A.upper_triangular(2) + Matrix([ + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + + >>> A.upper_triangular(-1) + Matrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1]]) + + """ + + def entry(i, j): + return self[i, j] if i + k <= j else self.zero + + return self._new(self.rows, self.cols, entry) + + def lower_triangular(self, k=0): + """Return the elements on and below the kth diagonal of a matrix. + If k is not specified then simply returns lower-triangular portion + of a matrix + + Examples + ======== + + >>> from sympy import ones + >>> A = ones(4) + >>> A.lower_triangular() + Matrix([ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]]) + + >>> A.lower_triangular(-2) + Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [1, 1, 0, 0]]) + + >>> A.lower_triangular(1) + Matrix([ + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]]) + + """ + + def entry(i, j): + return self[i, j] if i + k >= j else self.zero + + return self._new(self.rows, self.cols, entry) + + def _eval_Abs(self): + return self._new(self.rows, self.cols, lambda i, j: Abs(self[i, j])) + + def _eval_add(self, other): + return self._new(self.rows, self.cols, + lambda i, j: self[i, j] + other[i, j]) + + def _eval_matrix_mul(self, other): + def entry(i, j): + vec = [self[i,k]*other[k,j] for k in range(self.cols)] + try: + return Add(*vec) + except (TypeError, SympifyError): + # Some matrices don't work with `sum` or `Add` + # They don't work with `sum` because `sum` tries to add `0` + # Fall back to a safe way to multiply if the `Add` fails. + return reduce(lambda a, b: a + b, vec) + + return self._new(self.rows, other.cols, entry) + + def _eval_matrix_mul_elementwise(self, other): + return self._new(self.rows, self.cols, lambda i, j: self[i,j]*other[i,j]) + + def _eval_matrix_rmul(self, other): + def entry(i, j): + return sum(other[i,k]*self[k,j] for k in range(other.cols)) + return self._new(other.rows, self.cols, entry) + + def _eval_pow_by_recursion(self, num): + if num == 1: + return self + + if num % 2 == 1: + a, b = self, self._eval_pow_by_recursion(num - 1) + else: + a = b = self._eval_pow_by_recursion(num // 2) + + return a.multiply(b) + + def _eval_pow_by_cayley(self, exp): + from sympy.discrete.recurrences import linrec_coeffs + row = self.shape[0] + p = self.charpoly() + + coeffs = (-p).all_coeffs()[1:] + coeffs = linrec_coeffs(coeffs, exp) + new_mat = self.eye(row) + ans = self.zeros(row) + + for i in range(row): + ans += coeffs[i]*new_mat + new_mat *= self + + return ans + + def _eval_pow_by_recursion_dotprodsimp(self, num, prevsimp=None): + if prevsimp is None: + prevsimp = [True]*len(self) + + if num == 1: + return self + + if num % 2 == 1: + a, b = self, self._eval_pow_by_recursion_dotprodsimp(num - 1, + prevsimp=prevsimp) + else: + a = b = self._eval_pow_by_recursion_dotprodsimp(num // 2, + prevsimp=prevsimp) + + m = a.multiply(b, dotprodsimp=False) + lenm = len(m) + elems = [None]*lenm + + for i in range(lenm): + if prevsimp[i]: + elems[i], prevsimp[i] = _dotprodsimp(m[i], withsimp=True) + else: + elems[i] = m[i] + + return m._new(m.rows, m.cols, elems) + + def _eval_scalar_mul(self, other): + return self._new(self.rows, self.cols, lambda i, j: self[i,j]*other) + + def _eval_scalar_rmul(self, other): + return self._new(self.rows, self.cols, lambda i, j: other*self[i,j]) + + def _eval_Mod(self, other): + return self._new(self.rows, self.cols, lambda i, j: Mod(self[i, j], other)) + + # Python arithmetic functions + def __abs__(self): + """Returns a new matrix with entry-wise absolute values.""" + return self._eval_Abs() + + @call_highest_priority('__radd__') + def __add__(self, other): + """Return self + other, raising ShapeError if shapes do not match.""" + + other, T = _coerce_operand(self, other) + + if T != "is_matrix": + return NotImplemented + + if self.shape != other.shape: + raise ShapeError(f"Matrix size mismatch: {self.shape} + {other.shape}.") + + # Unify matrix types + a, b = self, other + if a.__class__ != classof(a, b): + b, a = a, b + + return a._eval_add(b) + + @call_highest_priority('__rtruediv__') + def __truediv__(self, other): + return self * (self.one / other) + + @call_highest_priority('__rmatmul__') + def __matmul__(self, other): + self, other, T = _unify_with_other(self, other) + + if T != "is_matrix": + return NotImplemented + + return self.__mul__(other) + + def __mod__(self, other): + return self.applyfunc(lambda x: x % other) + + @call_highest_priority('__rmul__') + def __mul__(self, other): + """Return self*other where other is either a scalar or a matrix + of compatible dimensions. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> 2*A == A*2 == Matrix([[2, 4, 6], [8, 10, 12]]) + True + >>> B = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> A*B + Matrix([ + [30, 36, 42], + [66, 81, 96]]) + >>> B*A + Traceback (most recent call last): + ... + ShapeError: Matrices size mismatch. + >>> + + See Also + ======== + + matrix_multiply_elementwise + """ + + return self.multiply(other) + + def multiply(self, other, dotprodsimp=None): + """Same as __mul__() but with optional simplification. + + Parameters + ========== + + dotprodsimp : bool, optional + Specifies whether intermediate term algebraic simplification is used + during matrix multiplications to control expression blowup and thus + speed up calculation. Default is off. + """ + + isimpbool = _get_intermediate_simp_bool(False, dotprodsimp) + + self, other, T = _unify_with_other(self, other) + + if T == "possible_scalar": + try: + return self._eval_scalar_mul(other) + except TypeError: + return NotImplemented + + elif T == "is_matrix": + + if self.shape[1] != other.shape[0]: + raise ShapeError(f"Matrix size mismatch: {self.shape} * {other.shape}.") + + m = self._eval_matrix_mul(other) + + if isimpbool: + m = m._new(m.rows, m.cols, [_dotprodsimp(e) for e in m]) + + return m + + else: + return NotImplemented + + def multiply_elementwise(self, other): + """Return the Hadamard product (elementwise product) of A and B + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[0, 1, 2], [3, 4, 5]]) + >>> B = Matrix([[1, 10, 100], [100, 10, 1]]) + >>> A.multiply_elementwise(B) + Matrix([ + [ 0, 10, 200], + [300, 40, 5]]) + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.cross + sympy.matrices.matrixbase.MatrixBase.dot + multiply + """ + if self.shape != other.shape: + raise ShapeError("Matrix shapes must agree {} != {}".format(self.shape, other.shape)) + + return self._eval_matrix_mul_elementwise(other) + + def __neg__(self): + return self._eval_scalar_mul(-1) + + @call_highest_priority('__rpow__') + def __pow__(self, exp): + """Return self**exp a scalar or symbol.""" + + return self.pow(exp) + + + def pow(self, exp, method=None): + r"""Return self**exp a scalar or symbol. + + Parameters + ========== + + method : multiply, mulsimp, jordan, cayley + If multiply then it returns exponentiation using recursion. + If jordan then Jordan form exponentiation will be used. + If cayley then the exponentiation is done using Cayley-Hamilton + theorem. + If mulsimp then the exponentiation is done using recursion + with dotprodsimp. This specifies whether intermediate term + algebraic simplification is used during naive matrix power to + control expression blowup and thus speed up calculation. + If None, then it heuristically decides which method to use. + + """ + + if method is not None and method not in ['multiply', 'mulsimp', 'jordan', 'cayley']: + raise TypeError('No such method') + if self.rows != self.cols: + raise NonSquareMatrixError() + a = self + jordan_pow = getattr(a, '_matrix_pow_by_jordan_blocks', None) + exp = sympify(exp) + + if exp.is_zero: + return a._new(a.rows, a.cols, lambda i, j: int(i == j)) + if exp == 1: + return a + + diagonal = getattr(a, 'is_diagonal', None) + if diagonal is not None and diagonal(): + return a._new(a.rows, a.cols, lambda i, j: a[i,j]**exp if i == j else 0) + + if exp.is_Number and exp % 1 == 0: + if a.rows == 1: + return a._new([[a[0]**exp]]) + if exp < 0: + exp = -exp + a = a.inv() + # When certain conditions are met, + # Jordan block algorithm is faster than + # computation by recursion. + if method == 'jordan': + try: + return jordan_pow(exp) + except MatrixError: + if method == 'jordan': + raise + + elif method == 'cayley': + if not exp.is_Number or exp % 1 != 0: + raise ValueError("cayley method is only valid for integer powers") + return a._eval_pow_by_cayley(exp) + + elif method == "mulsimp": + if not exp.is_Number or exp % 1 != 0: + raise ValueError("mulsimp method is only valid for integer powers") + return a._eval_pow_by_recursion_dotprodsimp(exp) + + elif method == "multiply": + if not exp.is_Number or exp % 1 != 0: + raise ValueError("multiply method is only valid for integer powers") + return a._eval_pow_by_recursion(exp) + + elif method is None and exp.is_Number and exp % 1 == 0: + if exp.is_Float: + exp = Integer(exp) + # Decide heuristically which method to apply + if a.rows == 2 and exp > 100000: + return jordan_pow(exp) + elif _get_intermediate_simp_bool(True, None): + return a._eval_pow_by_recursion_dotprodsimp(exp) + elif exp > 10000: + return a._eval_pow_by_cayley(exp) + else: + return a._eval_pow_by_recursion(exp) + + if jordan_pow: + try: + return jordan_pow(exp) + except NonInvertibleMatrixError: + # Raised by jordan_pow on zero determinant matrix unless exp is + # definitely known to be a non-negative integer. + # Here we raise if n is definitely not a non-negative integer + # but otherwise we can leave this as an unevaluated MatPow. + if exp.is_integer is False or exp.is_nonnegative is False: + raise + + from sympy.matrices.expressions import MatPow + return MatPow(a, exp) + + @call_highest_priority('__add__') + def __radd__(self, other): + return self.__add__(other) + + @call_highest_priority('__matmul__') + def __rmatmul__(self, other): + self, other, T = _unify_with_other(self, other) + + if T != "is_matrix": + return NotImplemented + + return self.__rmul__(other) + + @call_highest_priority('__mul__') + def __rmul__(self, other): + return self.rmultiply(other) + + def rmultiply(self, other, dotprodsimp=None): + """Same as __rmul__() but with optional simplification. + + Parameters + ========== + + dotprodsimp : bool, optional + Specifies whether intermediate term algebraic simplification is used + during matrix multiplications to control expression blowup and thus + speed up calculation. Default is off. + """ + isimpbool = _get_intermediate_simp_bool(False, dotprodsimp) + self, other, T = _unify_with_other(self, other) + + if T == "possible_scalar": + try: + return self._eval_scalar_rmul(other) + except TypeError: + return NotImplemented + + elif T == "is_matrix": + if self.shape[0] != other.shape[1]: + raise ShapeError("Matrix size mismatch.") + + m = self._eval_matrix_rmul(other) + + if isimpbool: + return m._new(m.rows, m.cols, [_dotprodsimp(e) for e in m]) + + return m + + else: + return NotImplemented + + @call_highest_priority('__sub__') + def __rsub__(self, a): + return (-self) + a + + @call_highest_priority('__rsub__') + def __sub__(self, a): + return self + (-a) + + def _eval_det_bareiss(self, iszerofunc=_is_zero_after_expand_mul): + return _det_bareiss(self, iszerofunc=iszerofunc) + + def _eval_det_berkowitz(self): + return _det_berkowitz(self) + + def _eval_det_lu(self, iszerofunc=_iszero, simpfunc=None): + return _det_LU(self, iszerofunc=iszerofunc, simpfunc=simpfunc) + + def _eval_det_bird(self): + return _det_bird(self) + + def _eval_det_laplace(self): + return _det_laplace(self) + + def _eval_determinant(self): # for expressions.determinant.Determinant + return _det(self) + + def adjugate(self, method="berkowitz"): + return _adjugate(self, method=method) + + def charpoly(self, x='lambda', simplify=_utilities_simplify): + return _charpoly(self, x=x, simplify=simplify) + + def cofactor(self, i, j, method="berkowitz"): + return _cofactor(self, i, j, method=method) + + def cofactor_matrix(self, method="berkowitz"): + return _cofactor_matrix(self, method=method) + + def det(self, method="bareiss", iszerofunc=None): + return _det(self, method=method, iszerofunc=iszerofunc) + + def per(self): + return _per(self) + + def minor(self, i, j, method="berkowitz"): + return _minor(self, i, j, method=method) + + def minor_submatrix(self, i, j): + return _minor_submatrix(self, i, j) + + _find_reasonable_pivot.__doc__ = _find_reasonable_pivot.__doc__ + _find_reasonable_pivot_naive.__doc__ = _find_reasonable_pivot_naive.__doc__ + _eval_det_bareiss.__doc__ = _det_bareiss.__doc__ + _eval_det_berkowitz.__doc__ = _det_berkowitz.__doc__ + _eval_det_bird.__doc__ = _det_bird.__doc__ + _eval_det_laplace.__doc__ = _det_laplace.__doc__ + _eval_det_lu.__doc__ = _det_LU.__doc__ + _eval_determinant.__doc__ = _det.__doc__ + adjugate.__doc__ = _adjugate.__doc__ + charpoly.__doc__ = _charpoly.__doc__ + cofactor.__doc__ = _cofactor.__doc__ + cofactor_matrix.__doc__ = _cofactor_matrix.__doc__ + det.__doc__ = _det.__doc__ + per.__doc__ = _per.__doc__ + minor.__doc__ = _minor.__doc__ + minor_submatrix.__doc__ = _minor_submatrix.__doc__ + + def echelon_form(self, iszerofunc=_iszero, simplify=False, with_pivots=False): + return _echelon_form(self, iszerofunc=iszerofunc, simplify=simplify, + with_pivots=with_pivots) + + @property + def is_echelon(self): + return _is_echelon(self) + + def rank(self, iszerofunc=_iszero, simplify=False): + return _rank(self, iszerofunc=iszerofunc, simplify=simplify) + + def rref_rhs(self, rhs): + """Return reduced row-echelon form of matrix, matrix showing + rhs after reduction steps. ``rhs`` must have the same number + of rows as ``self``. + + Examples + ======== + + >>> from sympy import Matrix, symbols + >>> r1, r2 = symbols('r1 r2') + >>> Matrix([[1, 1], [2, 1]]).rref_rhs(Matrix([r1, r2])) + (Matrix([ + [1, 0], + [0, 1]]), Matrix([ + [ -r1 + r2], + [2*r1 - r2]])) + """ + r, _ = _rref(self.hstack(self, self.eye(self.rows), rhs)) + return r[:, :self.cols], r[:, -rhs.cols:] + + def rref(self, iszerofunc=_iszero, simplify=False, pivots=True, + normalize_last=True): + return _rref(self, iszerofunc=iszerofunc, simplify=simplify, + pivots=pivots, normalize_last=normalize_last) + + echelon_form.__doc__ = _echelon_form.__doc__ + is_echelon.__doc__ = _is_echelon.__doc__ + rank.__doc__ = _rank.__doc__ + rref.__doc__ = _rref.__doc__ + + def _normalize_op_args(self, op, col, k, col1, col2, error_str="col"): + """Validate the arguments for a row/column operation. ``error_str`` + can be one of "row" or "col" depending on the arguments being parsed.""" + if op not in ["n->kn", "n<->m", "n->n+km"]: + raise ValueError("Unknown {} operation '{}'. Valid col operations " + "are 'n->kn', 'n<->m', 'n->n+km'".format(error_str, op)) + + # define self_col according to error_str + self_cols = self.cols if error_str == 'col' else self.rows + + # normalize and validate the arguments + if op == "n->kn": + col = col if col is not None else col1 + if col is None or k is None: + raise ValueError("For a {0} operation 'n->kn' you must provide the " + "kwargs `{0}` and `k`".format(error_str)) + if not 0 <= col < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col)) + + elif op == "n<->m": + # we need two cols to swap. It does not matter + # how they were specified, so gather them together and + # remove `None` + cols = {col, k, col1, col2}.difference([None]) + if len(cols) > 2: + # maybe the user left `k` by mistake? + cols = {col, col1, col2}.difference([None]) + if len(cols) != 2: + raise ValueError("For a {0} operation 'n<->m' you must provide the " + "kwargs `{0}1` and `{0}2`".format(error_str)) + col1, col2 = cols + if not 0 <= col1 < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col1)) + if not 0 <= col2 < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col2)) + + elif op == "n->n+km": + col = col1 if col is None else col + col2 = col1 if col2 is None else col2 + if col is None or col2 is None or k is None: + raise ValueError("For a {0} operation 'n->n+km' you must provide the " + "kwargs `{0}`, `k`, and `{0}2`".format(error_str)) + if col == col2: + raise ValueError("For a {0} operation 'n->n+km' `{0}` and `{0}2` must " + "be different.".format(error_str)) + if not 0 <= col < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col)) + if not 0 <= col2 < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col2)) + + else: + raise ValueError('invalid operation %s' % repr(op)) + + return op, col, k, col1, col2 + + def _eval_col_op_multiply_col_by_const(self, col, k): + def entry(i, j): + if j == col: + return k * self[i, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_col_op_swap(self, col1, col2): + def entry(i, j): + if j == col1: + return self[i, col2] + elif j == col2: + return self[i, col1] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_col_op_add_multiple_to_other_col(self, col, k, col2): + def entry(i, j): + if j == col: + return self[i, j] + k * self[i, col2] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_row_op_swap(self, row1, row2): + def entry(i, j): + if i == row1: + return self[row2, j] + elif i == row2: + return self[row1, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_row_op_multiply_row_by_const(self, row, k): + def entry(i, j): + if i == row: + return k * self[i, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_row_op_add_multiple_to_other_row(self, row, k, row2): + def entry(i, j): + if i == row: + return self[i, j] + k * self[row2, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def elementary_col_op(self, op="n->kn", col=None, k=None, col1=None, col2=None): + """Performs the elementary column operation `op`. + + `op` may be one of + + * ``"n->kn"`` (column n goes to k*n) + * ``"n<->m"`` (swap column n and column m) + * ``"n->n+km"`` (column n goes to column n + k*column m) + + Parameters + ========== + + op : string; the elementary row operation + col : the column to apply the column operation + k : the multiple to apply in the column operation + col1 : one column of a column swap + col2 : second column of a column swap or column "m" in the column operation + "n->n+km" + """ + + op, col, k, col1, col2 = self._normalize_op_args(op, col, k, col1, col2, "col") + + # now that we've validated, we're all good to dispatch + if op == "n->kn": + return self._eval_col_op_multiply_col_by_const(col, k) + if op == "n<->m": + return self._eval_col_op_swap(col1, col2) + if op == "n->n+km": + return self._eval_col_op_add_multiple_to_other_col(col, k, col2) + + def elementary_row_op(self, op="n->kn", row=None, k=None, row1=None, row2=None): + """Performs the elementary row operation `op`. + + `op` may be one of + + * ``"n->kn"`` (row n goes to k*n) + * ``"n<->m"`` (swap row n and row m) + * ``"n->n+km"`` (row n goes to row n + k*row m) + + Parameters + ========== + + op : string; the elementary row operation + row : the row to apply the row operation + k : the multiple to apply in the row operation + row1 : one row of a row swap + row2 : second row of a row swap or row "m" in the row operation + "n->n+km" + """ + + op, row, k, row1, row2 = self._normalize_op_args(op, row, k, row1, row2, "row") + + # now that we've validated, we're all good to dispatch + if op == "n->kn": + return self._eval_row_op_multiply_row_by_const(row, k) + if op == "n<->m": + return self._eval_row_op_swap(row1, row2) + if op == "n->n+km": + return self._eval_row_op_add_multiple_to_other_row(row, k, row2) + + def columnspace(self, simplify=False): + return _columnspace(self, simplify=simplify) + + def nullspace(self, simplify=False, iszerofunc=_iszero): + return _nullspace(self, simplify=simplify, iszerofunc=iszerofunc) + + def rowspace(self, simplify=False): + return _rowspace(self, simplify=simplify) + + # This is a classmethod but is converted to such later in order to allow + # assignment of __doc__ since that does not work for already wrapped + # classmethods in Python 3.6. + def orthogonalize(cls, *vecs, **kwargs): + return _orthogonalize(cls, *vecs, **kwargs) + + columnspace.__doc__ = _columnspace.__doc__ + nullspace.__doc__ = _nullspace.__doc__ + rowspace.__doc__ = _rowspace.__doc__ + orthogonalize.__doc__ = _orthogonalize.__doc__ + + orthogonalize = classmethod(orthogonalize) # type:ignore + + def eigenvals(self, error_when_incomplete=True, **flags): + return _eigenvals(self, error_when_incomplete=error_when_incomplete, **flags) + + def eigenvects(self, error_when_incomplete=True, iszerofunc=_iszero, **flags): + return _eigenvects(self, error_when_incomplete=error_when_incomplete, + iszerofunc=iszerofunc, **flags) + + def is_diagonalizable(self, reals_only=False, **kwargs): + return _is_diagonalizable(self, reals_only=reals_only, **kwargs) + + def diagonalize(self, reals_only=False, sort=False, normalize=False): + return _diagonalize(self, reals_only=reals_only, sort=sort, + normalize=normalize) + + def bidiagonalize(self, upper=True): + return _bidiagonalize(self, upper=upper) + + def bidiagonal_decomposition(self, upper=True): + return _bidiagonal_decomposition(self, upper=upper) + + @property + def is_positive_definite(self): + return _is_positive_definite(self) + + @property + def is_positive_semidefinite(self): + return _is_positive_semidefinite(self) + + @property + def is_negative_definite(self): + return _is_negative_definite(self) + + @property + def is_negative_semidefinite(self): + return _is_negative_semidefinite(self) + + @property + def is_indefinite(self): + return _is_indefinite(self) + + def jordan_form(self, calc_transform=True, **kwargs): + return _jordan_form(self, calc_transform=calc_transform, **kwargs) + + def left_eigenvects(self, **flags): + return _left_eigenvects(self, **flags) + + def singular_values(self): + return _singular_values(self) + + eigenvals.__doc__ = _eigenvals.__doc__ + eigenvects.__doc__ = _eigenvects.__doc__ + is_diagonalizable.__doc__ = _is_diagonalizable.__doc__ + diagonalize.__doc__ = _diagonalize.__doc__ + is_positive_definite.__doc__ = _is_positive_definite.__doc__ + is_positive_semidefinite.__doc__ = _is_positive_semidefinite.__doc__ + is_negative_definite.__doc__ = _is_negative_definite.__doc__ + is_negative_semidefinite.__doc__ = _is_negative_semidefinite.__doc__ + is_indefinite.__doc__ = _is_indefinite.__doc__ + jordan_form.__doc__ = _jordan_form.__doc__ + left_eigenvects.__doc__ = _left_eigenvects.__doc__ + singular_values.__doc__ = _singular_values.__doc__ + bidiagonalize.__doc__ = _bidiagonalize.__doc__ + bidiagonal_decomposition.__doc__ = _bidiagonal_decomposition.__doc__ + + def diff(self, *args, evaluate=True, **kwargs): + """Calculate the derivative of each element in the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.diff(x) + Matrix([ + [1, 0], + [0, 0]]) + + See Also + ======== + + integrate + limit + """ + # XXX this should be handled here rather than in Derivative + from sympy.tensor.array.array_derivatives import ArrayDerivative + deriv = ArrayDerivative(self, *args, evaluate=evaluate) + # XXX This can rather changed to always return immutable matrix + if not isinstance(self, Basic) and evaluate: + return deriv.as_mutable() + return deriv + + def _eval_derivative(self, arg): + return self.applyfunc(lambda x: x.diff(arg)) + + def integrate(self, *args, **kwargs): + """Integrate each element of the matrix. ``args`` will + be passed to the ``integrate`` function. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.integrate((x, )) + Matrix([ + [x**2/2, x*y], + [ x, 0]]) + >>> M.integrate((x, 0, 2)) + Matrix([ + [2, 2*y], + [2, 0]]) + + See Also + ======== + + limit + diff + """ + return self.applyfunc(lambda x: x.integrate(*args, **kwargs)) + + def jacobian(self, X): + """Calculates the Jacobian matrix (derivative of a vector-valued function). + + Parameters + ========== + + ``self`` : vector of expressions representing functions f_i(x_1, ..., x_n). + X : set of x_i's in order, it can be a list or a Matrix + + Both ``self`` and X can be a row or a column matrix in any order + (i.e., jacobian() should always work). + + Examples + ======== + + >>> from sympy import sin, cos, Matrix + >>> from sympy.abc import rho, phi + >>> X = Matrix([rho*cos(phi), rho*sin(phi), rho**2]) + >>> Y = Matrix([rho, phi]) + >>> X.jacobian(Y) + Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)], + [ 2*rho, 0]]) + >>> X = Matrix([rho*cos(phi), rho*sin(phi)]) + >>> X.jacobian(Y) + Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)]]) + + See Also + ======== + + hessian + wronskian + """ + from sympy.matrices.matrixbase import MatrixBase + if not isinstance(X, MatrixBase): + X = self._new(X) + # Both X and ``self`` can be a row or a column matrix, so we need to make + # sure all valid combinations work, but everything else fails: + if self.shape[0] == 1: + m = self.shape[1] + elif self.shape[1] == 1: + m = self.shape[0] + else: + raise TypeError("``self`` must be a row or a column matrix") + if X.shape[0] == 1: + n = X.shape[1] + elif X.shape[1] == 1: + n = X.shape[0] + else: + raise TypeError("X must be a row or a column matrix") + + # m is the number of functions and n is the number of variables + # computing the Jacobian is now easy: + return self._new(m, n, lambda j, i: self[j].diff(X[i])) + + def limit(self, *args): + """Calculate the limit of each element in the matrix. + ``args`` will be passed to the ``limit`` function. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.limit(x, 2) + Matrix([ + [2, y], + [1, 0]]) + + See Also + ======== + + integrate + diff + """ + return self.applyfunc(lambda x: x.limit(*args)) + + def berkowitz_charpoly(self, x=Dummy('lambda'), simplify=_utilities_simplify): + return self.charpoly(x=x) + + def berkowitz_det(self): + """Computes determinant using Berkowitz method. + + See Also + ======== + + det + """ + return self.det(method='berkowitz') + + def berkowitz_eigenvals(self, **flags): + """Computes eigenvalues of a Matrix using Berkowitz method.""" + return self.eigenvals(**flags) + + def berkowitz_minors(self): + """Computes principal minors using Berkowitz method.""" + sign, minors = self.one, [] + + for poly in self.berkowitz(): + minors.append(sign * poly[-1]) + sign = -sign + + return tuple(minors) + + def berkowitz(self): + from sympy.matrices import zeros + berk = ((1,),) + if not self: + return berk + + if not self.is_square: + raise NonSquareMatrixError() + + A, N = self, self.rows + transforms = [0] * (N - 1) + + for n in range(N, 1, -1): + T, k = zeros(n + 1, n), n - 1 + + R, C = -A[k, :k], A[:k, k] + A, a = A[:k, :k], -A[k, k] + + items = [C] + + for i in range(0, n - 2): + items.append(A * items[i]) + + for i, B in enumerate(items): + items[i] = (R * B)[0, 0] + + items = [self.one, a] + items + + for i in range(n): + T[i:, i] = items[:n - i + 1] + + transforms[k - 1] = T + + polys = [self._new([self.one, -A[0, 0]])] + + for i, T in enumerate(transforms): + polys.append(T * polys[i]) + + return berk + tuple(map(tuple, polys)) + + def cofactorMatrix(self, method="berkowitz"): + return self.cofactor_matrix(method=method) + + def det_bareis(self): + return _det_bareiss(self) + + def det_LU_decomposition(self): + """Compute matrix determinant using LU decomposition. + + + Note that this method fails if the LU decomposition itself + fails. In particular, if the matrix has no inverse this method + will fail. + + TODO: Implement algorithm for sparse matrices (SFF), + http://www.eecis.udel.edu/~saunders/papers/sffge/it5.ps. + + See Also + ======== + + + det + berkowitz_det + """ + return self.det(method='lu') + + def jordan_cell(self, eigenval, n): + return self.jordan_block(size=n, eigenvalue=eigenval) + + def jordan_cells(self, calc_transformation=True): + P, J = self.jordan_form() + return P, J.get_diag_blocks() + + def minorEntry(self, i, j, method="berkowitz"): + return self.minor(i, j, method=method) + + def minorMatrix(self, i, j): + return self.minor_submatrix(i, j) + + def permuteBkwd(self, perm): + """Permute the rows of the matrix with the given permutation in reverse.""" + return self.permute_rows(perm, direction='backward') + + def permuteFwd(self, perm): + """Permute the rows of the matrix with the given permutation.""" + return self.permute_rows(perm, direction='forward') + + @property + def kind(self) -> MatrixKind: + elem_kinds = {e.kind for e in self.flat()} + if len(elem_kinds) == 1: + elemkind, = elem_kinds + else: + elemkind = UndefinedKind + return MatrixKind(elemkind) + + def flat(self): + """ + Returns a flat list of all elements in the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix([[0, 2], [3, 4]]) + >>> m.flat() + [0, 2, 3, 4] + + See Also + ======== + + tolist + values + """ + return [self[i, j] for i in range(self.rows) for j in range(self.cols)] + + def __array__(self, dtype=object, copy=None): + if copy is not None and not copy: + raise TypeError("Cannot implement copy=False when converting Matrix to ndarray") + from .dense import matrix2numpy + return matrix2numpy(self, dtype=dtype) + + def __len__(self): + """Return the number of elements of ``self``. + + Implemented mainly so bool(Matrix()) == False. + """ + return self.rows * self.cols + + def _matrix_pow_by_jordan_blocks(self, num): + from sympy.matrices import diag, MutableMatrix + + def jordan_cell_power(jc, n): + N = jc.shape[0] + l = jc[0,0] + if l.is_zero: + if N == 1 and n.is_nonnegative: + jc[0,0] = l**n + elif not (n.is_integer and n.is_nonnegative): + raise NonInvertibleMatrixError("Non-invertible matrix can only be raised to a nonnegative integer") + else: + for i in range(N): + jc[0,i] = KroneckerDelta(i, n) + else: + for i in range(N): + bn = binomial(n, i) + if isinstance(bn, binomial): + bn = bn._eval_expand_func() + jc[0,i] = l**(n-i)*bn + for i in range(N): + for j in range(1, N-i): + jc[j,i+j] = jc [j-1,i+j-1] + + P, J = self.jordan_form() + jordan_cells = J.get_diag_blocks() + # Make sure jordan_cells matrices are mutable: + jordan_cells = [MutableMatrix(j) for j in jordan_cells] + for j in jordan_cells: + jordan_cell_power(j, num) + return self._new(P.multiply(diag(*jordan_cells)) + .multiply(P.inv())) + + def __str__(self): + if S.Zero in self.shape: + return 'Matrix(%s, %s, [])' % (self.rows, self.cols) + return "Matrix(%s)" % str(self.tolist()) + + def _format_str(self, printer=None): + if not printer: + printer = StrPrinter() + # Handle zero dimensions: + if S.Zero in self.shape: + return 'Matrix(%s, %s, [])' % (self.rows, self.cols) + if self.rows == 1: + return "Matrix([%s])" % self.table(printer, rowsep=',\n') + return "Matrix([\n%s])" % self.table(printer, rowsep=',\n') + + @classmethod + def irregular(cls, ntop, *matrices, **kwargs): + """Return a matrix filled by the given matrices which + are listed in order of appearance from left to right, top to + bottom as they first appear in the matrix. They must fill the + matrix completely. + + Examples + ======== + + >>> from sympy import ones, Matrix + >>> Matrix.irregular(3, ones(2,1), ones(3,3)*2, ones(2,2)*3, + ... ones(1,1)*4, ones(2,2)*5, ones(1,2)*6, ones(1,2)*7) + Matrix([ + [1, 2, 2, 2, 3, 3], + [1, 2, 2, 2, 3, 3], + [4, 2, 2, 2, 5, 5], + [6, 6, 7, 7, 5, 5]]) + """ + ntop = as_int(ntop) + # make sure we are working with explicit matrices + b = [i.as_explicit() if hasattr(i, 'as_explicit') else i + for i in matrices] + q = list(range(len(b))) + dat = [i.rows for i in b] + active = [q.pop(0) for _ in range(ntop)] + cols = sum(b[i].cols for i in active) + rows = [] + while any(dat): + r = [] + for a, j in enumerate(active): + r.extend(b[j][-dat[j], :]) + dat[j] -= 1 + if dat[j] == 0 and q: + active[a] = q.pop(0) + if len(r) != cols: + raise ValueError(filldedent(''' + Matrices provided do not appear to fill + the space completely.''')) + rows.append(r) + return cls._new(rows) + + @classmethod + def _handle_ndarray(cls, arg): + # NumPy array or matrix or some other object that implements + # __array__. So let's first use this method to get a + # numpy.array() and then make a Python list out of it. + arr = arg.__array__() + if len(arr.shape) == 2: + rows, cols = arr.shape[0], arr.shape[1] + flat_list = [cls._sympify(i) for i in arr.ravel()] + return rows, cols, flat_list + elif len(arr.shape) == 1: + flat_list = [cls._sympify(i) for i in arr] + return arr.shape[0], 1, flat_list + else: + raise NotImplementedError( + "SymPy supports just 1D and 2D matrices") + + @classmethod + def _handle_creation_inputs(cls, *args, **kwargs): + """Return the number of rows, cols and flat matrix elements. + + Examples + ======== + + >>> from sympy import Matrix, I + + Matrix can be constructed as follows: + + * from a nested list of iterables + + >>> Matrix( ((1, 2+I), (3, 4)) ) + Matrix([ + [1, 2 + I], + [3, 4]]) + + * from un-nested iterable (interpreted as a column) + + >>> Matrix( [1, 2] ) + Matrix([ + [1], + [2]]) + + * from un-nested iterable with dimensions + + >>> Matrix(1, 2, [1, 2] ) + Matrix([[1, 2]]) + + * from no arguments (a 0 x 0 matrix) + + >>> Matrix() + Matrix(0, 0, []) + + * from a rule + + >>> Matrix(2, 2, lambda i, j: i/(j + 1) ) + Matrix([ + [0, 0], + [1, 1/2]]) + + See Also + ======== + irregular - filling a matrix with irregular blocks + """ + from sympy.matrices import SparseMatrix + from sympy.matrices.expressions.matexpr import MatrixSymbol + from sympy.matrices.expressions.blockmatrix import BlockMatrix + + flat_list = None + + if len(args) == 1: + # Matrix(SparseMatrix(...)) + if isinstance(args[0], SparseMatrix): + return args[0].rows, args[0].cols, flatten(args[0].tolist()) + + # Matrix(Matrix(...)) + elif isinstance(args[0], MatrixBase): + return args[0].rows, args[0].cols, args[0].flat() + + # Matrix(MatrixSymbol('X', 2, 2)) + elif isinstance(args[0], Basic) and args[0].is_Matrix: + return args[0].rows, args[0].cols, args[0].as_explicit().flat() + + elif isinstance(args[0], mp.matrix): + M = args[0] + flat_list = [cls._sympify(x) for x in M] + return M.rows, M.cols, flat_list + + # Matrix(numpy.ones((2, 2))) + elif hasattr(args[0], "__array__"): + return cls._handle_ndarray(args[0]) + + # Matrix([1, 2, 3]) or Matrix([[1, 2], [3, 4]]) + elif is_sequence(args[0]) \ + and not isinstance(args[0], DeferredVector): + dat = list(args[0]) + ismat = lambda i: isinstance(i, MatrixBase) and ( + evaluate or isinstance(i, (BlockMatrix, MatrixSymbol))) + raw = lambda i: is_sequence(i) and not ismat(i) + evaluate = kwargs.get('evaluate', True) + + + if evaluate: + + def make_explicit(x): + """make Block and Symbol explicit""" + if isinstance(x, BlockMatrix): + return x.as_explicit() + elif isinstance(x, MatrixSymbol) and all(_.is_Integer for _ in x.shape): + return x.as_explicit() + else: + return x + + def make_explicit_row(row): + # Could be list or could be list of lists + if isinstance(row, (list, tuple)): + return [make_explicit(x) for x in row] + else: + return make_explicit(row) + + if isinstance(dat, (list, tuple)): + dat = [make_explicit_row(row) for row in dat] + + if len(dat) == 0: + rows = cols = 0 + flat_list = [] + elif all(raw(i) for i in dat) and len(dat[0]) == 0: + if not all(len(i) == 0 for i in dat): + raise ValueError('mismatched dimensions') + rows = len(dat) + cols = 0 + flat_list = [] + elif not any(raw(i) or ismat(i) for i in dat): + # a column as a list of values + flat_list = [cls._sympify(i) for i in dat] + rows = len(flat_list) + cols = 1 if rows else 0 + elif evaluate and all(ismat(i) for i in dat): + # a column as a list of matrices + ncol = {i.cols for i in dat if any(i.shape)} + if ncol: + if len(ncol) != 1: + raise ValueError('mismatched dimensions') + flat_list = [_ for i in dat for r in i.tolist() for _ in r] + cols = ncol.pop() + rows = len(flat_list)//cols + else: + rows = cols = 0 + flat_list = [] + elif evaluate and any(ismat(i) for i in dat): + ncol = set() + flat_list = [] + for i in dat: + if ismat(i): + flat_list.extend( + [k for j in i.tolist() for k in j]) + if any(i.shape): + ncol.add(i.cols) + elif raw(i): + if i: + ncol.add(len(i)) + flat_list.extend([cls._sympify(ij) for ij in i]) + else: + ncol.add(1) + flat_list.append(i) + if len(ncol) > 1: + raise ValueError('mismatched dimensions') + cols = ncol.pop() + rows = len(flat_list)//cols + else: + # list of lists; each sublist is a logical row + # which might consist of many rows if the values in + # the row are matrices + flat_list = [] + ncol = set() + rows = cols = 0 + for row in dat: + if not is_sequence(row) and \ + not getattr(row, 'is_Matrix', False): + raise ValueError('expecting list of lists') + + if hasattr(row, '__array__'): + if 0 in row.shape: + continue + + if evaluate and all(ismat(i) for i in row): + r, c, flatT = cls._handle_creation_inputs( + [i.T for i in row]) + T = reshape(flatT, [c]) + flat = \ + [T[i][j] for j in range(c) for i in range(r)] + r, c = c, r + else: + r = 1 + if getattr(row, 'is_Matrix', False): + c = 1 + flat = [row] + else: + c = len(row) + flat = [cls._sympify(i) for i in row] + ncol.add(c) + if len(ncol) > 1: + raise ValueError('mismatched dimensions') + flat_list.extend(flat) + rows += r + cols = ncol.pop() if ncol else 0 + + elif len(args) == 3: + rows = as_int(args[0]) + cols = as_int(args[1]) + + if rows < 0 or cols < 0: + raise ValueError("Cannot create a {} x {} matrix. " + "Both dimensions must be positive".format(rows, cols)) + + # Matrix(2, 2, lambda i, j: i+j) + if len(args) == 3 and isinstance(args[2], Callable): + op = args[2] + flat_list = [] + for i in range(rows): + flat_list.extend( + [cls._sympify(op(cls._sympify(i), cls._sympify(j))) + for j in range(cols)]) + + # Matrix(2, 2, [1, 2, 3, 4]) + elif len(args) == 3 and is_sequence(args[2]): + flat_list = args[2] + if len(flat_list) != rows * cols: + raise ValueError( + 'List length should be equal to rows*columns') + flat_list = [cls._sympify(i) for i in flat_list] + + + # Matrix() + elif len(args) == 0: + # Empty Matrix + rows = cols = 0 + flat_list = [] + + if flat_list is None: + raise TypeError(filldedent(''' + Data type not understood; expecting list of lists + or lists of values.''')) + + return rows, cols, flat_list + + def _setitem(self, key, value): + """Helper to set value at location given by key. + + Examples + ======== + + >>> from sympy import Matrix, I, zeros, ones + >>> m = Matrix(((1, 2+I), (3, 4))) + >>> m + Matrix([ + [1, 2 + I], + [3, 4]]) + >>> m[1, 0] = 9 + >>> m + Matrix([ + [1, 2 + I], + [9, 4]]) + >>> m[1, 0] = [[0, 1]] + + To replace row r you assign to position r*m where m + is the number of columns: + + >>> M = zeros(4) + >>> m = M.cols + >>> M[3*m] = ones(1, m)*2; M + Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [2, 2, 2, 2]]) + + And to replace column c you can assign to position c: + + >>> M[2] = ones(m, 1)*4; M + Matrix([ + [0, 0, 4, 0], + [0, 0, 4, 0], + [0, 0, 4, 0], + [2, 2, 4, 2]]) + """ + from .dense import Matrix + + is_slice = isinstance(key, slice) + i, j = key = self.key2ij(key) + is_mat = isinstance(value, MatrixBase) + if isinstance(i, slice) or isinstance(j, slice): + if is_mat: + self.copyin_matrix(key, value) + return + if not isinstance(value, Expr) and is_sequence(value): + self.copyin_list(key, value) + return + raise ValueError('unexpected value: %s' % value) + else: + if (not is_mat and + not isinstance(value, Basic) and is_sequence(value)): + value = Matrix(value) + is_mat = True + if is_mat: + if is_slice: + key = (slice(*divmod(i, self.cols)), + slice(*divmod(j, self.cols))) + else: + key = (slice(i, i + value.rows), + slice(j, j + value.cols)) + self.copyin_matrix(key, value) + else: + return i, j, self._sympify(value) + return + + def add(self, b): + """Return self + b.""" + return self + b + + def condition_number(self): + """Returns the condition number of a matrix. + + This is the maximum singular value divided by the minimum singular value + + Examples + ======== + + >>> from sympy import Matrix, S + >>> A = Matrix([[1, 0, 0], [0, 10, 0], [0, 0, S.One/10]]) + >>> A.condition_number() + 100 + + See Also + ======== + + singular_values + """ + + if not self: + return self.zero + singularvalues = self.singular_values() + return Max(*singularvalues) / Min(*singularvalues) + + def copy(self): + """ + Returns the copy of a matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(2, 2, [1, 2, 3, 4]) + >>> A.copy() + Matrix([ + [1, 2], + [3, 4]]) + + """ + return self._new(self.rows, self.cols, self.flat()) + + def cross(self, b): + r""" + Return the cross product of ``self`` and ``b`` relaxing the condition + of compatible dimensions: if each has 3 elements, a matrix of the + same type and shape as ``self`` will be returned. If ``b`` has the same + shape as ``self`` then common identities for the cross product (like + `a \times b = - b \times a`) will hold. + + Parameters + ========== + b : 3x1 or 1x3 Matrix + + See Also + ======== + + dot + hat + vee + multiply + multiply_elementwise + """ + from sympy.matrices.expressions.matexpr import MatrixExpr + + if not isinstance(b, (MatrixBase, MatrixExpr)): + raise TypeError( + "{} must be a Matrix, not {}.".format(b, type(b))) + + if not (self.rows * self.cols == b.rows * b.cols == 3): + raise ShapeError("Dimensions incorrect for cross product: %s x %s" % + ((self.rows, self.cols), (b.rows, b.cols))) + else: + return self._new(self.rows, self.cols, ( + (self[1] * b[2] - self[2] * b[1]), + (self[2] * b[0] - self[0] * b[2]), + (self[0] * b[1] - self[1] * b[0]))) + + def hat(self): + r""" + Return the skew-symmetric matrix representing the cross product, + so that ``self.hat() * b`` is equivalent to ``self.cross(b)``. + + Examples + ======== + + Calling ``hat`` creates a skew-symmetric 3x3 Matrix from a 3x1 Matrix: + + >>> from sympy import Matrix + >>> a = Matrix([1, 2, 3]) + >>> a.hat() + Matrix([ + [ 0, -3, 2], + [ 3, 0, -1], + [-2, 1, 0]]) + + Multiplying it with another 3x1 Matrix calculates the cross product: + + >>> b = Matrix([3, 2, 1]) + >>> a.hat() * b + Matrix([ + [-4], + [ 8], + [-4]]) + + Which is equivalent to calling the ``cross`` method: + + >>> a.cross(b) + Matrix([ + [-4], + [ 8], + [-4]]) + + See Also + ======== + + dot + cross + vee + multiply + multiply_elementwise + """ + + if self.shape != (3, 1): + raise ShapeError("Dimensions incorrect, expected (3, 1), got " + + str(self.shape)) + else: + x, y, z = self + return self._new(3, 3, ( + 0, -z, y, + z, 0, -x, + -y, x, 0)) + + def vee(self): + r""" + Return a 3x1 vector from a skew-symmetric matrix representing the cross product, + so that ``self * b`` is equivalent to ``self.vee().cross(b)``. + + Examples + ======== + + Calling ``vee`` creates a vector from a skew-symmetric Matrix: + + >>> from sympy import Matrix + >>> A = Matrix([[0, -3, 2], [3, 0, -1], [-2, 1, 0]]) + >>> a = A.vee() + >>> a + Matrix([ + [1], + [2], + [3]]) + + Calculating the matrix product of the original matrix with a vector + is equivalent to a cross product: + + >>> b = Matrix([3, 2, 1]) + >>> A * b + Matrix([ + [-4], + [ 8], + [-4]]) + + >>> a.cross(b) + Matrix([ + [-4], + [ 8], + [-4]]) + + ``vee`` can also be used to retrieve angular velocity expressions. + Defining a rotation matrix: + + >>> from sympy import rot_ccw_axis3, trigsimp + >>> from sympy.physics.mechanics import dynamicsymbols + >>> theta = dynamicsymbols('theta') + >>> R = rot_ccw_axis3(theta) + >>> R + Matrix([ + [cos(theta(t)), -sin(theta(t)), 0], + [sin(theta(t)), cos(theta(t)), 0], + [ 0, 0, 1]]) + + We can retrieve the angular velocity: + + >>> Omega = R.T * R.diff() + >>> Omega = trigsimp(Omega) + >>> Omega.vee() + Matrix([ + [ 0], + [ 0], + [Derivative(theta(t), t)]]) + + See Also + ======== + + dot + cross + hat + multiply + multiply_elementwise + """ + + if self.shape != (3, 3): + raise ShapeError("Dimensions incorrect, expected (3, 3), got " + + str(self.shape)) + elif not self.is_anti_symmetric(): + raise ValueError("Matrix is not skew-symmetric") + else: + return self._new(3, 1, ( + self[2, 1], + self[0, 2], + self[1, 0])) + + @property + def D(self): + """Return Dirac conjugate (if ``self.rows == 4``). + + Examples + ======== + + >>> from sympy import Matrix, I, eye + >>> m = Matrix((0, 1 + I, 2, 3)) + >>> m.D + Matrix([[0, 1 - I, -2, -3]]) + >>> m = (eye(4) + I*eye(4)) + >>> m[0, 3] = 2 + >>> m.D + Matrix([ + [1 - I, 0, 0, 0], + [ 0, 1 - I, 0, 0], + [ 0, 0, -1 + I, 0], + [ 2, 0, 0, -1 + I]]) + + If the matrix does not have 4 rows an AttributeError will be raised + because this property is only defined for matrices with 4 rows. + + >>> Matrix(eye(2)).D + Traceback (most recent call last): + ... + AttributeError: Matrix has no attribute D. + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.conjugate: By-element conjugation + sympy.matrices.matrixbase.MatrixBase.H: Hermite conjugation + """ + from sympy.physics.matrices import mgamma + if self.rows != 4: + # In Python 3.2, properties can only return an AttributeError + # so we can't raise a ShapeError -- see commit which added the + # first line of this inline comment. Also, there is no need + # for a message since MatrixBase will raise the AttributeError + raise AttributeError + return self.H * mgamma(0) + + def dot(self, b, hermitian=None, conjugate_convention=None): + """Return the dot or inner product of two vectors of equal length. + Here ``self`` must be a ``Matrix`` of size 1 x n or n x 1, and ``b`` + must be either a matrix of size 1 x n, n x 1, or a list/tuple of length n. + A scalar is returned. + + By default, ``dot`` does not conjugate ``self`` or ``b``, even if there are + complex entries. Set ``hermitian=True`` (and optionally a ``conjugate_convention``) + to compute the hermitian inner product. + + Possible kwargs are ``hermitian`` and ``conjugate_convention``. + + If ``conjugate_convention`` is ``"left"``, ``"math"`` or ``"maths"``, + the conjugate of the first vector (``self``) is used. If ``"right"`` + or ``"physics"`` is specified, the conjugate of the second vector ``b`` is used. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> v = Matrix([1, 1, 1]) + >>> M.row(0).dot(v) + 6 + >>> M.col(0).dot(v) + 12 + >>> v = [3, 2, 1] + >>> M.row(0).dot(v) + 10 + + >>> from sympy import I + >>> q = Matrix([1*I, 1*I, 1*I]) + >>> q.dot(q, hermitian=False) + -3 + + >>> q.dot(q, hermitian=True) + 3 + + >>> q1 = Matrix([1, 1, 1*I]) + >>> q.dot(q1, hermitian=True, conjugate_convention="maths") + 1 - 2*I + >>> q.dot(q1, hermitian=True, conjugate_convention="physics") + 1 + 2*I + + + See Also + ======== + + cross + multiply + multiply_elementwise + """ + from .dense import Matrix + + if not isinstance(b, MatrixBase): + if is_sequence(b): + if len(b) != self.cols and len(b) != self.rows: + raise ShapeError( + "Dimensions incorrect for dot product: %s, %s" % ( + self.shape, len(b))) + return self.dot(Matrix(b)) + else: + raise TypeError( + "`b` must be an ordered iterable or Matrix, not %s." % + type(b)) + + if (1 not in self.shape) or (1 not in b.shape): + raise ShapeError + if len(self) != len(b): + raise ShapeError( + "Dimensions incorrect for dot product: %s, %s" % (self.shape, b.shape)) + + mat = self + n = len(mat) + if mat.shape != (1, n): + mat = mat.reshape(1, n) + if b.shape != (n, 1): + b = b.reshape(n, 1) + + # Now ``mat`` is a row vector and ``b`` is a column vector. + + # If it so happens that only conjugate_convention is passed + # then automatically set hermitian to True. If only hermitian + # is true but no conjugate_convention is not passed then + # automatically set it to ``"maths"`` + + if conjugate_convention is not None and hermitian is None: + hermitian = True + if hermitian and conjugate_convention is None: + conjugate_convention = "maths" + + if hermitian == True: + if conjugate_convention in ("maths", "left", "math"): + mat = mat.conjugate() + elif conjugate_convention in ("physics", "right"): + b = b.conjugate() + else: + raise ValueError("Unknown conjugate_convention was entered." + " conjugate_convention must be one of the" + " following: math, maths, left, physics or right.") + return (mat * b)[0] + + def dual(self): + """Returns the dual of a matrix. + + A dual of a matrix is: + + ``(1/2)*levicivita(i, j, k, l)*M(k, l)`` summed over indices `k` and `l` + + Since the levicivita method is anti_symmetric for any pairwise + exchange of indices, the dual of a symmetric matrix is the zero + matrix. Strictly speaking the dual defined here assumes that the + 'matrix' `M` is a contravariant anti_symmetric second rank tensor, + so that the dual is a covariant second rank tensor. + + """ + from sympy.matrices import zeros + + M, n = self[:, :], self.rows + work = zeros(n) + if self.is_symmetric(): + return work + + for i in range(1, n): + for j in range(1, n): + acum = 0 + for k in range(1, n): + acum += LeviCivita(i, j, 0, k) * M[0, k] + work[i, j] = acum + work[j, i] = -acum + + for l in range(1, n): + acum = 0 + for a in range(1, n): + for b in range(1, n): + acum += LeviCivita(0, l, a, b) * M[a, b] + acum /= 2 + work[0, l] = -acum + work[l, 0] = acum + + return work + + def _eval_matrix_exp_jblock(self): + """A helper function to compute an exponential of a Jordan block + matrix + + Examples + ======== + + >>> from sympy import Symbol, Matrix + >>> l = Symbol('lamda') + + A trivial example of 1*1 Jordan block: + + >>> m = Matrix.jordan_block(1, l) + >>> m._eval_matrix_exp_jblock() + Matrix([[exp(lamda)]]) + + An example of 3*3 Jordan block: + + >>> m = Matrix.jordan_block(3, l) + >>> m._eval_matrix_exp_jblock() + Matrix([ + [exp(lamda), exp(lamda), exp(lamda)/2], + [ 0, exp(lamda), exp(lamda)], + [ 0, 0, exp(lamda)]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Matrix_function#Jordan_decomposition + """ + size = self.rows + l = self[0, 0] + exp_l = exp(l) + + bands = {i: exp_l / factorial(i) for i in range(size)} + + from .sparsetools import banded + return self.__class__(banded(size, bands)) + + + def analytic_func(self, f, x): + """ + Computes f(A) where A is a Square Matrix + and f is an analytic function. + + Examples + ======== + + >>> from sympy import Symbol, Matrix, S, log + + >>> x = Symbol('x') + >>> m = Matrix([[S(5)/4, S(3)/4], [S(3)/4, S(5)/4]]) + >>> f = log(x) + >>> m.analytic_func(f, x) + Matrix([ + [ 0, log(2)], + [log(2), 0]]) + + Parameters + ========== + + f : Expr + Analytic Function + x : Symbol + parameter of f + + """ + + f, x = _sympify(f), _sympify(x) + if not self.is_square: + raise NonSquareMatrixError + if not x.is_symbol: + raise ValueError("{} must be a symbol.".format(x)) + if x not in f.free_symbols: + raise ValueError( + "{} must be a parameter of {}.".format(x, f)) + if x in self.free_symbols: + raise ValueError( + "{} must not be a parameter of {}.".format(x, self)) + + eigen = self.eigenvals() + max_mul = max(eigen.values()) + derivative = {} + dd = f + for i in range(max_mul - 1): + dd = diff(dd, x) + derivative[i + 1] = dd + n = self.shape[0] + r = self.zeros(n) + f_val = self.zeros(n, 1) + row = 0 + + for i in eigen: + mul = eigen[i] + f_val[row] = f.subs(x, i) + if f_val[row].is_number and not f_val[row].is_complex: + raise ValueError( + "Cannot evaluate the function because the " + "function {} is not analytic at the given " + "eigenvalue {}".format(f, f_val[row])) + val = 1 + for a in range(n): + r[row, a] = val + val *= i + if mul > 1: + coe = [1 for ii in range(n)] + deri = 1 + while mul > 1: + row = row + 1 + mul -= 1 + d_i = derivative[deri].subs(x, i) + if d_i.is_number and not d_i.is_complex: + raise ValueError( + "Cannot evaluate the function because the " + "derivative {} is not analytic at the given " + "eigenvalue {}".format(derivative[deri], d_i)) + f_val[row] = d_i + for a in range(n): + if a - deri + 1 <= 0: + r[row, a] = 0 + coe[a] = 0 + continue + coe[a] = coe[a]*(a - deri + 1) + r[row, a] = coe[a]*pow(i, a - deri) + deri += 1 + row += 1 + c = r.solve(f_val) + ans = self.zeros(n) + pre = self.eye(n) + for i in range(n): + ans = ans + c[i]*pre + pre *= self + return ans + + + def exp(self): + """Return the exponential of a square matrix. + + Examples + ======== + + >>> from sympy import Symbol, Matrix + + >>> t = Symbol('t') + >>> m = Matrix([[0, 1], [-1, 0]]) * t + >>> m.exp() + Matrix([ + [ exp(I*t)/2 + exp(-I*t)/2, -I*exp(I*t)/2 + I*exp(-I*t)/2], + [I*exp(I*t)/2 - I*exp(-I*t)/2, exp(I*t)/2 + exp(-I*t)/2]]) + """ + if not self.is_square: + raise NonSquareMatrixError( + "Exponentiation is valid only for square matrices") + try: + P, J = self.jordan_form() + cells = J.get_diag_blocks() + except MatrixError: + raise NotImplementedError( + "Exponentiation is implemented only for matrices for which the Jordan normal form can be computed") + + blocks = [cell._eval_matrix_exp_jblock() for cell in cells] + from sympy.matrices import diag + eJ = diag(*blocks) + # n = self.rows + ret = P.multiply(eJ, dotprodsimp=None).multiply(P.inv(), dotprodsimp=None) + if all(value.is_real for value in self.values()): + return type(self)(re(ret)) + else: + return type(self)(ret) + + def _eval_matrix_log_jblock(self): + """Helper function to compute logarithm of a jordan block. + + Examples + ======== + + >>> from sympy import Symbol, Matrix + >>> l = Symbol('lamda') + + A trivial example of 1*1 Jordan block: + + >>> m = Matrix.jordan_block(1, l) + >>> m._eval_matrix_log_jblock() + Matrix([[log(lamda)]]) + + An example of 3*3 Jordan block: + + >>> m = Matrix.jordan_block(3, l) + >>> m._eval_matrix_log_jblock() + Matrix([ + [log(lamda), 1/lamda, -1/(2*lamda**2)], + [ 0, log(lamda), 1/lamda], + [ 0, 0, log(lamda)]]) + """ + size = self.rows + l = self[0, 0] + + if l.is_zero: + raise MatrixError( + 'Could not take logarithm or reciprocal for the given ' + 'eigenvalue {}'.format(l)) + + bands = {0: log(l)} + for i in range(1, size): + bands[i] = -((-l) ** -i) / i + + from .sparsetools import banded + return self.__class__(banded(size, bands)) + + def log(self, simplify=cancel): + """Return the logarithm of a square matrix. + + Parameters + ========== + + simplify : function, bool + The function to simplify the result with. + + Default is ``cancel``, which is effective to reduce the + expression growing for taking reciprocals and inverses for + symbolic matrices. + + Examples + ======== + + >>> from sympy import S, Matrix + + Examples for positive-definite matrices: + + >>> m = Matrix([[1, 1], [0, 1]]) + >>> m.log() + Matrix([ + [0, 1], + [0, 0]]) + + >>> m = Matrix([[S(5)/4, S(3)/4], [S(3)/4, S(5)/4]]) + >>> m.log() + Matrix([ + [ 0, log(2)], + [log(2), 0]]) + + Examples for non positive-definite matrices: + + >>> m = Matrix([[S(3)/4, S(5)/4], [S(5)/4, S(3)/4]]) + >>> m.log() + Matrix([ + [ I*pi/2, log(2) - I*pi/2], + [log(2) - I*pi/2, I*pi/2]]) + + >>> m = Matrix( + ... [[0, 0, 0, 1], + ... [0, 0, 1, 0], + ... [0, 1, 0, 0], + ... [1, 0, 0, 0]]) + >>> m.log() + Matrix([ + [ I*pi/2, 0, 0, -I*pi/2], + [ 0, I*pi/2, -I*pi/2, 0], + [ 0, -I*pi/2, I*pi/2, 0], + [-I*pi/2, 0, 0, I*pi/2]]) + """ + if not self.is_square: + raise NonSquareMatrixError( + "Logarithm is valid only for square matrices") + + try: + if simplify: + P, J = simplify(self).jordan_form() + else: + P, J = self.jordan_form() + + cells = J.get_diag_blocks() + except MatrixError: + raise NotImplementedError( + "Logarithm is implemented only for matrices for which " + "the Jordan normal form can be computed") + + blocks = [ + cell._eval_matrix_log_jblock() + for cell in cells] + from sympy.matrices import diag + eJ = diag(*blocks) + + if simplify: + ret = simplify(P * eJ * simplify(P.inv())) + ret = self.__class__(ret) + else: + ret = P * eJ * P.inv() + + return ret + + def is_nilpotent(self): + """Checks if a matrix is nilpotent. + + A matrix B is nilpotent if for some integer k, B**k is + a zero matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[0, 0, 0], [1, 0, 0], [1, 1, 0]]) + >>> a.is_nilpotent() + True + + >>> a = Matrix([[1, 0, 1], [1, 0, 0], [1, 1, 0]]) + >>> a.is_nilpotent() + False + """ + if not self: + return True + if not self.is_square: + raise NonSquareMatrixError( + "Nilpotency is valid only for square matrices") + x = uniquely_named_symbol('x', self, modify=lambda s: '_' + s) + p = self.charpoly(x) + if p.args[0] == x ** self.rows: + return True + return False + + def key2bounds(self, keys): + """Converts a key with potentially mixed types of keys (integer and slice) + into a tuple of ranges and raises an error if any index is out of ``self``'s + range. + + See Also + ======== + + key2ij + """ + islice, jslice = [isinstance(k, slice) for k in keys] + if islice: + if not self.rows: + rlo = rhi = 0 + else: + rlo, rhi = keys[0].indices(self.rows)[:2] + else: + rlo = a2idx(keys[0], self.rows) + rhi = rlo + 1 + if jslice: + if not self.cols: + clo = chi = 0 + else: + clo, chi = keys[1].indices(self.cols)[:2] + else: + clo = a2idx(keys[1], self.cols) + chi = clo + 1 + return rlo, rhi, clo, chi + + def key2ij(self, key): + """Converts key into canonical form, converting integers or indexable + items into valid integers for ``self``'s range or returning slices + unchanged. + + See Also + ======== + + key2bounds + """ + if is_sequence(key): + if not len(key) == 2: + raise TypeError('key must be a sequence of length 2') + return [a2idx(i, n) if not isinstance(i, slice) else i + for i, n in zip(key, self.shape)] + elif isinstance(key, slice): + return key.indices(len(self))[:2] + else: + return divmod(a2idx(key, len(self)), self.cols) + + def normalized(self, iszerofunc=_iszero): + """Return the normalized version of ``self``. + + Parameters + ========== + + iszerofunc : Function, optional + A function to determine whether ``self`` is a zero vector. + The default ``_iszero`` tests to see if each element is + exactly zero. + + Returns + ======= + + Matrix + Normalized vector form of ``self``. + It has the same length as a unit vector. However, a zero vector + will be returned for a vector with norm 0. + + Raises + ====== + + ShapeError + If the matrix is not in a vector form. + + See Also + ======== + + norm + """ + if self.rows != 1 and self.cols != 1: + raise ShapeError("A Matrix must be a vector to normalize.") + norm = self.norm() + if iszerofunc(norm): + out = self.zeros(self.rows, self.cols) + else: + out = self.applyfunc(lambda i: i / norm) + return out + + def norm(self, ord=None): + """Return the Norm of a Matrix or Vector. + + In the simplest case this is the geometric size of the vector + Other norms can be specified by the ord parameter + + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm - does not exist + inf maximum row sum max(abs(x)) + -inf -- min(abs(x)) + 1 maximum column sum as below + -1 -- as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other - does not exist sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Examples + ======== + + >>> from sympy import Matrix, Symbol, trigsimp, cos, sin, oo + >>> x = Symbol('x', real=True) + >>> v = Matrix([cos(x), sin(x)]) + >>> trigsimp( v.norm() ) + 1 + >>> v.norm(10) + (sin(x)**10 + cos(x)**10)**(1/10) + >>> A = Matrix([[1, 1], [1, 1]]) + >>> A.norm(1) # maximum sum of absolute values of A is 2 + 2 + >>> A.norm(2) # Spectral norm (max of |Ax|/|x| under 2-vector-norm) + 2 + >>> A.norm(-2) # Inverse spectral norm (smallest singular value) + 0 + >>> A.norm() # Frobenius Norm + 2 + >>> A.norm(oo) # Infinity Norm + 2 + >>> Matrix([1, -2]).norm(oo) + 2 + >>> Matrix([-1, 2]).norm(-oo) + 1 + + See Also + ======== + + normalized + """ + # Row or Column Vector Norms + vals = list(self.values()) or [0] + if S.One in self.shape: + if ord in (2, None): # Common case sqrt() + return sqrt(Add(*(abs(i) ** 2 for i in vals))) + + elif ord == 1: # sum(abs(x)) + return Add(*(abs(i) for i in vals)) + + elif ord is S.Infinity: # max(abs(x)) + return Max(*[abs(i) for i in vals]) + + elif ord is S.NegativeInfinity: # min(abs(x)) + return Min(*[abs(i) for i in vals]) + + # Otherwise generalize the 2-norm, Sum(x_i**ord)**(1/ord) + # Note that while useful this is not mathematically a norm + try: + return Pow(Add(*(abs(i) ** ord for i in vals)), S.One / ord) + except (NotImplementedError, TypeError): + raise ValueError("Expected order to be Number, Symbol, oo") + + # Matrix Norms + else: + if ord == 1: # Maximum column sum + m = self.applyfunc(abs) + return Max(*[sum(m.col(i)) for i in range(m.cols)]) + + elif ord == 2: # Spectral Norm + # Maximum singular value + return Max(*self.singular_values()) + + elif ord == -2: + # Minimum singular value + return Min(*self.singular_values()) + + elif ord is S.Infinity: # Infinity Norm - Maximum row sum + m = self.applyfunc(abs) + return Max(*[sum(m.row(i)) for i in range(m.rows)]) + + elif (ord is None or isinstance(ord, + str) and ord.lower() in + ['f', 'fro', 'frobenius', 'vector']): + # Reshape as vector and send back to norm function + return self.vec().norm(ord=2) + + else: + raise NotImplementedError("Matrix Norms under development") + + def print_nonzero(self, symb="X"): + """Shows location of non-zero entries for fast shape lookup. + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> m = Matrix(2, 3, lambda i, j: i*3+j) + >>> m + Matrix([ + [0, 1, 2], + [3, 4, 5]]) + >>> m.print_nonzero() + [ XX] + [XXX] + >>> m = eye(4) + >>> m.print_nonzero("x") + [x ] + [ x ] + [ x ] + [ x] + + """ + s = [] + for i in range(self.rows): + line = [] + for j in range(self.cols): + if self[i, j] == 0: + line.append(" ") + else: + line.append(str(symb)) + s.append("[%s]" % ''.join(line)) + print('\n'.join(s)) + + def project(self, v): + """Return the projection of ``self`` onto the line containing ``v``. + + Examples + ======== + + >>> from sympy import Matrix, S, sqrt + >>> V = Matrix([sqrt(3)/2, S.Half]) + >>> x = Matrix([[1, 0]]) + >>> V.project(x) + Matrix([[sqrt(3)/2, 0]]) + >>> V.project(-x) + Matrix([[sqrt(3)/2, 0]]) + """ + return v * (self.dot(v) / v.dot(v)) + + def table(self, printer, rowstart='[', rowend=']', rowsep='\n', + colsep=', ', align='right'): + r""" + String form of Matrix as a table. + + ``printer`` is the printer to use for on the elements (generally + something like StrPrinter()) + + ``rowstart`` is the string used to start each row (by default '['). + + ``rowend`` is the string used to end each row (by default ']'). + + ``rowsep`` is the string used to separate rows (by default a newline). + + ``colsep`` is the string used to separate columns (by default ', '). + + ``align`` defines how the elements are aligned. Must be one of 'left', + 'right', or 'center'. You can also use '<', '>', and '^' to mean the + same thing, respectively. + + This is used by the string printer for Matrix. + + Examples + ======== + + >>> from sympy import Matrix, StrPrinter + >>> M = Matrix([[1, 2], [-33, 4]]) + >>> printer = StrPrinter() + >>> M.table(printer) + '[ 1, 2]\n[-33, 4]' + >>> print(M.table(printer)) + [ 1, 2] + [-33, 4] + >>> print(M.table(printer, rowsep=',\n')) + [ 1, 2], + [-33, 4] + >>> print('[%s]' % M.table(printer, rowsep=',\n')) + [[ 1, 2], + [-33, 4]] + >>> print(M.table(printer, colsep=' ')) + [ 1 2] + [-33 4] + >>> print(M.table(printer, align='center')) + [ 1 , 2] + [-33, 4] + >>> print(M.table(printer, rowstart='{', rowend='}')) + { 1, 2} + {-33, 4} + """ + # Handle zero dimensions: + if S.Zero in self.shape: + return '[]' + # Build table of string representations of the elements + res = [] + # Track per-column max lengths for pretty alignment + maxlen = [0] * self.cols + for i in range(self.rows): + res.append([]) + for j in range(self.cols): + s = printer._print(self[i, j]) + res[-1].append(s) + maxlen[j] = max(len(s), maxlen[j]) + # Patch strings together + align = { + 'left': 'ljust', + 'right': 'rjust', + 'center': 'center', + '<': 'ljust', + '>': 'rjust', + '^': 'center', + }[align] + for i, row in enumerate(res): + for j, elem in enumerate(row): + row[j] = getattr(elem, align)(maxlen[j]) + res[i] = rowstart + colsep.join(row) + rowend + return rowsep.join(res) + + def rank_decomposition(self, iszerofunc=_iszero, simplify=False): + return _rank_decomposition(self, iszerofunc=iszerofunc, + simplify=simplify) + + def cholesky(self, hermitian=True): + raise NotImplementedError('This function is implemented in DenseMatrix or SparseMatrix') + + def LDLdecomposition(self, hermitian=True): + raise NotImplementedError('This function is implemented in DenseMatrix or SparseMatrix') + + def LUdecomposition(self, iszerofunc=_iszero, simpfunc=None, + rankcheck=False): + return _LUdecomposition(self, iszerofunc=iszerofunc, simpfunc=simpfunc, + rankcheck=rankcheck) + + def LUdecomposition_Simple(self, iszerofunc=_iszero, simpfunc=None, + rankcheck=False): + return _LUdecomposition_Simple(self, iszerofunc=iszerofunc, + simpfunc=simpfunc, rankcheck=rankcheck) + + def LUdecompositionFF(self): + return _LUdecompositionFF(self) + + def singular_value_decomposition(self): + return _singular_value_decomposition(self) + + def QRdecomposition(self): + return _QRdecomposition(self) + + def upper_hessenberg_decomposition(self): + return _upper_hessenberg_decomposition(self) + + def diagonal_solve(self, rhs): + return _diagonal_solve(self, rhs) + + def lower_triangular_solve(self, rhs): + raise NotImplementedError('This function is implemented in DenseMatrix or SparseMatrix') + + def upper_triangular_solve(self, rhs): + raise NotImplementedError('This function is implemented in DenseMatrix or SparseMatrix') + + def cholesky_solve(self, rhs): + return _cholesky_solve(self, rhs) + + def LDLsolve(self, rhs): + return _LDLsolve(self, rhs) + + def LUsolve(self, rhs, iszerofunc=_iszero): + return _LUsolve(self, rhs, iszerofunc=iszerofunc) + + def QRsolve(self, b): + return _QRsolve(self, b) + + def gauss_jordan_solve(self, B, freevar=False): + return _gauss_jordan_solve(self, B, freevar=freevar) + + def pinv_solve(self, B, arbitrary_matrix=None): + return _pinv_solve(self, B, arbitrary_matrix=arbitrary_matrix) + + def cramer_solve(self, rhs, det_method="laplace"): + return _cramer_solve(self, rhs, det_method=det_method) + + def solve(self, rhs, method='GJ'): + return _solve(self, rhs, method=method) + + def solve_least_squares(self, rhs, method='CH'): + return _solve_least_squares(self, rhs, method=method) + + def pinv(self, method='RD'): + return _pinv(self, method=method) + + def inverse_ADJ(self, iszerofunc=_iszero): + return _inv_ADJ(self, iszerofunc=iszerofunc) + + def inverse_BLOCK(self, iszerofunc=_iszero): + return _inv_block(self, iszerofunc=iszerofunc) + + def inverse_GE(self, iszerofunc=_iszero): + return _inv_GE(self, iszerofunc=iszerofunc) + + def inverse_LU(self, iszerofunc=_iszero): + return _inv_LU(self, iszerofunc=iszerofunc) + + def inverse_CH(self, iszerofunc=_iszero): + return _inv_CH(self, iszerofunc=iszerofunc) + + def inverse_LDL(self, iszerofunc=_iszero): + return _inv_LDL(self, iszerofunc=iszerofunc) + + def inverse_QR(self, iszerofunc=_iszero): + return _inv_QR(self, iszerofunc=iszerofunc) + + def inv(self, method=None, iszerofunc=_iszero, try_block_diag=False): + return _inv(self, method=method, iszerofunc=iszerofunc, + try_block_diag=try_block_diag) + + def connected_components(self): + return _connected_components(self) + + def connected_components_decomposition(self): + return _connected_components_decomposition(self) + + def strongly_connected_components(self): + return _strongly_connected_components(self) + + def strongly_connected_components_decomposition(self, lower=True): + return _strongly_connected_components_decomposition(self, lower=lower) + + _sage_ = Basic._sage_ + + rank_decomposition.__doc__ = _rank_decomposition.__doc__ + cholesky.__doc__ = _cholesky.__doc__ + LDLdecomposition.__doc__ = _LDLdecomposition.__doc__ + LUdecomposition.__doc__ = _LUdecomposition.__doc__ + LUdecomposition_Simple.__doc__ = _LUdecomposition_Simple.__doc__ + LUdecompositionFF.__doc__ = _LUdecompositionFF.__doc__ + singular_value_decomposition.__doc__ = _singular_value_decomposition.__doc__ + QRdecomposition.__doc__ = _QRdecomposition.__doc__ + upper_hessenberg_decomposition.__doc__ = _upper_hessenberg_decomposition.__doc__ + + diagonal_solve.__doc__ = _diagonal_solve.__doc__ + lower_triangular_solve.__doc__ = _lower_triangular_solve.__doc__ + upper_triangular_solve.__doc__ = _upper_triangular_solve.__doc__ + cholesky_solve.__doc__ = _cholesky_solve.__doc__ + LDLsolve.__doc__ = _LDLsolve.__doc__ + LUsolve.__doc__ = _LUsolve.__doc__ + QRsolve.__doc__ = _QRsolve.__doc__ + gauss_jordan_solve.__doc__ = _gauss_jordan_solve.__doc__ + pinv_solve.__doc__ = _pinv_solve.__doc__ + cramer_solve.__doc__ = _cramer_solve.__doc__ + solve.__doc__ = _solve.__doc__ + solve_least_squares.__doc__ = _solve_least_squares.__doc__ + + pinv.__doc__ = _pinv.__doc__ + inverse_ADJ.__doc__ = _inv_ADJ.__doc__ + inverse_GE.__doc__ = _inv_GE.__doc__ + inverse_LU.__doc__ = _inv_LU.__doc__ + inverse_CH.__doc__ = _inv_CH.__doc__ + inverse_LDL.__doc__ = _inv_LDL.__doc__ + inverse_QR.__doc__ = _inv_QR.__doc__ + inverse_BLOCK.__doc__ = _inv_block.__doc__ + inv.__doc__ = _inv.__doc__ + + connected_components.__doc__ = _connected_components.__doc__ + connected_components_decomposition.__doc__ = \ + _connected_components_decomposition.__doc__ + strongly_connected_components.__doc__ = \ + _strongly_connected_components.__doc__ + strongly_connected_components_decomposition.__doc__ = \ + _strongly_connected_components_decomposition.__doc__ + + +def _convert_matrix(typ, mat): + """Convert mat to a Matrix of type typ.""" + from sympy.matrices.matrixbase import MatrixBase + if getattr(mat, "is_Matrix", False) and not isinstance(mat, MatrixBase): + # This is needed for interop between Matrix and the redundant matrix + # mixin types like _MinimalMatrix etc. If anyone should happen to be + # using those then this keeps them working. Really _MinimalMatrix etc + # should be deprecated and removed though. + return typ(*mat.shape, list(mat)) + else: + return typ(mat) + + +def _has_matrix_shape(other): + shape = getattr(other, 'shape', None) + if shape is None: + return False + return isinstance(shape, tuple) and len(shape) == 2 + + +def _has_rows_cols(other): + return hasattr(other, 'rows') and hasattr(other, 'cols') + + +def _coerce_operand(self, other): + """Convert other to a Matrix, or check for possible scalar.""" + + INVALID = None, 'invalid_type' + + # Disallow mixing Matrix and Array + if isinstance(other, NDimArray): + return INVALID + + is_Matrix = getattr(other, 'is_Matrix', None) + + # Return a Matrix as-is + if is_Matrix: + return other, 'is_matrix' + + # Try to convert numpy array, mpmath matrix etc. + if is_Matrix is None: + if _has_matrix_shape(other) or _has_rows_cols(other): + return _convert_matrix(type(self), other), 'is_matrix' + + # Could be a scalar but only if not iterable... + if not isinstance(other, Iterable): + return other, 'possible_scalar' + + return INVALID + + +def classof(A, B): + """ + Get the type of the result when combining matrices of different types. + + Currently the strategy is that immutability is contagious. + + Examples + ======== + + >>> from sympy import Matrix, ImmutableMatrix + >>> from sympy.matrices.matrixbase import classof + >>> M = Matrix([[1, 2], [3, 4]]) # a Mutable Matrix + >>> IM = ImmutableMatrix([[1, 2], [3, 4]]) + >>> classof(M, IM) + + """ + priority_A = getattr(A, '_class_priority', None) + priority_B = getattr(B, '_class_priority', None) + if None not in (priority_A, priority_B): + if A._class_priority > B._class_priority: + return A.__class__ + else: + return B.__class__ + + try: + import numpy + except ImportError: + pass + else: + if isinstance(A, numpy.ndarray): + return B.__class__ + if isinstance(B, numpy.ndarray): + return A.__class__ + + raise TypeError("Incompatible classes %s, %s" % (A.__class__, B.__class__)) + + +def _unify_with_other(self, other): + """Unify self and other into a single matrix type, or check for scalar.""" + other, T = _coerce_operand(self, other) + + if T == "is_matrix": + typ = classof(self, other) + if typ != self.__class__: + self = _convert_matrix(typ, self) + if typ != other.__class__: + other = _convert_matrix(typ, other) + + return self, other, T + + +def a2idx(j, n=None): + """Return integer after making positive and validating against n.""" + if not isinstance(j, int): + jindex = getattr(j, '__index__', None) + if jindex is not None: + j = jindex() + else: + raise IndexError("Invalid index a[%r]" % (j,)) + if n is not None: + if j < 0: + j += n + if not (j >= 0 and j < n): + raise IndexError("Index out of range: a[%s]" % (j,)) + return int(j) + + +class DeferredVector(Symbol, NotIterable): # type: ignore + """A vector whose components are deferred (e.g. for use with lambdify). + + Examples + ======== + + >>> from sympy import DeferredVector, lambdify + >>> X = DeferredVector( 'X' ) + >>> X + X + >>> expr = (X[0] + 2, X[2] + 3) + >>> func = lambdify( X, expr) + >>> func( [1, 2, 3] ) + (3, 6) + """ + + def __getitem__(self, i): + if i == -0: + i = 0 + if i < 0: + raise IndexError('DeferredVector index out of range') + component_name = '%s[%d]' % (self.name, i) + return Symbol(component_name) + + def __str__(self): + return sstr(self) + + def __repr__(self): + return "DeferredVector('%s')" % self.name diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/normalforms.py b/.venv/lib/python3.13/site-packages/sympy/matrices/normalforms.py new file mode 100644 index 0000000000000000000000000000000000000000..61a7d26bbdb8c8a3e8e3044d39b2403b2e14b7d5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/normalforms.py @@ -0,0 +1,156 @@ +'''Functions returning normal forms of matrices''' + +from sympy.polys.domains.integerring import ZZ +from sympy.polys.polytools import Poly +from sympy.polys.matrices import DomainMatrix +from sympy.polys.matrices.normalforms import ( + smith_normal_form as _snf, + is_smith_normal_form as _is_snf, + smith_normal_decomp as _snd, + invariant_factors as _invf, + hermite_normal_form as _hnf, + ) + + +def _to_domain(m, domain=None): + """Convert Matrix to DomainMatrix""" + # XXX: deprecated support for RawMatrix: + ring = getattr(m, "ring", None) + m = m.applyfunc(lambda e: e.as_expr() if isinstance(e, Poly) else e) + + dM = DomainMatrix.from_Matrix(m) + + domain = domain or ring + if domain is not None: + dM = dM.convert_to(domain) + return dM + + +def smith_normal_form(m, domain=None): + ''' + Return the Smith Normal Form of a matrix `m` over the ring `domain`. + This will only work if the ring is a principal ideal domain. + + Examples + ======== + + >>> from sympy import Matrix, ZZ + >>> from sympy.matrices.normalforms import smith_normal_form + >>> m = Matrix([[12, 6, 4], [3, 9, 6], [2, 16, 14]]) + >>> print(smith_normal_form(m, domain=ZZ)) + Matrix([[1, 0, 0], [0, 10, 0], [0, 0, 30]]) + + ''' + dM = _to_domain(m, domain) + return _snf(dM).to_Matrix() + + +def is_smith_normal_form(m, domain=None): + ''' + Checks that the matrix is in Smith Normal Form + ''' + dM = _to_domain(m, domain) + return _is_snf(dM) + + +def smith_normal_decomp(m, domain=None): + ''' + Return the Smith Normal Decomposition of a matrix `m` over the ring + `domain`. This will only work if the ring is a principal ideal domain. + + Examples + ======== + + >>> from sympy import Matrix, ZZ + >>> from sympy.matrices.normalforms import smith_normal_decomp + >>> m = Matrix([[12, 6, 4], [3, 9, 6], [2, 16, 14]]) + >>> a, s, t = smith_normal_decomp(m, domain=ZZ) + >>> assert a == s * m * t + ''' + dM = _to_domain(m, domain) + a, s, t = _snd(dM) + return a.to_Matrix(), s.to_Matrix(), t.to_Matrix() + + +def invariant_factors(m, domain=None): + ''' + Return the tuple of abelian invariants for a matrix `m` + (as in the Smith-Normal form) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Smith_normal_form#Algorithm + .. [2] https://web.archive.org/web/20200331143852/https://sierra.nmsu.edu/morandi/notes/SmithNormalForm.pdf + + ''' + dM = _to_domain(m, domain) + factors = _invf(dM) + factors = tuple(dM.domain.to_sympy(f) for f in factors) + # XXX: deprecated. + if hasattr(m, "ring"): + if m.ring.is_PolynomialRing: + K = m.ring + to_poly = lambda f: Poly(f, K.symbols, domain=K.domain) + factors = tuple(to_poly(f) for f in factors) + return factors + + +def hermite_normal_form(A, *, D=None, check_rank=False): + r""" + Compute the Hermite Normal Form of a Matrix *A* of integers. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.matrices.normalforms import hermite_normal_form + >>> m = Matrix([[12, 6, 4], [3, 9, 6], [2, 16, 14]]) + >>> print(hermite_normal_form(m)) + Matrix([[10, 0, 2], [0, 15, 3], [0, 0, 2]]) + + Parameters + ========== + + A : $m \times n$ ``Matrix`` of integers. + + D : int, optional + Let $W$ be the HNF of *A*. If known in advance, a positive integer *D* + being any multiple of $\det(W)$ may be provided. In this case, if *A* + also has rank $m$, then we may use an alternative algorithm that works + mod *D* in order to prevent coefficient explosion. + + check_rank : boolean, optional (default=False) + The basic assumption is that, if you pass a value for *D*, then + you already believe that *A* has rank $m$, so we do not waste time + checking it for you. If you do want this to be checked (and the + ordinary, non-modulo *D* algorithm to be used if the check fails), then + set *check_rank* to ``True``. + + Returns + ======= + + ``Matrix`` + The HNF of matrix *A*. + + Raises + ====== + + DMDomainError + If the domain of the matrix is not :ref:`ZZ`. + + DMShapeError + If the mod *D* algorithm is used but the matrix has more rows than + columns. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithms 2.4.5 and 2.4.8.) + + """ + # Accept any of Python int, SymPy Integer, and ZZ itself: + if D is not None and not ZZ.of_type(D): + D = ZZ(int(D)) + return _hnf(A._rep, D=D, check_rank=check_rank).to_Matrix() diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/reductions.py b/.venv/lib/python3.13/site-packages/sympy/matrices/reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..aace8c0336358e1869a34d99f79390cc0c0163fe --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/reductions.py @@ -0,0 +1,387 @@ +from types import FunctionType + +from sympy.polys.polyerrors import CoercionFailed +from sympy.polys.domains import ZZ, QQ + +from .utilities import _get_intermediate_simp, _iszero, _dotprodsimp, _simplify +from .determinant import _find_reasonable_pivot + + +def _row_reduce_list(mat, rows, cols, one, iszerofunc, simpfunc, + normalize_last=True, normalize=True, zero_above=True): + """Row reduce a flat list representation of a matrix and return a tuple + (rref_matrix, pivot_cols, swaps) where ``rref_matrix`` is a flat list, + ``pivot_cols`` are the pivot columns and ``swaps`` are any row swaps that + were used in the process of row reduction. + + Parameters + ========== + + mat : list + list of matrix elements, must be ``rows`` * ``cols`` in length + + rows, cols : integer + number of rows and columns in flat list representation + + one : SymPy object + represents the value one, from ``Matrix.one`` + + iszerofunc : determines if an entry can be used as a pivot + + simpfunc : used to simplify elements and test if they are + zero if ``iszerofunc`` returns `None` + + normalize_last : indicates where all row reduction should + happen in a fraction-free manner and then the rows are + normalized (so that the pivots are 1), or whether + rows should be normalized along the way (like the naive + row reduction algorithm) + + normalize : whether pivot rows should be normalized so that + the pivot value is 1 + + zero_above : whether entries above the pivot should be zeroed. + If ``zero_above=False``, an echelon matrix will be returned. + """ + + def get_col(i): + return mat[i::cols] + + def row_swap(i, j): + mat[i*cols:(i + 1)*cols], mat[j*cols:(j + 1)*cols] = \ + mat[j*cols:(j + 1)*cols], mat[i*cols:(i + 1)*cols] + + def cross_cancel(a, i, b, j): + """Does the row op row[i] = a*row[i] - b*row[j]""" + q = (j - i)*cols + for p in range(i*cols, (i + 1)*cols): + mat[p] = isimp(a*mat[p] - b*mat[p + q]) + + isimp = _get_intermediate_simp(_dotprodsimp) + piv_row, piv_col = 0, 0 + pivot_cols = [] + swaps = [] + + # use a fraction free method to zero above and below each pivot + while piv_col < cols and piv_row < rows: + pivot_offset, pivot_val, \ + assumed_nonzero, newly_determined = _find_reasonable_pivot( + get_col(piv_col)[piv_row:], iszerofunc, simpfunc) + + # _find_reasonable_pivot may have simplified some things + # in the process. Let's not let them go to waste + for (offset, val) in newly_determined: + offset += piv_row + mat[offset*cols + piv_col] = val + + if pivot_offset is None: + piv_col += 1 + continue + + pivot_cols.append(piv_col) + if pivot_offset != 0: + row_swap(piv_row, pivot_offset + piv_row) + swaps.append((piv_row, pivot_offset + piv_row)) + + # if we aren't normalizing last, we normalize + # before we zero the other rows + if normalize_last is False: + i, j = piv_row, piv_col + mat[i*cols + j] = one + for p in range(i*cols + j + 1, (i + 1)*cols): + mat[p] = isimp(mat[p] / pivot_val) + # after normalizing, the pivot value is 1 + pivot_val = one + + # zero above and below the pivot + for row in range(rows): + # don't zero our current row + if row == piv_row: + continue + # don't zero above the pivot unless we're told. + if zero_above is False and row < piv_row: + continue + # if we're already a zero, don't do anything + val = mat[row*cols + piv_col] + if iszerofunc(val): + continue + + cross_cancel(pivot_val, row, val, piv_row) + piv_row += 1 + + # normalize each row + if normalize_last is True and normalize is True: + for piv_i, piv_j in enumerate(pivot_cols): + pivot_val = mat[piv_i*cols + piv_j] + mat[piv_i*cols + piv_j] = one + for p in range(piv_i*cols + piv_j + 1, (piv_i + 1)*cols): + mat[p] = isimp(mat[p] / pivot_val) + + return mat, tuple(pivot_cols), tuple(swaps) + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _row_reduce(M, iszerofunc, simpfunc, normalize_last=True, + normalize=True, zero_above=True): + + mat, pivot_cols, swaps = _row_reduce_list(list(M), M.rows, M.cols, M.one, + iszerofunc, simpfunc, normalize_last=normalize_last, + normalize=normalize, zero_above=zero_above) + + return M._new(M.rows, M.cols, mat), pivot_cols, swaps + + +def _is_echelon(M, iszerofunc=_iszero): + """Returns `True` if the matrix is in echelon form. That is, all rows of + zeros are at the bottom, and below each leading non-zero in a row are + exclusively zeros.""" + + if M.rows <= 0 or M.cols <= 0: + return True + + zeros_below = all(iszerofunc(t) for t in M[1:, 0]) + + if iszerofunc(M[0, 0]): + return zeros_below and _is_echelon(M[:, 1:], iszerofunc) + + return zeros_below and _is_echelon(M[1:, 1:], iszerofunc) + + +def _echelon_form(M, iszerofunc=_iszero, simplify=False, with_pivots=False): + """Returns a matrix row-equivalent to ``M`` that is in echelon form. Note + that echelon form of a matrix is *not* unique, however, properties like the + row space and the null space are preserved. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> M.echelon_form() + Matrix([ + [1, 2], + [0, -2]]) + """ + + simpfunc = simplify if isinstance(simplify, FunctionType) else _simplify + + mat, pivots, _ = _row_reduce(M, iszerofunc, simpfunc, + normalize_last=True, normalize=False, zero_above=False) + + if with_pivots: + return mat, pivots + + return mat + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _rank(M, iszerofunc=_iszero, simplify=False): + """Returns the rank of a matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x + >>> m = Matrix([[1, 2], [x, 1 - 1/x]]) + >>> m.rank() + 2 + >>> n = Matrix(3, 3, range(1, 10)) + >>> n.rank() + 2 + """ + + def _permute_complexity_right(M, iszerofunc): + """Permute columns with complicated elements as + far right as they can go. Since the ``sympy`` row reduction + algorithms start on the left, having complexity right-shifted + speeds things up. + + Returns a tuple (mat, perm) where perm is a permutation + of the columns to perform to shift the complex columns right, and mat + is the permuted matrix.""" + + def complexity(i): + # the complexity of a column will be judged by how many + # element's zero-ness cannot be determined + return sum(1 if iszerofunc(e) is None else 0 for e in M[:, i]) + + complex = [(complexity(i), i) for i in range(M.cols)] + perm = [j for (i, j) in sorted(complex)] + + return (M.permute(perm, orientation='cols'), perm) + + simpfunc = simplify if isinstance(simplify, FunctionType) else _simplify + + # for small matrices, we compute the rank explicitly + # if is_zero on elements doesn't answer the question + # for small matrices, we fall back to the full routine. + if M.rows <= 0 or M.cols <= 0: + return 0 + + if M.rows <= 1 or M.cols <= 1: + zeros = [iszerofunc(x) for x in M] + + if False in zeros: + return 1 + + if M.rows == 2 and M.cols == 2: + zeros = [iszerofunc(x) for x in M] + + if False not in zeros and None not in zeros: + return 0 + + d = M.det() + + if iszerofunc(d) and False in zeros: + return 1 + if iszerofunc(d) is False: + return 2 + + mat, _ = _permute_complexity_right(M, iszerofunc=iszerofunc) + _, pivots, _ = _row_reduce(mat, iszerofunc, simpfunc, normalize_last=True, + normalize=False, zero_above=False) + + return len(pivots) + + +def _to_DM_ZZ_QQ(M): + # We have to test for _rep here because there are tests that otherwise fail + # with e.g. "AttributeError: 'SubspaceOnlyMatrix' object has no attribute + # '_rep'." There is almost certainly no value in such tests. The + # presumption seems to be that someone could create a new class by + # inheriting some of the Matrix classes and not the full set that is used + # by the standard Matrix class but if anyone tried that it would fail in + # many ways. + if not hasattr(M, '_rep'): + return None + + rep = M._rep + K = rep.domain + + if K.is_ZZ: + return rep + elif K.is_QQ: + try: + return rep.convert_to(ZZ) + except CoercionFailed: + return rep + else: + if not all(e.is_Rational for e in M): + return None + try: + return rep.convert_to(ZZ) + except CoercionFailed: + return rep.convert_to(QQ) + + +def _rref_dm(dM): + """Compute the reduced row echelon form of a DomainMatrix.""" + K = dM.domain + + if K.is_ZZ: + dM_rref, den, pivots = dM.rref_den(keep_domain=False) + dM_rref = dM_rref.to_field() / den + elif K.is_QQ: + dM_rref, pivots = dM.rref() + else: + assert False # pragma: no cover + + M_rref = dM_rref.to_Matrix() + + return M_rref, pivots + + +def _rref(M, iszerofunc=_iszero, simplify=False, pivots=True, + normalize_last=True): + """Return reduced row-echelon form of matrix and indices + of pivot vars. + + Parameters + ========== + + iszerofunc : Function + A function used for detecting whether an element can + act as a pivot. ``lambda x: x.is_zero`` is used by default. + + simplify : Function + A function used to simplify elements when looking for a pivot. + By default SymPy's ``simplify`` is used. + + pivots : True or False + If ``True``, a tuple containing the row-reduced matrix and a tuple + of pivot columns is returned. If ``False`` just the row-reduced + matrix is returned. + + normalize_last : True or False + If ``True``, no pivots are normalized to `1` until after all + entries above and below each pivot are zeroed. This means the row + reduction algorithm is fraction free until the very last step. + If ``False``, the naive row reduction procedure is used where + each pivot is normalized to be `1` before row operations are + used to zero above and below the pivot. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x + >>> m = Matrix([[1, 2], [x, 1 - 1/x]]) + >>> m.rref() + (Matrix([ + [1, 0], + [0, 1]]), (0, 1)) + >>> rref_matrix, rref_pivots = m.rref() + >>> rref_matrix + Matrix([ + [1, 0], + [0, 1]]) + >>> rref_pivots + (0, 1) + + ``iszerofunc`` can correct rounding errors in matrices with float + values. In the following example, calling ``rref()`` leads to + floating point errors, incorrectly row reducing the matrix. + ``iszerofunc= lambda x: abs(x) < 1e-9`` sets sufficiently small numbers + to zero, avoiding this error. + + >>> m = Matrix([[0.9, -0.1, -0.2, 0], [-0.8, 0.9, -0.4, 0], [-0.1, -0.8, 0.6, 0]]) + >>> m.rref() + (Matrix([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0]]), (0, 1, 2)) + >>> m.rref(iszerofunc=lambda x:abs(x)<1e-9) + (Matrix([ + [1, 0, -0.301369863013699, 0], + [0, 1, -0.712328767123288, 0], + [0, 0, 0, 0]]), (0, 1)) + + Notes + ===== + + The default value of ``normalize_last=True`` can provide significant + speedup to row reduction, especially on matrices with symbols. However, + if you depend on the form row reduction algorithm leaves entries + of the matrix, set ``normalize_last=False`` + """ + # Try to use DomainMatrix for ZZ or QQ + dM = _to_DM_ZZ_QQ(M) + + if dM is not None: + # Use DomainMatrix for ZZ or QQ + mat, pivot_cols = _rref_dm(dM) + else: + # Use the generic Matrix routine. + if isinstance(simplify, FunctionType): + simpfunc = simplify + else: + simpfunc = _simplify + + mat, pivot_cols, _ = _row_reduce(M, iszerofunc, simpfunc, + normalize_last, normalize=True, zero_above=True) + + if pivots: + return mat, pivot_cols + else: + return mat diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/repmatrix.py b/.venv/lib/python3.13/site-packages/sympy/matrices/repmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..57f32fae34786f68f579fad7de38c9e3cf43e131 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/repmatrix.py @@ -0,0 +1,1034 @@ +from collections import defaultdict + +from operator import index as index_ + +from sympy.core.expr import Expr +from sympy.core.kind import Kind, NumberKind, UndefinedKind +from sympy.core.numbers import Integer, Rational +from sympy.core.sympify import _sympify, SympifyError +from sympy.core.singleton import S +from sympy.polys.domains import ZZ, QQ, GF, EXRAW +from sympy.polys.matrices import DomainMatrix +from sympy.polys.matrices.exceptions import DMNonInvertibleMatrixError +from sympy.polys.polyerrors import CoercionFailed, NotInvertible +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import is_sequence +from sympy.utilities.misc import filldedent, as_int + +from .exceptions import ShapeError, NonSquareMatrixError, NonInvertibleMatrixError +from .matrixbase import classof, MatrixBase +from .kind import MatrixKind + + +class RepMatrix(MatrixBase): + """Matrix implementation based on DomainMatrix as an internal representation. + + The RepMatrix class is a superclass for Matrix, ImmutableMatrix, + SparseMatrix and ImmutableSparseMatrix which are the main usable matrix + classes in SymPy. Most methods on this class are simply forwarded to + DomainMatrix. + """ + + # + # MatrixBase is the common superclass for all of the usable explicit matrix + # classes in SymPy. The idea is that MatrixBase is an abstract class though + # and that subclasses will implement the lower-level methods. + # + # RepMatrix is a subclass of MatrixBase that uses DomainMatrix as an + # internal representation and delegates lower-level methods to + # DomainMatrix. All of SymPy's standard explicit matrix classes subclass + # RepMatrix and so use DomainMatrix internally. + # + # A RepMatrix uses an internal DomainMatrix with the domain set to ZZ, QQ + # or EXRAW. The EXRAW domain is equivalent to the previous implementation + # of Matrix that used Expr for the elements. The ZZ and QQ domains are used + # when applicable just because they are compatible with the previous + # implementation but are much more efficient. Other domains such as QQ[x] + # are not used because they differ from Expr in some way (e.g. automatic + # expansion of powers and products). + # + + _rep: DomainMatrix + + def __eq__(self, other): + # Skip sympify for mutable matrices... + if not isinstance(other, RepMatrix): + try: + other = _sympify(other) + except SympifyError: + return NotImplemented + if not isinstance(other, RepMatrix): + return NotImplemented + + return self._rep.unify_eq(other._rep) + + def to_DM(self, domain=None, **kwargs): + """Convert to a :class:`~.DomainMatrix`. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> M.to_DM() + DomainMatrix({0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}}, (2, 2), ZZ) + + The :meth:`DomainMatrix.to_Matrix` method can be used to convert back: + + >>> M.to_DM().to_Matrix() == M + True + + The domain can be given explicitly or otherwise it will be chosen by + :func:`construct_domain`. Any keyword arguments (besides ``domain``) + are passed to :func:`construct_domain`: + + >>> from sympy import QQ, symbols + >>> x = symbols('x') + >>> M = Matrix([[x, 1], [1, x]]) + >>> M + Matrix([ + [x, 1], + [1, x]]) + >>> M.to_DM().domain + ZZ[x] + >>> M.to_DM(field=True).domain + ZZ(x) + >>> M.to_DM(domain=QQ[x]).domain + QQ[x] + + See Also + ======== + + DomainMatrix + DomainMatrix.to_Matrix + DomainMatrix.convert_to + DomainMatrix.choose_domain + construct_domain + """ + if domain is not None: + if kwargs: + raise TypeError("Options cannot be used with domain parameter") + return self._rep.convert_to(domain) + + rep = self._rep + dom = rep.domain + + # If the internal DomainMatrix is already ZZ or QQ then we can maybe + # bypass calling construct_domain or performing any conversions. Some + # kwargs might affect this though e.g. field=True (not sure if there + # are others). + if not kwargs: + if dom.is_ZZ: + return rep.copy() + elif dom.is_QQ: + # All elements might be integers + try: + return rep.convert_to(ZZ) + except CoercionFailed: + pass + return rep.copy() + + # Let construct_domain choose a domain + rep_dom = rep.choose_domain(**kwargs) + + # XXX: There should be an option to construct_domain to choose EXRAW + # instead of EX. At least converting to EX does not initially trigger + # EX.simplify which is what we want here but should probably be + # considered a bug in EX. Perhaps also this could be handled in + # DomainMatrix.choose_domain rather than here... + if rep_dom.domain.is_EX: + rep_dom = rep_dom.convert_to(EXRAW) + + return rep_dom + + @classmethod + def _unify_element_sympy(cls, rep, element): + domain = rep.domain + element = _sympify(element) + + if domain != EXRAW: + # The domain can only be ZZ, QQ or EXRAW + if element.is_Integer: + new_domain = domain + elif element.is_Rational: + new_domain = QQ + else: + new_domain = EXRAW + + # XXX: This converts the domain for all elements in the matrix + # which can be slow. This happens e.g. if __setitem__ changes one + # element to something that does not fit in the domain + if new_domain != domain: + rep = rep.convert_to(new_domain) + domain = new_domain + + if domain != EXRAW: + element = new_domain.from_sympy(element) + + if domain == EXRAW and not isinstance(element, Expr): + sympy_deprecation_warning( + """ + non-Expr objects in a Matrix is deprecated. Matrix represents + a mathematical matrix. To represent a container of non-numeric + entities, Use a list of lists, TableForm, NumPy array, or some + other data structure instead. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-non-expr-in-matrix", + stacklevel=4, + ) + + return rep, element + + @classmethod + def _dod_to_DomainMatrix(cls, rows, cols, dod, types): + + if not all(issubclass(typ, Expr) for typ in types): + sympy_deprecation_warning( + """ + non-Expr objects in a Matrix is deprecated. Matrix represents + a mathematical matrix. To represent a container of non-numeric + entities, Use a list of lists, TableForm, NumPy array, or some + other data structure instead. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-non-expr-in-matrix", + stacklevel=6, + ) + + rep = DomainMatrix(dod, (rows, cols), EXRAW) + + if all(issubclass(typ, Rational) for typ in types): + if all(issubclass(typ, Integer) for typ in types): + rep = rep.convert_to(ZZ) + else: + rep = rep.convert_to(QQ) + + return rep + + @classmethod + def _flat_list_to_DomainMatrix(cls, rows, cols, flat_list): + + elements_dod = defaultdict(dict) + for n, element in enumerate(flat_list): + if element != 0: + i, j = divmod(n, cols) + elements_dod[i][j] = element + + types = set(map(type, flat_list)) + + rep = cls._dod_to_DomainMatrix(rows, cols, elements_dod, types) + return rep + + @classmethod + def _smat_to_DomainMatrix(cls, rows, cols, smat): + + elements_dod = defaultdict(dict) + for (i, j), element in smat.items(): + if element != 0: + elements_dod[i][j] = element + + types = set(map(type, smat.values())) + + rep = cls._dod_to_DomainMatrix(rows, cols, elements_dod, types) + return rep + + def flat(self): + return self._rep.to_sympy().to_list_flat() + + def _eval_tolist(self): + return self._rep.to_sympy().to_list() + + def _eval_todok(self): + return self._rep.to_sympy().to_dok() + + @classmethod + def _eval_from_dok(cls, rows, cols, dok): + return cls._fromrep(cls._smat_to_DomainMatrix(rows, cols, dok)) + + def _eval_values(self): + return list(self._eval_iter_values()) + + def _eval_iter_values(self): + rep = self._rep + K = rep.domain + values = rep.iter_values() + if not K.is_EXRAW: + values = map(K.to_sympy, values) + return values + + def _eval_iter_items(self): + rep = self._rep + K = rep.domain + to_sympy = K.to_sympy + items = rep.iter_items() + if not K.is_EXRAW: + items = ((i, to_sympy(v)) for i, v in items) + return items + + def copy(self): + return self._fromrep(self._rep.copy()) + + @property + def kind(self) -> MatrixKind: + domain = self._rep.domain + element_kind: Kind + if domain in (ZZ, QQ): + element_kind = NumberKind + elif domain == EXRAW: + kinds = {e.kind for e in self.values()} + if len(kinds) == 1: + [element_kind] = kinds + else: + element_kind = UndefinedKind + else: # pragma: no cover + raise RuntimeError("Domain should only be ZZ, QQ or EXRAW") + return MatrixKind(element_kind) + + def _eval_has(self, *patterns): + # if the matrix has any zeros, see if S.Zero + # has the pattern. If _smat is full length, + # the matrix has no zeros. + zhas = False + dok = self.todok() + if len(dok) != self.rows*self.cols: + zhas = S.Zero.has(*patterns) + return zhas or any(value.has(*patterns) for value in dok.values()) + + def _eval_is_Identity(self): + if not all(self[i, i] == 1 for i in range(self.rows)): + return False + return len(self.todok()) == self.rows + + def _eval_is_symmetric(self, simpfunc): + diff = (self - self.T).applyfunc(simpfunc) + return len(diff.values()) == 0 + + def _eval_transpose(self): + """Returns the transposed SparseMatrix of this SparseMatrix. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> a = SparseMatrix(((1, 2), (3, 4))) + >>> a + Matrix([ + [1, 2], + [3, 4]]) + >>> a.T + Matrix([ + [1, 3], + [2, 4]]) + """ + return self._fromrep(self._rep.transpose()) + + def _eval_col_join(self, other): + return self._fromrep(self._rep.vstack(other._rep)) + + def _eval_row_join(self, other): + return self._fromrep(self._rep.hstack(other._rep)) + + def _eval_extract(self, rowsList, colsList): + return self._fromrep(self._rep.extract(rowsList, colsList)) + + def __getitem__(self, key): + return _getitem_RepMatrix(self, key) + + @classmethod + def _eval_zeros(cls, rows, cols): + rep = DomainMatrix.zeros((rows, cols), ZZ) + return cls._fromrep(rep) + + @classmethod + def _eval_eye(cls, rows, cols): + rep = DomainMatrix.eye((rows, cols), ZZ) + return cls._fromrep(rep) + + def _eval_add(self, other): + return classof(self, other)._fromrep(self._rep + other._rep) + + def _eval_matrix_mul(self, other): + return classof(self, other)._fromrep(self._rep * other._rep) + + def _eval_matrix_mul_elementwise(self, other): + selfrep, otherrep = self._rep.unify(other._rep) + newrep = selfrep.mul_elementwise(otherrep) + return classof(self, other)._fromrep(newrep) + + def _eval_scalar_mul(self, other): + rep, other = self._unify_element_sympy(self._rep, other) + return self._fromrep(rep.scalarmul(other)) + + def _eval_scalar_rmul(self, other): + rep, other = self._unify_element_sympy(self._rep, other) + return self._fromrep(rep.rscalarmul(other)) + + def _eval_Abs(self): + return self._fromrep(self._rep.applyfunc(abs)) + + def _eval_conjugate(self): + rep = self._rep + domain = rep.domain + if domain in (ZZ, QQ): + return self.copy() + else: + return self._fromrep(rep.applyfunc(lambda e: e.conjugate())) + + def equals(self, other, failing_expression=False): + """Applies ``equals`` to corresponding elements of the matrices, + trying to prove that the elements are equivalent, returning True + if they are, False if any pair is not, and None (or the first + failing expression if failing_expression is True) if it cannot + be decided if the expressions are equivalent or not. This is, in + general, an expensive operation. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x + >>> A = Matrix([x*(x - 1), 0]) + >>> B = Matrix([x**2 - x, 0]) + >>> A == B + False + >>> A.simplify() == B.simplify() + True + >>> A.equals(B) + True + >>> A.equals(2) + False + + See Also + ======== + sympy.core.expr.Expr.equals + """ + if self.shape != getattr(other, 'shape', None): + return False + + rv = True + for i in range(self.rows): + for j in range(self.cols): + ans = self[i, j].equals(other[i, j], failing_expression) + if ans is False: + return False + elif ans is not True and rv is True: + rv = ans + return rv + + def inv_mod(M, m): + r""" + Returns the inverse of the integer matrix ``M`` modulo ``m``. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(2, 2, [1, 2, 3, 4]) + >>> A.inv_mod(5) + Matrix([ + [3, 1], + [4, 2]]) + >>> A.inv_mod(3) + Matrix([ + [1, 1], + [0, 1]]) + + """ + + if not M.is_square: + raise NonSquareMatrixError() + + try: + m = as_int(m) + except ValueError: + raise TypeError("inv_mod: modulus m must be an integer") + + K = GF(m, symmetric=False) + + try: + dM = M.to_DM(K) + except CoercionFailed: + raise ValueError("inv_mod: matrix entries must be integers") + + if K.is_Field: + try: + dMi = dM.inv() + except DMNonInvertibleMatrixError as exc: + msg = f'Matrix is not invertible (mod {m})' + raise NonInvertibleMatrixError(msg) from exc + else: + dMadj, det = dM.adj_det() + try: + detinv = 1 / det + except NotInvertible: + msg = f'Matrix is not invertible (mod {m})' + raise NonInvertibleMatrixError(msg) + dMi = dMadj * detinv + + return dMi.to_Matrix() + + def lll(self, delta=0.75): + """LLL-reduced basis for the rowspace of a matrix of integers. + + Performs the Lenstra–Lenstra–Lovász (LLL) basis reduction algorithm. + + The implementation is provided by :class:`~DomainMatrix`. See + :meth:`~DomainMatrix.lll` for more details. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 0, 0, 0, -20160], + ... [0, 1, 0, 0, 33768], + ... [0, 0, 1, 0, 39578], + ... [0, 0, 0, 1, 47757]]) + >>> M.lll() + Matrix([ + [ 10, -3, -2, 8, -4], + [ 3, -9, 8, 1, -11], + [ -3, 13, -9, -3, -9], + [-12, -7, -11, 9, -1]]) + + See Also + ======== + + lll_transform + sympy.polys.matrices.domainmatrix.DomainMatrix.lll + """ + delta = QQ.from_sympy(_sympify(delta)) + dM = self._rep.convert_to(ZZ) + basis = dM.lll(delta=delta) + return self._fromrep(basis) + + def lll_transform(self, delta=0.75): + """LLL-reduced basis and transformation matrix. + + Performs the Lenstra–Lenstra–Lovász (LLL) basis reduction algorithm. + + The implementation is provided by :class:`~DomainMatrix`. See + :meth:`~DomainMatrix.lll_transform` for more details. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 0, 0, 0, -20160], + ... [0, 1, 0, 0, 33768], + ... [0, 0, 1, 0, 39578], + ... [0, 0, 0, 1, 47757]]) + >>> B, T = M.lll_transform() + >>> B + Matrix([ + [ 10, -3, -2, 8, -4], + [ 3, -9, 8, 1, -11], + [ -3, 13, -9, -3, -9], + [-12, -7, -11, 9, -1]]) + >>> T + Matrix([ + [ 10, -3, -2, 8], + [ 3, -9, 8, 1], + [ -3, 13, -9, -3], + [-12, -7, -11, 9]]) + + The transformation matrix maps the original basis to the LLL-reduced + basis: + + >>> T * M == B + True + + See Also + ======== + + lll + sympy.polys.matrices.domainmatrix.DomainMatrix.lll_transform + """ + delta = QQ.from_sympy(_sympify(delta)) + dM = self._rep.convert_to(ZZ) + basis, transform = dM.lll_transform(delta=delta) + B = self._fromrep(basis) + T = self._fromrep(transform) + return B, T + + +class MutableRepMatrix(RepMatrix): + """Mutable matrix based on DomainMatrix as the internal representation""" + + # + # MutableRepMatrix is a subclass of RepMatrix that adds/overrides methods + # to make the instances mutable. MutableRepMatrix is a superclass for both + # MutableDenseMatrix and MutableSparseMatrix. + # + + is_zero = False + + def __new__(cls, *args, **kwargs): + return cls._new(*args, **kwargs) + + @classmethod + def _new(cls, *args, copy=True, **kwargs): + if copy is False: + # The input was rows, cols, [list]. + # It should be used directly without creating a copy. + if len(args) != 3: + raise TypeError("'copy=False' requires a matrix be initialized as rows,cols,[list]") + rows, cols, flat_list = args + else: + rows, cols, flat_list = cls._handle_creation_inputs(*args, **kwargs) + flat_list = list(flat_list) # create a shallow copy + + rep = cls._flat_list_to_DomainMatrix(rows, cols, flat_list) + + return cls._fromrep(rep) + + @classmethod + def _fromrep(cls, rep): + obj = super().__new__(cls) + obj.rows, obj.cols = rep.shape + obj._rep = rep + return obj + + def copy(self): + return self._fromrep(self._rep.copy()) + + def as_mutable(self): + return self.copy() + + def __setitem__(self, key, value): + """ + + Examples + ======== + + >>> from sympy import Matrix, I, zeros, ones + >>> m = Matrix(((1, 2+I), (3, 4))) + >>> m + Matrix([ + [1, 2 + I], + [3, 4]]) + >>> m[1, 0] = 9 + >>> m + Matrix([ + [1, 2 + I], + [9, 4]]) + >>> m[1, 0] = [[0, 1]] + + To replace row r you assign to position r*m where m + is the number of columns: + + >>> M = zeros(4) + >>> m = M.cols + >>> M[3*m] = ones(1, m)*2; M + Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [2, 2, 2, 2]]) + + And to replace column c you can assign to position c: + + >>> M[2] = ones(m, 1)*4; M + Matrix([ + [0, 0, 4, 0], + [0, 0, 4, 0], + [0, 0, 4, 0], + [2, 2, 4, 2]]) + """ + rv = self._setitem(key, value) + if rv is not None: + i, j, value = rv + self._rep, value = self._unify_element_sympy(self._rep, value) + self._rep.rep.setitem(i, j, value) + + def _eval_col_del(self, col): + self._rep = DomainMatrix.hstack(self._rep[:,:col], self._rep[:,col+1:]) + self.cols -= 1 + + def _eval_row_del(self, row): + self._rep = DomainMatrix.vstack(self._rep[:row,:], self._rep[row+1:, :]) + self.rows -= 1 + + def _eval_col_insert(self, col, other): + other = self._new(other) + return self.hstack(self[:,:col], other, self[:,col:]) + + def _eval_row_insert(self, row, other): + other = self._new(other) + return self.vstack(self[:row,:], other, self[row:,:]) + + def col_op(self, j, f): + """In-place operation on col j using two-arg functor whose args are + interpreted as (self[i, j], i). + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.col_op(1, lambda v, i: v + 2*M[i, 0]); M + Matrix([ + [1, 2, 0], + [0, 1, 0], + [0, 0, 1]]) + + See Also + ======== + col + row_op + """ + for i in range(self.rows): + self[i, j] = f(self[i, j], i) + + def col_swap(self, i, j): + """Swap the two given columns of the matrix in-place. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 0], [1, 0]]) + >>> M + Matrix([ + [1, 0], + [1, 0]]) + >>> M.col_swap(0, 1) + >>> M + Matrix([ + [0, 1], + [0, 1]]) + + See Also + ======== + + col + row_swap + """ + for k in range(0, self.rows): + self[k, i], self[k, j] = self[k, j], self[k, i] + + def row_op(self, i, f): + """In-place operation on row ``i`` using two-arg functor whose args are + interpreted as ``(self[i, j], j)``. + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.row_op(1, lambda v, j: v + 2*M[0, j]); M + Matrix([ + [1, 0, 0], + [2, 1, 0], + [0, 0, 1]]) + + See Also + ======== + row + zip_row_op + col_op + + """ + for j in range(self.cols): + self[i, j] = f(self[i, j], j) + + #The next three methods give direct support for the most common row operations inplace. + def row_mult(self,i,factor): + """Multiply the given row by the given factor in-place. + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.row_mult(1,7); M + Matrix([ + [1, 0, 0], + [0, 7, 0], + [0, 0, 1]]) + + """ + for j in range(self.cols): + self[i,j] *= factor + + def row_add(self,s,t,k): + """Add k times row s (source) to row t (target) in place. + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.row_add(0, 2,3); M + Matrix([ + [1, 0, 0], + [0, 1, 0], + [3, 0, 1]]) + """ + + for j in range(self.cols): + self[t,j] += k*self[s,j] + + def row_swap(self, i, j): + """Swap the two given rows of the matrix in-place. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[0, 1], [1, 0]]) + >>> M + Matrix([ + [0, 1], + [1, 0]]) + >>> M.row_swap(0, 1) + >>> M + Matrix([ + [1, 0], + [0, 1]]) + + See Also + ======== + + row + col_swap + """ + for k in range(0, self.cols): + self[i, k], self[j, k] = self[j, k], self[i, k] + + def zip_row_op(self, i, k, f): + """In-place operation on row ``i`` using two-arg functor whose args are + interpreted as ``(self[i, j], self[k, j])``. + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.zip_row_op(1, 0, lambda v, u: v + 2*u); M + Matrix([ + [1, 0, 0], + [2, 1, 0], + [0, 0, 1]]) + + See Also + ======== + row + row_op + col_op + + """ + for j in range(self.cols): + self[i, j] = f(self[i, j], self[k, j]) + + def copyin_list(self, key, value): + """Copy in elements from a list. + + Parameters + ========== + + key : slice + The section of this matrix to replace. + value : iterable + The iterable to copy values from. + + Examples + ======== + + >>> from sympy import eye + >>> I = eye(3) + >>> I[:2, 0] = [1, 2] # col + >>> I + Matrix([ + [1, 0, 0], + [2, 1, 0], + [0, 0, 1]]) + >>> I[1, :2] = [[3, 4]] + >>> I + Matrix([ + [1, 0, 0], + [3, 4, 0], + [0, 0, 1]]) + + See Also + ======== + + copyin_matrix + """ + if not is_sequence(value): + raise TypeError("`value` must be an ordered iterable, not %s." % type(value)) + return self.copyin_matrix(key, type(self)(value)) + + def copyin_matrix(self, key, value): + """Copy in values from a matrix into the given bounds. + + Parameters + ========== + + key : slice + The section of this matrix to replace. + value : Matrix + The matrix to copy values from. + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> M = Matrix([[0, 1], [2, 3], [4, 5]]) + >>> I = eye(3) + >>> I[:3, :2] = M + >>> I + Matrix([ + [0, 1, 0], + [2, 3, 0], + [4, 5, 1]]) + >>> I[0, 1] = M + >>> I + Matrix([ + [0, 0, 1], + [2, 2, 3], + [4, 4, 5]]) + + See Also + ======== + + copyin_list + """ + rlo, rhi, clo, chi = self.key2bounds(key) + shape = value.shape + dr, dc = rhi - rlo, chi - clo + if shape != (dr, dc): + raise ShapeError(filldedent("The Matrix `value` doesn't have the " + "same dimensions " + "as the in sub-Matrix given by `key`.")) + + for i in range(value.rows): + for j in range(value.cols): + self[i + rlo, j + clo] = value[i, j] + + def fill(self, value): + """Fill self with the given value. + + Notes + ===== + + Unless many values are going to be deleted (i.e. set to zero) + this will create a matrix that is slower than a dense matrix in + operations. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> M = SparseMatrix.zeros(3); M + Matrix([ + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]]) + >>> M.fill(1); M + Matrix([ + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]) + + See Also + ======== + + zeros + ones + """ + value = _sympify(value) + if not value: + self._rep = DomainMatrix.zeros(self.shape, EXRAW) + else: + elements_dod = {i: dict.fromkeys(range(self.cols), value) for i in range(self.rows)} + self._rep = DomainMatrix(elements_dod, self.shape, EXRAW) + + +def _getitem_RepMatrix(self, key): + """Return portion of self defined by key. If the key involves a slice + then a list will be returned (if key is a single slice) or a matrix + (if key was a tuple involving a slice). + + Examples + ======== + + >>> from sympy import Matrix, I + >>> m = Matrix([ + ... [1, 2 + I], + ... [3, 4 ]]) + + If the key is a tuple that does not involve a slice then that element + is returned: + + >>> m[1, 0] + 3 + + When a tuple key involves a slice, a matrix is returned. Here, the + first column is selected (all rows, column 0): + + >>> m[:, 0] + Matrix([ + [1], + [3]]) + + If the slice is not a tuple then it selects from the underlying + list of elements that are arranged in row order and a list is + returned if a slice is involved: + + >>> m[0] + 1 + >>> m[::2] + [1, 3] + """ + if isinstance(key, tuple): + i, j = key + try: + return self._rep.getitem_sympy(index_(i), index_(j)) + except (TypeError, IndexError): + if (isinstance(i, Expr) and not i.is_number) or (isinstance(j, Expr) and not j.is_number): + if ((j < 0) is True) or ((j >= self.shape[1]) is True) or\ + ((i < 0) is True) or ((i >= self.shape[0]) is True): + raise ValueError("index out of boundary") + from sympy.matrices.expressions.matexpr import MatrixElement + return MatrixElement(self, i, j) + + if isinstance(i, slice): + i = range(self.rows)[i] + elif is_sequence(i): + pass + else: + i = [i] + if isinstance(j, slice): + j = range(self.cols)[j] + elif is_sequence(j): + pass + else: + j = [j] + return self.extract(i, j) + + else: + # Index/slice like a flattened list + rows, cols = self.shape + + # Raise the appropriate exception: + if not rows * cols: + return [][key] + + rep = self._rep.rep + domain = rep.domain + is_slice = isinstance(key, slice) + + if is_slice: + values = [rep.getitem(*divmod(n, cols)) for n in range(rows * cols)[key]] + else: + values = [rep.getitem(*divmod(index_(key), cols))] + + if domain != EXRAW: + to_sympy = domain.to_sympy + values = [to_sympy(val) for val in values] + + if is_slice: + return values + else: + return values[0] diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/solvers.py b/.venv/lib/python3.13/site-packages/sympy/matrices/solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..1fba990df80dcf46304ecb1412f5382f60948c51 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/solvers.py @@ -0,0 +1,942 @@ +from sympy.core.function import expand_mul +from sympy.core.symbol import Dummy, uniquely_named_symbol, symbols +from sympy.utilities.iterables import numbered_symbols + +from .exceptions import ShapeError, NonSquareMatrixError, NonInvertibleMatrixError +from .eigen import _fuzzy_positive_definite +from .utilities import _get_intermediate_simp, _iszero + + +def _diagonal_solve(M, rhs): + """Solves ``Ax = B`` efficiently, where A is a diagonal Matrix, + with non-zero diagonal entries. + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> A = eye(2)*2 + >>> B = Matrix([[1, 2], [3, 4]]) + >>> A.diagonal_solve(B) == B/2 + True + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + gauss_jordan_solve + cholesky_solve + LDLsolve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + if not M.is_diagonal(): + raise TypeError("Matrix should be diagonal") + if rhs.rows != M.rows: + raise TypeError("Size mismatch") + + return M._new( + rhs.rows, rhs.cols, lambda i, j: rhs[i, j] / M[i, i]) + + +def _lower_triangular_solve(M, rhs): + """Solves ``Ax = B``, where A is a lower triangular matrix. + + See Also + ======== + + upper_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + from .dense import MutableDenseMatrix + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if rhs.rows != M.rows: + raise ShapeError("Matrices size mismatch.") + if not M.is_lower: + raise ValueError("Matrix must be lower triangular.") + + dps = _get_intermediate_simp() + X = MutableDenseMatrix.zeros(M.rows, rhs.cols) + + for j in range(rhs.cols): + for i in range(M.rows): + if M[i, i] == 0: + raise TypeError("Matrix must be non-singular.") + + X[i, j] = dps((rhs[i, j] - sum(M[i, k]*X[k, j] + for k in range(i))) / M[i, i]) + + return M._new(X) + +def _lower_triangular_solve_sparse(M, rhs): + """Solves ``Ax = B``, where A is a lower triangular matrix. + + See Also + ======== + + upper_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if rhs.rows != M.rows: + raise ShapeError("Matrices size mismatch.") + if not M.is_lower: + raise ValueError("Matrix must be lower triangular.") + + dps = _get_intermediate_simp() + rows = [[] for i in range(M.rows)] + + for i, j, v in M.row_list(): + if i > j: + rows[i].append((j, v)) + + X = rhs.as_mutable() + + for j in range(rhs.cols): + for i in range(rhs.rows): + for u, v in rows[i]: + X[i, j] -= v*X[u, j] + + X[i, j] = dps(X[i, j] / M[i, i]) + + return M._new(X) + + +def _upper_triangular_solve(M, rhs): + """Solves ``Ax = B``, where A is an upper triangular matrix. + + See Also + ======== + + lower_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + from .dense import MutableDenseMatrix + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if rhs.rows != M.rows: + raise ShapeError("Matrix size mismatch.") + if not M.is_upper: + raise TypeError("Matrix is not upper triangular.") + + dps = _get_intermediate_simp() + X = MutableDenseMatrix.zeros(M.rows, rhs.cols) + + for j in range(rhs.cols): + for i in reversed(range(M.rows)): + if M[i, i] == 0: + raise ValueError("Matrix must be non-singular.") + + X[i, j] = dps((rhs[i, j] - sum(M[i, k]*X[k, j] + for k in range(i + 1, M.rows))) / M[i, i]) + + return M._new(X) + +def _upper_triangular_solve_sparse(M, rhs): + """Solves ``Ax = B``, where A is an upper triangular matrix. + + See Also + ======== + + lower_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if rhs.rows != M.rows: + raise ShapeError("Matrix size mismatch.") + if not M.is_upper: + raise TypeError("Matrix is not upper triangular.") + + dps = _get_intermediate_simp() + rows = [[] for i in range(M.rows)] + + for i, j, v in M.row_list(): + if i < j: + rows[i].append((j, v)) + + X = rhs.as_mutable() + + for j in range(rhs.cols): + for i in reversed(range(rhs.rows)): + for u, v in reversed(rows[i]): + X[i, j] -= v*X[u, j] + + X[i, j] = dps(X[i, j] / M[i, i]) + + return M._new(X) + + +def _cholesky_solve(M, rhs): + """Solves ``Ax = B`` using Cholesky decomposition, + for a general square non-singular matrix. + For a non-square matrix with rows > cols, + the least squares solution is returned. + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + gauss_jordan_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + if M.rows < M.cols: + raise NotImplementedError( + 'Under-determined System. Try M.gauss_jordan_solve(rhs)') + + hermitian = True + reform = False + + if M.is_symmetric(): + hermitian = False + elif not M.is_hermitian: + reform = True + + if reform or _fuzzy_positive_definite(M) is False: + H = M.H + M = H.multiply(M) + rhs = H.multiply(rhs) + hermitian = not M.is_symmetric() + + L = M.cholesky(hermitian=hermitian) + Y = L.lower_triangular_solve(rhs) + + if hermitian: + return (L.H).upper_triangular_solve(Y) + else: + return (L.T).upper_triangular_solve(Y) + + +def _LDLsolve(M, rhs): + """Solves ``Ax = B`` using LDL decomposition, + for a general square and non-singular matrix. + + For a non-square matrix with rows > cols, + the least squares solution is returned. + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> A = eye(2)*2 + >>> B = Matrix([[1, 2], [3, 4]]) + >>> A.LDLsolve(B) == B/2 + True + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.LDLdecomposition + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + if M.rows < M.cols: + raise NotImplementedError( + 'Under-determined System. Try M.gauss_jordan_solve(rhs)') + + hermitian = True + reform = False + + if M.is_symmetric(): + hermitian = False + elif not M.is_hermitian: + reform = True + + if reform or _fuzzy_positive_definite(M) is False: + H = M.H + M = H.multiply(M) + rhs = H.multiply(rhs) + hermitian = not M.is_symmetric() + + L, D = M.LDLdecomposition(hermitian=hermitian) + Y = L.lower_triangular_solve(rhs) + Z = D.diagonal_solve(Y) + + if hermitian: + return (L.H).upper_triangular_solve(Z) + else: + return (L.T).upper_triangular_solve(Z) + + +def _LUsolve(M, rhs, iszerofunc=_iszero): + """Solve the linear system ``Ax = rhs`` for ``x`` where ``A = M``. + + This is for symbolic matrices, for real or complex ones use + mpmath.lu_solve or mpmath.qr_solve. + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + QRsolve + pinv_solve + LUdecomposition + cramer_solve + """ + + if rhs.rows != M.rows: + raise ShapeError( + "``M`` and ``rhs`` must have the same number of rows.") + + m = M.rows + n = M.cols + + if m < n: + raise NotImplementedError("Underdetermined systems not supported.") + + try: + A, perm = M.LUdecomposition_Simple( + iszerofunc=iszerofunc, rankcheck=True) + except ValueError: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + + dps = _get_intermediate_simp() + b = rhs.permute_rows(perm).as_mutable() + + # forward substitution, all diag entries are scaled to 1 + for i in range(m): + for j in range(min(i, n)): + scale = A[i, j] + b.zip_row_op(i, j, lambda x, y: dps(x - scale * y)) + + # consistency check for overdetermined systems + if m > n: + for i in range(n, m): + for j in range(b.cols): + if not iszerofunc(b[i, j]): + raise ValueError("The system is inconsistent.") + + b = b[0:n, :] # truncate zero rows if consistent + + # backward substitution + for i in range(n - 1, -1, -1): + for j in range(i + 1, n): + scale = A[i, j] + b.zip_row_op(i, j, lambda x, y: dps(x - scale * y)) + + scale = A[i, i] + b.row_op(i, lambda x, _: dps(scale**-1 * x)) + + return rhs.__class__(b) + + +def _QRsolve(M, b): + """Solve the linear system ``Ax = b``. + + ``M`` is the matrix ``A``, the method argument is the vector + ``b``. The method returns the solution vector ``x``. If ``b`` is a + matrix, the system is solved for each column of ``b`` and the + return value is a matrix of the same shape as ``b``. + + This method is slower (approximately by a factor of 2) but + more stable for floating-point arithmetic than the LUsolve method. + However, LUsolve usually uses an exact arithmetic, so you do not need + to use QRsolve. + + This is mainly for educational purposes and symbolic matrices, for real + (or complex) matrices use mpmath.qr_solve. + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + pinv_solve + QRdecomposition + cramer_solve + """ + + dps = _get_intermediate_simp(expand_mul, expand_mul) + Q, R = M.QRdecomposition() + y = Q.T * b + + # back substitution to solve R*x = y: + # We build up the result "backwards" in the vector 'x' and reverse it + # only in the end. + x = [] + n = R.rows + + for j in range(n - 1, -1, -1): + tmp = y[j, :] + + for k in range(j + 1, n): + tmp -= R[j, k] * x[n - 1 - k] + + tmp = dps(tmp) + + x.append(tmp / R[j, j]) + + return M.vstack(*x[::-1]) + + +def _gauss_jordan_solve(M, B, freevar=False): + """ + Solves ``Ax = B`` using Gauss Jordan elimination. + + There may be zero, one, or infinite solutions. If one solution + exists, it will be returned. If infinite solutions exist, it will + be returned parametrically. If no solutions exist, It will throw + ValueError. + + Parameters + ========== + + B : Matrix + The right hand side of the equation to be solved for. Must have + the same number of rows as matrix A. + + freevar : boolean, optional + Flag, when set to `True` will return the indices of the free + variables in the solutions (column Matrix), for a system that is + undetermined (e.g. A has more columns than rows), for which + infinite solutions are possible, in terms of arbitrary + values of free variables. Default `False`. + + Returns + ======= + + x : Matrix + The matrix that will satisfy ``Ax = B``. Will have as many rows as + matrix A has columns, and as many columns as matrix B. + + params : Matrix + If the system is underdetermined (e.g. A has more columns than + rows), infinite solutions are possible, in terms of arbitrary + parameters. These arbitrary parameters are returned as params + Matrix. + + free_var_index : List, optional + If the system is underdetermined (e.g. A has more columns than + rows), infinite solutions are possible, in terms of arbitrary + values of free variables. Then the indices of the free variables + in the solutions (column Matrix) are returned by free_var_index, + if the flag `freevar` is set to `True`. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2, 1, 1], [1, 2, 2, -1], [2, 4, 0, 6]]) + >>> B = Matrix([7, 12, 4]) + >>> sol, params = A.gauss_jordan_solve(B) + >>> sol + Matrix([ + [-2*tau0 - 3*tau1 + 2], + [ tau0], + [ 2*tau1 + 5], + [ tau1]]) + >>> params + Matrix([ + [tau0], + [tau1]]) + >>> taus_zeroes = { tau:0 for tau in params } + >>> sol_unique = sol.xreplace(taus_zeroes) + >>> sol_unique + Matrix([ + [2], + [0], + [5], + [0]]) + + + >>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + >>> B = Matrix([3, 6, 9]) + >>> sol, params = A.gauss_jordan_solve(B) + >>> sol + Matrix([ + [-1], + [ 2], + [ 0]]) + >>> params + Matrix(0, 1, []) + + >>> A = Matrix([[2, -7], [-1, 4]]) + >>> B = Matrix([[-21, 3], [12, -2]]) + >>> sol, params = A.gauss_jordan_solve(B) + >>> sol + Matrix([ + [0, -2], + [3, -1]]) + >>> params + Matrix(0, 2, []) + + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2, 1, 1], [1, 2, 2, -1], [2, 4, 0, 6]]) + >>> B = Matrix([7, 12, 4]) + >>> sol, params, freevars = A.gauss_jordan_solve(B, freevar=True) + >>> sol + Matrix([ + [-2*tau0 - 3*tau1 + 2], + [ tau0], + [ 2*tau1 + 5], + [ tau1]]) + >>> params + Matrix([ + [tau0], + [tau1]]) + >>> freevars + [1, 3] + + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gaussian_elimination + + """ + + from sympy.matrices import Matrix, zeros + + cls = M.__class__ + aug = M.hstack(M.copy(), B.copy()) + B_cols = B.cols + row, col = aug[:, :-B_cols].shape + + # solve by reduced row echelon form + A, pivots = aug.rref(simplify=True) + A, v = A[:, :-B_cols], A[:, -B_cols:] + pivots = list(filter(lambda p: p < col, pivots)) + rank = len(pivots) + + # Get index of free symbols (free parameters) + # non-pivots columns are free variables + free_var_index = [c for c in range(A.cols) if c not in pivots] + + # Bring to block form + permutation = Matrix(pivots + free_var_index).T + + # check for existence of solutions + # rank of aug Matrix should be equal to rank of coefficient matrix + if not v[rank:, :].is_zero_matrix: + raise ValueError("Linear system has no solution") + + # Free parameters + # what are current unnumbered free symbol names? + name = uniquely_named_symbol('tau', [aug], + compare=lambda i: str(i).rstrip('1234567890'), + modify=lambda s: '_' + s).name + gen = numbered_symbols(name) + tau = Matrix([next(gen) for k in range((col - rank)*B_cols)]).reshape( + col - rank, B_cols) + + # Full parametric solution + V = A[:rank, free_var_index] + vt = v[:rank, :] + free_sol = tau.vstack(vt - V * tau, tau) + + # Undo permutation + sol = zeros(col, B_cols) + + for k in range(col): + sol[permutation[k], :] = free_sol[k,:] + + sol, tau = cls(sol), cls(tau) + + if freevar: + return sol, tau, free_var_index + else: + return sol, tau + + +def _pinv_solve(M, B, arbitrary_matrix=None): + """Solve ``Ax = B`` using the Moore-Penrose pseudoinverse. + + There may be zero, one, or infinite solutions. If one solution + exists, it will be returned. If infinite solutions exist, one will + be returned based on the value of arbitrary_matrix. If no solutions + exist, the least-squares solution is returned. + + Parameters + ========== + + B : Matrix + The right hand side of the equation to be solved for. Must have + the same number of rows as matrix A. + arbitrary_matrix : Matrix + If the system is underdetermined (e.g. A has more columns than + rows), infinite solutions are possible, in terms of an arbitrary + matrix. This parameter may be set to a specific matrix to use + for that purpose; if so, it must be the same shape as x, with as + many rows as matrix A has columns, and as many columns as matrix + B. If left as None, an appropriate matrix containing dummy + symbols in the form of ``wn_m`` will be used, with n and m being + row and column position of each symbol. + + Returns + ======= + + x : Matrix + The matrix that will satisfy ``Ax = B``. Will have as many rows as + matrix A has columns, and as many columns as matrix B. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> B = Matrix([7, 8]) + >>> A.pinv_solve(B) + Matrix([ + [ _w0_0/6 - _w1_0/3 + _w2_0/6 - 55/18], + [-_w0_0/3 + 2*_w1_0/3 - _w2_0/3 + 1/9], + [ _w0_0/6 - _w1_0/3 + _w2_0/6 + 59/18]]) + >>> A.pinv_solve(B, arbitrary_matrix=Matrix([0, 0, 0])) + Matrix([ + [-55/18], + [ 1/9], + [ 59/18]]) + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv + + Notes + ===== + + This may return either exact solutions or least squares solutions. + To determine which, check ``A * A.pinv() * B == B``. It will be + True if exact solutions exist, and False if only a least-squares + solution exists. Be aware that the left hand side of that equation + may need to be simplified to correctly compare to the right hand + side. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Moore-Penrose_pseudoinverse#Obtaining_all_solutions_of_a_linear_system + + """ + + from sympy.matrices import eye + + A = M + A_pinv = M.pinv() + + if arbitrary_matrix is None: + rows, cols = A.cols, B.cols + w = symbols('w:{}_:{}'.format(rows, cols), cls=Dummy) + arbitrary_matrix = M.__class__(cols, rows, w).T + + return A_pinv.multiply(B) + (eye(A.cols) - + A_pinv.multiply(A)).multiply(arbitrary_matrix) + + +def _cramer_solve(M, rhs, det_method="laplace"): + """Solves system of linear equations using Cramer's rule. + + This method is relatively inefficient compared to other methods. + However it only uses a single division, assuming a division-free determinant + method is provided. This is helpful to minimize the chance of divide-by-zero + cases in symbolic solutions to linear systems. + + Parameters + ========== + M : Matrix + The matrix representing the left hand side of the equation. + rhs : Matrix + The matrix representing the right hand side of the equation. + det_method : str or callable + The method to use to calculate the determinant of the matrix. + The default is ``'laplace'``. If a callable is passed, it should take a + single argument, the matrix, and return the determinant of the matrix. + + Returns + ======= + x : Matrix + The matrix that will satisfy ``Ax = B``. Will have as many rows as + matrix A has columns, and as many columns as matrix B. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[0, -6, 1], [0, -6, -1], [-5, -2, 3]]) + >>> B = Matrix([[-30, -9], [-18, -27], [-26, 46]]) + >>> x = A.cramer_solve(B) + >>> x + Matrix([ + [ 0, -5], + [ 4, 3], + [-6, 9]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Cramer%27s_rule#Explicit_formulas_for_small_systems + + """ + from .dense import zeros + + def entry(i, j): + return rhs[i, sol] if j == col else M[i, j] + + if det_method == "bird": + from .determinant import _det_bird + det = _det_bird + elif det_method == "laplace": + from .determinant import _det_laplace + det = _det_laplace + elif isinstance(det_method, str): + det = lambda matrix: matrix.det(method=det_method) + else: + det = det_method + det_M = det(M) + x = zeros(*rhs.shape) + for sol in range(rhs.shape[1]): + for col in range(rhs.shape[0]): + x[col, sol] = det(M.__class__(*M.shape, entry)) / det_M + return M.__class__(x) + + +def _solve(M, rhs, method='GJ'): + """Solves linear equation where the unique solution exists. + + Parameters + ========== + + rhs : Matrix + Vector representing the right hand side of the linear equation. + + method : string, optional + If set to ``'GJ'`` or ``'GE'``, the Gauss-Jordan elimination will be + used, which is implemented in the routine ``gauss_jordan_solve``. + + If set to ``'LU'``, ``LUsolve`` routine will be used. + + If set to ``'QR'``, ``QRsolve`` routine will be used. + + If set to ``'PINV'``, ``pinv_solve`` routine will be used. + + If set to ``'CRAMER'``, ``cramer_solve`` routine will be used. + + It also supports the methods available for special linear systems + + For positive definite systems: + + If set to ``'CH'``, ``cholesky_solve`` routine will be used. + + If set to ``'LDL'``, ``LDLsolve`` routine will be used. + + To use a different method and to compute the solution via the + inverse, use a method defined in the .inv() docstring. + + Returns + ======= + + solutions : Matrix + Vector representing the solution. + + Raises + ====== + + ValueError + If there is not a unique solution then a ``ValueError`` will be + raised. + + If ``M`` is not square, a ``ValueError`` and a different routine + for solving the system will be suggested. + """ + + if method in ('GJ', 'GE'): + try: + soln, param = M.gauss_jordan_solve(rhs) + + if param: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible. " + "Try ``M.gauss_jordan_solve(rhs)`` to obtain a parametric solution.") + + except ValueError: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + + return soln + + elif method == 'LU': + return M.LUsolve(rhs) + elif method == 'CH': + return M.cholesky_solve(rhs) + elif method == 'QR': + return M.QRsolve(rhs) + elif method == 'LDL': + return M.LDLsolve(rhs) + elif method == 'PINV': + return M.pinv_solve(rhs) + elif method == 'CRAMER': + return M.cramer_solve(rhs) + else: + return M.inv(method=method).multiply(rhs) + + +def _solve_least_squares(M, rhs, method='CH'): + """Return the least-square fit to the data. + + Parameters + ========== + + rhs : Matrix + Vector representing the right hand side of the linear equation. + + method : string or boolean, optional + If set to ``'CH'``, ``cholesky_solve`` routine will be used. + + If set to ``'LDL'``, ``LDLsolve`` routine will be used. + + If set to ``'QR'``, ``QRsolve`` routine will be used. + + If set to ``'PINV'``, ``pinv_solve`` routine will be used. + + Otherwise, the conjugate of ``M`` will be used to create a system + of equations that is passed to ``solve`` along with the hint + defined by ``method``. + + Returns + ======= + + solutions : Matrix + Vector representing the solution. + + Examples + ======== + + >>> from sympy import Matrix, ones + >>> A = Matrix([1, 2, 3]) + >>> B = Matrix([2, 3, 4]) + >>> S = Matrix(A.row_join(B)) + >>> S + Matrix([ + [1, 2], + [2, 3], + [3, 4]]) + + If each line of S represent coefficients of Ax + By + and x and y are [2, 3] then S*xy is: + + >>> r = S*Matrix([2, 3]); r + Matrix([ + [ 8], + [13], + [18]]) + + But let's add 1 to the middle value and then solve for the + least-squares value of xy: + + >>> xy = S.solve_least_squares(Matrix([8, 14, 18])); xy + Matrix([ + [ 5/3], + [10/3]]) + + The error is given by S*xy - r: + + >>> S*xy - r + Matrix([ + [1/3], + [1/3], + [1/3]]) + >>> _.norm().n(2) + 0.58 + + If a different xy is used, the norm will be higher: + + >>> xy += ones(2, 1)/10 + >>> (S*xy - r).norm().n(2) + 1.5 + + """ + + if method == 'CH': + return M.cholesky_solve(rhs) + elif method == 'QR': + return M.QRsolve(rhs) + elif method == 'LDL': + return M.LDLsolve(rhs) + elif method == 'PINV': + return M.pinv_solve(rhs) + else: + t = M.H + return (t * M).solve(t * rhs, method=method) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/sparse.py b/.venv/lib/python3.13/site-packages/sympy/matrices/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..95a7b3ca0ac29cf4409ec1eeecd059f9643e9bbc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/sparse.py @@ -0,0 +1,473 @@ +from collections.abc import Callable + +from sympy.core.containers import Dict +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import is_sequence +from sympy.utilities.misc import as_int + +from .matrixbase import MatrixBase +from .repmatrix import MutableRepMatrix, RepMatrix + +from .utilities import _iszero + +from .decompositions import ( + _liupc, _row_structure_symbolic_cholesky, _cholesky_sparse, + _LDLdecomposition_sparse) + +from .solvers import ( + _lower_triangular_solve_sparse, _upper_triangular_solve_sparse) + + +class SparseRepMatrix(RepMatrix): + """ + A sparse matrix (a matrix with a large number of zero elements). + + Examples + ======== + + >>> from sympy import SparseMatrix, ones + >>> SparseMatrix(2, 2, range(4)) + Matrix([ + [0, 1], + [2, 3]]) + >>> SparseMatrix(2, 2, {(1, 1): 2}) + Matrix([ + [0, 0], + [0, 2]]) + + A SparseMatrix can be instantiated from a ragged list of lists: + + >>> SparseMatrix([[1, 2, 3], [1, 2], [1]]) + Matrix([ + [1, 2, 3], + [1, 2, 0], + [1, 0, 0]]) + + For safety, one may include the expected size and then an error + will be raised if the indices of any element are out of range or + (for a flat list) if the total number of elements does not match + the expected shape: + + >>> SparseMatrix(2, 2, [1, 2]) + Traceback (most recent call last): + ... + ValueError: List length (2) != rows*columns (4) + + Here, an error is not raised because the list is not flat and no + element is out of range: + + >>> SparseMatrix(2, 2, [[1, 2]]) + Matrix([ + [1, 2], + [0, 0]]) + + But adding another element to the first (and only) row will cause + an error to be raised: + + >>> SparseMatrix(2, 2, [[1, 2, 3]]) + Traceback (most recent call last): + ... + ValueError: The location (0, 2) is out of designated range: (1, 1) + + To autosize the matrix, pass None for rows: + + >>> SparseMatrix(None, [[1, 2, 3]]) + Matrix([[1, 2, 3]]) + >>> SparseMatrix(None, {(1, 1): 1, (3, 3): 3}) + Matrix([ + [0, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 3]]) + + Values that are themselves a Matrix are automatically expanded: + + >>> SparseMatrix(4, 4, {(1, 1): ones(2)}) + Matrix([ + [0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]]) + + A ValueError is raised if the expanding matrix tries to overwrite + a different element already present: + + >>> SparseMatrix(3, 3, {(0, 0): ones(2), (1, 1): 2}) + Traceback (most recent call last): + ... + ValueError: collision at (1, 1) + + See Also + ======== + DenseMatrix + MutableSparseMatrix + ImmutableSparseMatrix + """ + + @classmethod + def _handle_creation_inputs(cls, *args, **kwargs): + if len(args) == 1 and isinstance(args[0], MatrixBase): + rows = args[0].rows + cols = args[0].cols + smat = args[0].todok() + return rows, cols, smat + + smat = {} + # autosizing + if len(args) == 2 and args[0] is None: + args = [None, None, args[1]] + + if len(args) == 3: + r, c = args[:2] + if r is c is None: + rows = cols = None + elif None in (r, c): + raise ValueError( + 'Pass rows=None and no cols for autosizing.') + else: + rows, cols = as_int(args[0]), as_int(args[1]) + + if isinstance(args[2], Callable): + op = args[2] + + if None in (rows, cols): + raise ValueError( + "{} and {} must be integers for this " + "specification.".format(rows, cols)) + + row_indices = [cls._sympify(i) for i in range(rows)] + col_indices = [cls._sympify(j) for j in range(cols)] + + for i in row_indices: + for j in col_indices: + value = cls._sympify(op(i, j)) + if value != cls.zero: + smat[i, j] = value + + return rows, cols, smat + + elif isinstance(args[2], (dict, Dict)): + def update(i, j, v): + # update smat and make sure there are no collisions + if v: + if (i, j) in smat and v != smat[i, j]: + raise ValueError( + "There is a collision at {} for {} and {}." + .format((i, j), v, smat[i, j]) + ) + smat[i, j] = v + + # manual copy, copy.deepcopy() doesn't work + for (r, c), v in args[2].items(): + if isinstance(v, MatrixBase): + for (i, j), vv in v.todok().items(): + update(r + i, c + j, vv) + elif isinstance(v, (list, tuple)): + _, _, smat = cls._handle_creation_inputs(v, **kwargs) + for i, j in smat: + update(r + i, c + j, smat[i, j]) + else: + v = cls._sympify(v) + update(r, c, cls._sympify(v)) + + elif is_sequence(args[2]): + flat = not any(is_sequence(i) for i in args[2]) + if not flat: + _, _, smat = \ + cls._handle_creation_inputs(args[2], **kwargs) + else: + flat_list = args[2] + if len(flat_list) != rows * cols: + raise ValueError( + "The length of the flat list ({}) does not " + "match the specified size ({} * {})." + .format(len(flat_list), rows, cols) + ) + + for i in range(rows): + for j in range(cols): + value = flat_list[i*cols + j] + value = cls._sympify(value) + if value != cls.zero: + smat[i, j] = value + + if rows is None: # autosizing + keys = smat.keys() + rows = max(r for r, _ in keys) + 1 if keys else 0 + cols = max(c for _, c in keys) + 1 if keys else 0 + + else: + for i, j in smat.keys(): + if i and i >= rows or j and j >= cols: + raise ValueError( + "The location {} is out of the designated range" + "[{}, {}]x[{}, {}]" + .format((i, j), 0, rows - 1, 0, cols - 1) + ) + + return rows, cols, smat + + elif len(args) == 1 and isinstance(args[0], (list, tuple)): + # list of values or lists + v = args[0] + c = 0 + for i, row in enumerate(v): + if not isinstance(row, (list, tuple)): + row = [row] + for j, vv in enumerate(row): + if vv != cls.zero: + smat[i, j] = cls._sympify(vv) + c = max(c, len(row)) + rows = len(v) if c else 0 + cols = c + return rows, cols, smat + + else: + # handle full matrix forms with _handle_creation_inputs + rows, cols, mat = super()._handle_creation_inputs(*args) + for i in range(rows): + for j in range(cols): + value = mat[cols*i + j] + if value != cls.zero: + smat[i, j] = value + + return rows, cols, smat + + @property + def _smat(self): + + sympy_deprecation_warning( + """ + The private _smat attribute of SparseMatrix is deprecated. Use the + .todok() method instead. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-private-matrix-attributes" + ) + + return self.todok() + + def _eval_inverse(self, **kwargs): + return self.inv(method=kwargs.get('method', 'LDL'), + iszerofunc=kwargs.get('iszerofunc', _iszero), + try_block_diag=kwargs.get('try_block_diag', False)) + + def applyfunc(self, f): + """Apply a function to each element of the matrix. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> m = SparseMatrix(2, 2, lambda i, j: i*2+j) + >>> m + Matrix([ + [0, 1], + [2, 3]]) + >>> m.applyfunc(lambda i: 2*i) + Matrix([ + [0, 2], + [4, 6]]) + + """ + if not callable(f): + raise TypeError("`f` must be callable.") + + # XXX: This only applies the function to the nonzero elements of the + # matrix so is inconsistent with DenseMatrix.applyfunc e.g. + # zeros(2, 2).applyfunc(lambda x: x + 1) + dok = {} + for k, v in self.todok().items(): + fv = f(v) + if fv != 0: + dok[k] = fv + + return self._new(self.rows, self.cols, dok) + + def as_immutable(self): + """Returns an Immutable version of this Matrix.""" + from .immutable import ImmutableSparseMatrix + return ImmutableSparseMatrix(self) + + def as_mutable(self): + """Returns a mutable version of this matrix. + + Examples + ======== + + >>> from sympy import ImmutableMatrix + >>> X = ImmutableMatrix([[1, 2], [3, 4]]) + >>> Y = X.as_mutable() + >>> Y[1, 1] = 5 # Can set values in Y + >>> Y + Matrix([ + [1, 2], + [3, 5]]) + """ + return MutableSparseMatrix(self) + + def col_list(self): + """Returns a column-sorted list of non-zero elements of the matrix. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> a=SparseMatrix(((1, 2), (3, 4))) + >>> a + Matrix([ + [1, 2], + [3, 4]]) + >>> a.CL + [(0, 0, 1), (1, 0, 3), (0, 1, 2), (1, 1, 4)] + + See Also + ======== + + sympy.matrices.sparse.SparseMatrix.row_list + """ + return [tuple(k + (self[k],)) for k in sorted(self.todok().keys(), key=lambda k: list(reversed(k)))] + + def nnz(self): + """Returns the number of non-zero elements in Matrix.""" + return len(self.todok()) + + def row_list(self): + """Returns a row-sorted list of non-zero elements of the matrix. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> a = SparseMatrix(((1, 2), (3, 4))) + >>> a + Matrix([ + [1, 2], + [3, 4]]) + >>> a.RL + [(0, 0, 1), (0, 1, 2), (1, 0, 3), (1, 1, 4)] + + See Also + ======== + + sympy.matrices.sparse.SparseMatrix.col_list + """ + return [tuple(k + (self[k],)) for k in + sorted(self.todok().keys(), key=list)] + + def scalar_multiply(self, scalar): + "Scalar element-wise multiplication" + return scalar * self + + def solve_least_squares(self, rhs, method='LDL'): + """Return the least-square fit to the data. + + By default the cholesky_solve routine is used (method='CH'); other + methods of matrix inversion can be used. To find out which are + available, see the docstring of the .inv() method. + + Examples + ======== + + >>> from sympy import SparseMatrix, Matrix, ones + >>> A = Matrix([1, 2, 3]) + >>> B = Matrix([2, 3, 4]) + >>> S = SparseMatrix(A.row_join(B)) + >>> S + Matrix([ + [1, 2], + [2, 3], + [3, 4]]) + + If each line of S represent coefficients of Ax + By + and x and y are [2, 3] then S*xy is: + + >>> r = S*Matrix([2, 3]); r + Matrix([ + [ 8], + [13], + [18]]) + + But let's add 1 to the middle value and then solve for the + least-squares value of xy: + + >>> xy = S.solve_least_squares(Matrix([8, 14, 18])); xy + Matrix([ + [ 5/3], + [10/3]]) + + The error is given by S*xy - r: + + >>> S*xy - r + Matrix([ + [1/3], + [1/3], + [1/3]]) + >>> _.norm().n(2) + 0.58 + + If a different xy is used, the norm will be higher: + + >>> xy += ones(2, 1)/10 + >>> (S*xy - r).norm().n(2) + 1.5 + + """ + t = self.T + return (t*self).inv(method=method)*t*rhs + + def solve(self, rhs, method='LDL'): + """Return solution to self*soln = rhs using given inversion method. + + For a list of possible inversion methods, see the .inv() docstring. + """ + if not self.is_square: + if self.rows < self.cols: + raise ValueError('Under-determined system.') + elif self.rows > self.cols: + raise ValueError('For over-determined system, M, having ' + 'more rows than columns, try M.solve_least_squares(rhs).') + else: + return self.inv(method=method).multiply(rhs) + + RL = property(row_list, None, None, "Alternate faster representation") + CL = property(col_list, None, None, "Alternate faster representation") + + def liupc(self): + return _liupc(self) + + def row_structure_symbolic_cholesky(self): + return _row_structure_symbolic_cholesky(self) + + def cholesky(self, hermitian=True): + return _cholesky_sparse(self, hermitian=hermitian) + + def LDLdecomposition(self, hermitian=True): + return _LDLdecomposition_sparse(self, hermitian=hermitian) + + def lower_triangular_solve(self, rhs): + return _lower_triangular_solve_sparse(self, rhs) + + def upper_triangular_solve(self, rhs): + return _upper_triangular_solve_sparse(self, rhs) + + liupc.__doc__ = _liupc.__doc__ + row_structure_symbolic_cholesky.__doc__ = _row_structure_symbolic_cholesky.__doc__ + cholesky.__doc__ = _cholesky_sparse.__doc__ + LDLdecomposition.__doc__ = _LDLdecomposition_sparse.__doc__ + lower_triangular_solve.__doc__ = lower_triangular_solve.__doc__ + upper_triangular_solve.__doc__ = upper_triangular_solve.__doc__ + + +class MutableSparseMatrix(SparseRepMatrix, MutableRepMatrix): + + @classmethod + def _new(cls, *args, **kwargs): + rows, cols, smat = cls._handle_creation_inputs(*args, **kwargs) + + rep = cls._smat_to_DomainMatrix(rows, cols, smat) + + return cls._fromrep(rep) + + +SparseMatrix = MutableSparseMatrix diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/sparsetools.py b/.venv/lib/python3.13/site-packages/sympy/matrices/sparsetools.py new file mode 100644 index 0000000000000000000000000000000000000000..50048f6dc7e5cf160366963d16427987616ddce7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/sparsetools.py @@ -0,0 +1,300 @@ +from sympy.core.containers import Dict +from sympy.core.symbol import Dummy +from sympy.utilities.iterables import is_sequence +from sympy.utilities.misc import as_int, filldedent + +from .sparse import MutableSparseMatrix as SparseMatrix + + +def _doktocsr(dok): + """Converts a sparse matrix to Compressed Sparse Row (CSR) format. + + Parameters + ========== + + A : contains non-zero elements sorted by key (row, column) + JA : JA[i] is the column corresponding to A[i] + IA : IA[i] contains the index in A for the first non-zero element + of row[i]. Thus IA[i+1] - IA[i] gives number of non-zero + elements row[i]. The length of IA is always 1 more than the + number of rows in the matrix. + + Examples + ======== + + >>> from sympy.matrices.sparsetools import _doktocsr + >>> from sympy import SparseMatrix, diag + >>> m = SparseMatrix(diag(1, 2, 3)) + >>> m[2, 0] = -1 + >>> _doktocsr(m) + [[1, 2, -1, 3], [0, 1, 0, 2], [0, 1, 2, 4], [3, 3]] + + """ + row, JA, A = [list(i) for i in zip(*dok.row_list())] + IA = [0]*((row[0] if row else 0) + 1) + for i, r in enumerate(row): + IA.extend([i]*(r - row[i - 1])) # if i = 0 nothing is extended + IA.extend([len(A)]*(dok.rows - len(IA) + 1)) + shape = [dok.rows, dok.cols] + return [A, JA, IA, shape] + + +def _csrtodok(csr): + """Converts a CSR representation to DOK representation. + + Examples + ======== + + >>> from sympy.matrices.sparsetools import _csrtodok + >>> _csrtodok([[5, 8, 3, 6], [0, 1, 2, 1], [0, 0, 2, 3, 4], [4, 3]]) + Matrix([ + [0, 0, 0], + [5, 8, 0], + [0, 0, 3], + [0, 6, 0]]) + + """ + smat = {} + A, JA, IA, shape = csr + for i in range(len(IA) - 1): + indices = slice(IA[i], IA[i + 1]) + for l, m in zip(A[indices], JA[indices]): + smat[i, m] = l + return SparseMatrix(*shape, smat) + + +def banded(*args, **kwargs): + """Returns a SparseMatrix from the given dictionary describing + the diagonals of the matrix. The keys are positive for upper + diagonals and negative for those below the main diagonal. The + values may be: + + * expressions or single-argument functions, + + * lists or tuples of values, + + * matrices + + Unless dimensions are given, the size of the returned matrix will + be large enough to contain the largest non-zero value provided. + + kwargs + ====== + + rows : rows of the resulting matrix; computed if + not given. + + cols : columns of the resulting matrix; computed if + not given. + + Examples + ======== + + >>> from sympy import banded, ones, Matrix + >>> from sympy.abc import x + + If explicit values are given in tuples, + the matrix will autosize to contain all values, otherwise + a single value is filled onto the entire diagonal: + + >>> banded({1: (1, 2, 3), -1: (4, 5, 6), 0: x}) + Matrix([ + [x, 1, 0, 0], + [4, x, 2, 0], + [0, 5, x, 3], + [0, 0, 6, x]]) + + A function accepting a single argument can be used to fill the + diagonal as a function of diagonal index (which starts at 0). + The size (or shape) of the matrix must be given to obtain more + than a 1x1 matrix: + + >>> s = lambda d: (1 + d)**2 + >>> banded(5, {0: s, 2: s, -2: 2}) + Matrix([ + [1, 0, 1, 0, 0], + [0, 4, 0, 4, 0], + [2, 0, 9, 0, 9], + [0, 2, 0, 16, 0], + [0, 0, 2, 0, 25]]) + + The diagonal of matrices placed on a diagonal will coincide + with the indicated diagonal: + + >>> vert = Matrix([1, 2, 3]) + >>> banded({0: vert}, cols=3) + Matrix([ + [1, 0, 0], + [2, 1, 0], + [3, 2, 1], + [0, 3, 2], + [0, 0, 3]]) + + >>> banded(4, {0: ones(2)}) + Matrix([ + [1, 1, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 1], + [0, 0, 1, 1]]) + + Errors are raised if the designated size will not hold + all values an integral number of times. Here, the rows + are designated as odd (but an even number is required to + hold the off-diagonal 2x2 ones): + + >>> banded({0: 2, 1: ones(2)}, rows=5) + Traceback (most recent call last): + ... + ValueError: + sequence does not fit an integral number of times in the matrix + + And here, an even number of rows is given...but the square + matrix has an even number of columns, too. As we saw + in the previous example, an odd number is required: + + >>> banded(4, {0: 2, 1: ones(2)}) # trying to make 4x4 and cols must be odd + Traceback (most recent call last): + ... + ValueError: + sequence does not fit an integral number of times in the matrix + + A way around having to count rows is to enclosing matrix elements + in a tuple and indicate the desired number of them to the right: + + >>> banded({0: 2, 2: (ones(2),)*3}) + Matrix([ + [2, 0, 1, 1, 0, 0, 0, 0], + [0, 2, 1, 1, 0, 0, 0, 0], + [0, 0, 2, 0, 1, 1, 0, 0], + [0, 0, 0, 2, 1, 1, 0, 0], + [0, 0, 0, 0, 2, 0, 1, 1], + [0, 0, 0, 0, 0, 2, 1, 1]]) + + An error will be raised if more than one value + is written to a given entry. Here, the ones overlap + with the main diagonal if they are placed on the + first diagonal: + + >>> banded({0: (2,)*5, 1: (ones(2),)*3}) + Traceback (most recent call last): + ... + ValueError: collision at (1, 1) + + By placing a 0 at the bottom left of the 2x2 matrix of + ones, the collision is avoided: + + >>> u2 = Matrix([ + ... [1, 1], + ... [0, 1]]) + >>> banded({0: [2]*5, 1: [u2]*3}) + Matrix([ + [2, 1, 1, 0, 0, 0, 0], + [0, 2, 1, 0, 0, 0, 0], + [0, 0, 2, 1, 1, 0, 0], + [0, 0, 0, 2, 1, 0, 0], + [0, 0, 0, 0, 2, 1, 1], + [0, 0, 0, 0, 0, 0, 1]]) + """ + try: + if len(args) not in (1, 2, 3): + raise TypeError + if not isinstance(args[-1], (dict, Dict)): + raise TypeError + if len(args) == 1: + rows = kwargs.get('rows', None) + cols = kwargs.get('cols', None) + if rows is not None: + rows = as_int(rows) + if cols is not None: + cols = as_int(cols) + elif len(args) == 2: + rows = cols = as_int(args[0]) + else: + rows, cols = map(as_int, args[:2]) + # fails with ValueError if any keys are not ints + _ = all(as_int(k) for k in args[-1]) + except (ValueError, TypeError): + raise TypeError(filldedent( + '''unrecognized input to banded: + expecting [[row,] col,] {int: value}''')) + def rc(d): + # return row,col coord of diagonal start + r = -d if d < 0 else 0 + c = 0 if r else d + return r, c + smat = {} + undone = [] + tba = Dummy() + # first handle objects with size + for d, v in args[-1].items(): + r, c = rc(d) + # note: only list and tuple are recognized since this + # will allow other Basic objects like Tuple + # into the matrix if so desired + if isinstance(v, (list, tuple)): + extra = 0 + for i, vi in enumerate(v): + i += extra + if is_sequence(vi): + vi = SparseMatrix(vi) + smat[r + i, c + i] = vi + extra += min(vi.shape) - 1 + else: + smat[r + i, c + i] = vi + elif is_sequence(v): + v = SparseMatrix(v) + rv, cv = v.shape + if rows and cols: + nr, xr = divmod(rows - r, rv) + nc, xc = divmod(cols - c, cv) + x = xr or xc + do = min(nr, nc) + elif rows: + do, x = divmod(rows - r, rv) + elif cols: + do, x = divmod(cols - c, cv) + else: + do = 1 + x = 0 + if x: + raise ValueError(filldedent(''' + sequence does not fit an integral number of times + in the matrix''')) + j = min(v.shape) + for i in range(do): + smat[r, c] = v + r += j + c += j + elif v: + smat[r, c] = tba + undone.append((d, v)) + s = SparseMatrix(None, smat) # to expand matrices + smat = s.todok() + # check for dim errors here + if rows is not None and rows < s.rows: + raise ValueError('Designated rows %s < needed %s' % (rows, s.rows)) + if cols is not None and cols < s.cols: + raise ValueError('Designated cols %s < needed %s' % (cols, s.cols)) + if rows is cols is None: + rows = s.rows + cols = s.cols + elif rows is not None and cols is None: + cols = max(rows, s.cols) + elif cols is not None and rows is None: + rows = max(cols, s.rows) + def update(i, j, v): + # update smat and make sure there are + # no collisions + if v: + if (i, j) in smat and smat[i, j] not in (tba, v): + raise ValueError('collision at %s' % ((i, j),)) + smat[i, j] = v + if undone: + for d, vi in undone: + r, c = rc(d) + v = vi if callable(vi) else lambda _: vi + i = 0 + while r + i < rows and c + i < cols: + update(r + i, c + i, v(i)) + i += 1 + return SparseMatrix(rows, cols, smat) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/subspaces.py b/.venv/lib/python3.13/site-packages/sympy/matrices/subspaces.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab0b71b4289ebaeb6394059c6a7cd49d3a148a1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/subspaces.py @@ -0,0 +1,174 @@ +from .utilities import _iszero + + +def _columnspace(M, simplify=False): + """Returns a list of vectors (Matrix objects) that span columnspace of ``M`` + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix(3, 3, [1, 3, 0, -2, -6, 0, 3, 9, 6]) + >>> M + Matrix([ + [ 1, 3, 0], + [-2, -6, 0], + [ 3, 9, 6]]) + >>> M.columnspace() + [Matrix([ + [ 1], + [-2], + [ 3]]), Matrix([ + [0], + [0], + [6]])] + + See Also + ======== + + nullspace + rowspace + """ + + reduced, pivots = M.echelon_form(simplify=simplify, with_pivots=True) + + return [M.col(i) for i in pivots] + + +def _nullspace(M, simplify=False, iszerofunc=_iszero): + """Returns list of vectors (Matrix objects) that span nullspace of ``M`` + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix(3, 3, [1, 3, 0, -2, -6, 0, 3, 9, 6]) + >>> M + Matrix([ + [ 1, 3, 0], + [-2, -6, 0], + [ 3, 9, 6]]) + >>> M.nullspace() + [Matrix([ + [-3], + [ 1], + [ 0]])] + + See Also + ======== + + columnspace + rowspace + """ + + reduced, pivots = M.rref(iszerofunc=iszerofunc, simplify=simplify) + + free_vars = [i for i in range(M.cols) if i not in pivots] + basis = [] + + for free_var in free_vars: + # for each free variable, we will set it to 1 and all others + # to 0. Then, we will use back substitution to solve the system + vec = [M.zero] * M.cols + vec[free_var] = M.one + + for piv_row, piv_col in enumerate(pivots): + vec[piv_col] -= reduced[piv_row, free_var] + + basis.append(vec) + + return [M._new(M.cols, 1, b) for b in basis] + + +def _rowspace(M, simplify=False): + """Returns a list of vectors that span the row space of ``M``. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix(3, 3, [1, 3, 0, -2, -6, 0, 3, 9, 6]) + >>> M + Matrix([ + [ 1, 3, 0], + [-2, -6, 0], + [ 3, 9, 6]]) + >>> M.rowspace() + [Matrix([[1, 3, 0]]), Matrix([[0, 0, 6]])] + """ + + reduced, pivots = M.echelon_form(simplify=simplify, with_pivots=True) + + return [reduced.row(i) for i in range(len(pivots))] + + +def _orthogonalize(cls, *vecs, normalize=False, rankcheck=False): + """Apply the Gram-Schmidt orthogonalization procedure + to vectors supplied in ``vecs``. + + Parameters + ========== + + vecs + vectors to be made orthogonal + + normalize : bool + If ``True``, return an orthonormal basis. + + rankcheck : bool + If ``True``, the computation does not stop when encountering + linearly dependent vectors. + + If ``False``, it will raise ``ValueError`` when any zero + or linearly dependent vectors are found. + + Returns + ======= + + list + List of orthogonal (or orthonormal) basis vectors. + + Examples + ======== + + >>> from sympy import I, Matrix + >>> v = [Matrix([1, I]), Matrix([1, -I])] + >>> Matrix.orthogonalize(*v) + [Matrix([ + [1], + [I]]), Matrix([ + [ 1], + [-I]])] + + See Also + ======== + + MatrixBase.QRdecomposition + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process + """ + from .decompositions import _QRdecomposition_optional + + if not vecs: + return [] + + all_row_vecs = (vecs[0].rows == 1) + + vecs = [x.vec() for x in vecs] + M = cls.hstack(*vecs) + Q, R = _QRdecomposition_optional(M, normalize=normalize) + + if rankcheck and Q.cols < len(vecs): + raise ValueError("GramSchmidt: vector set not linearly independent") + + ret = [] + for i in range(Q.cols): + if all_row_vecs: + col = cls(Q[:, i].T) + else: + col = cls(Q[:, i]) + ret.append(col) + return ret diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_commonmatrix.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_commonmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..6735adc1a9d4f9934a55c7ee70b087a19d3a48b4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_commonmatrix.py @@ -0,0 +1,1266 @@ +# +# Code for testing deprecated matrix classes. New test code should not be added +# here. Instead, add it to test_matrixbase.py. +# +# This entire test module and the corresponding sympy/matrices/common.py +# module will be removed in a future release. +# +from sympy.testing.pytest import raises, XFAIL, warns_deprecated_sympy + +from sympy.assumptions import Q +from sympy.core.expr import Expr +from sympy.core.add import Add +from sympy.core.function import Function +from sympy.core.kind import NumberKind, UndefinedKind +from sympy.core.numbers import I, Integer, oo, pi, Rational +from sympy.core.singleton import S +from sympy.core.symbol import Symbol, symbols +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.matrices.exceptions import ShapeError, NonSquareMatrixError +from sympy.matrices.kind import MatrixKind +from sympy.matrices.common import ( + _MinimalMatrix, _CastableMatrix, MatrixShaping, MatrixProperties, + MatrixOperations, MatrixArithmetic, MatrixSpecial) +from sympy.matrices.matrices import MatrixCalculus +from sympy.matrices import (Matrix, diag, eye, + matrix_multiply_elementwise, ones, zeros, SparseMatrix, banded, + MutableDenseMatrix, MutableSparseMatrix, ImmutableDenseMatrix, + ImmutableSparseMatrix) +from sympy.polys.polytools import Poly +from sympy.utilities.iterables import flatten +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray as Array + +from sympy.abc import x, y, z + + +def test_matrix_deprecated_isinstance(): + + # Test that e.g. isinstance(M, MatrixCommon) still gives True when M is a + # Matrix for each of the deprecated matrix classes. + + from sympy.matrices.common import ( + MatrixRequired, + MatrixShaping, + MatrixSpecial, + MatrixProperties, + MatrixOperations, + MatrixArithmetic, + MatrixCommon + ) + from sympy.matrices.matrices import ( + MatrixDeterminant, + MatrixReductions, + MatrixSubspaces, + MatrixEigen, + MatrixCalculus, + MatrixDeprecated + ) + from sympy import ( + Matrix, + ImmutableMatrix, + SparseMatrix, + ImmutableSparseMatrix + ) + all_mixins = ( + MatrixRequired, + MatrixShaping, + MatrixSpecial, + MatrixProperties, + MatrixOperations, + MatrixArithmetic, + MatrixCommon, + MatrixDeterminant, + MatrixReductions, + MatrixSubspaces, + MatrixEigen, + MatrixCalculus, + MatrixDeprecated + ) + all_matrices = ( + Matrix, + ImmutableMatrix, + SparseMatrix, + ImmutableSparseMatrix + ) + + Ms = [M([[1, 2], [3, 4]]) for M in all_matrices] + t = () + + for mixin in all_mixins: + for M in Ms: + with warns_deprecated_sympy(): + assert isinstance(M, mixin) is True + with warns_deprecated_sympy(): + assert isinstance(t, mixin) is False + + +# classes to test the deprecated matrix classes. We use warns_deprecated_sympy +# to suppress the deprecation warnings because subclassing the deprecated +# classes causes a warning to be raised. + +with warns_deprecated_sympy(): + class ShapingOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixShaping): + pass + + +def eye_Shaping(n): + return ShapingOnlyMatrix(n, n, lambda i, j: int(i == j)) + + +def zeros_Shaping(n): + return ShapingOnlyMatrix(n, n, lambda i, j: 0) + + +with warns_deprecated_sympy(): + class PropertiesOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixProperties): + pass + + +def eye_Properties(n): + return PropertiesOnlyMatrix(n, n, lambda i, j: int(i == j)) + + +def zeros_Properties(n): + return PropertiesOnlyMatrix(n, n, lambda i, j: 0) + + +with warns_deprecated_sympy(): + class OperationsOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixOperations): + pass + + +def eye_Operations(n): + return OperationsOnlyMatrix(n, n, lambda i, j: int(i == j)) + + +def zeros_Operations(n): + return OperationsOnlyMatrix(n, n, lambda i, j: 0) + + +with warns_deprecated_sympy(): + class ArithmeticOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixArithmetic): + pass + + +def eye_Arithmetic(n): + return ArithmeticOnlyMatrix(n, n, lambda i, j: int(i == j)) + + +def zeros_Arithmetic(n): + return ArithmeticOnlyMatrix(n, n, lambda i, j: 0) + + +with warns_deprecated_sympy(): + class SpecialOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixSpecial): + pass + + +with warns_deprecated_sympy(): + class CalculusOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixCalculus): + pass + + +def test__MinimalMatrix(): + x = _MinimalMatrix(2, 3, [1, 2, 3, 4, 5, 6]) + assert x.rows == 2 + assert x.cols == 3 + assert x[2] == 3 + assert x[1, 1] == 5 + assert list(x) == [1, 2, 3, 4, 5, 6] + assert list(x[1, :]) == [4, 5, 6] + assert list(x[:, 1]) == [2, 5] + assert list(x[:, :]) == list(x) + assert x[:, :] == x + assert _MinimalMatrix(x) == x + assert _MinimalMatrix([[1, 2, 3], [4, 5, 6]]) == x + assert _MinimalMatrix(([1, 2, 3], [4, 5, 6])) == x + assert _MinimalMatrix([(1, 2, 3), (4, 5, 6)]) == x + assert _MinimalMatrix(((1, 2, 3), (4, 5, 6))) == x + assert not (_MinimalMatrix([[1, 2], [3, 4], [5, 6]]) == x) + + +def test_kind(): + assert Matrix([[1, 2], [3, 4]]).kind == MatrixKind(NumberKind) + assert Matrix([[0, 0], [0, 0]]).kind == MatrixKind(NumberKind) + assert Matrix(0, 0, []).kind == MatrixKind(NumberKind) + assert Matrix([[x]]).kind == MatrixKind(NumberKind) + assert Matrix([[1, Matrix([[1]])]]).kind == MatrixKind(UndefinedKind) + assert SparseMatrix([[1]]).kind == MatrixKind(NumberKind) + assert SparseMatrix([[1, Matrix([[1]])]]).kind == MatrixKind(UndefinedKind) + + +# ShapingOnlyMatrix tests +def test_vec(): + m = ShapingOnlyMatrix(2, 2, [1, 3, 2, 4]) + m_vec = m.vec() + assert m_vec.cols == 1 + for i in range(4): + assert m_vec[i] == i + 1 + + +def test_todok(): + a, b, c, d = symbols('a:d') + m1 = MutableDenseMatrix([[a, b], [c, d]]) + m2 = ImmutableDenseMatrix([[a, b], [c, d]]) + m3 = MutableSparseMatrix([[a, b], [c, d]]) + m4 = ImmutableSparseMatrix([[a, b], [c, d]]) + assert m1.todok() == m2.todok() == m3.todok() == m4.todok() == \ + {(0, 0): a, (0, 1): b, (1, 0): c, (1, 1): d} + + +def test_tolist(): + lst = [[S.One, S.Half, x*y, S.Zero], [x, y, z, x**2], [y, -S.One, z*x, 3]] + flat_lst = [S.One, S.Half, x*y, S.Zero, x, y, z, x**2, y, -S.One, z*x, 3] + m = ShapingOnlyMatrix(3, 4, flat_lst) + assert m.tolist() == lst + +def test_todod(): + m = ShapingOnlyMatrix(3, 2, [[S.One, 0], [0, S.Half], [x, 0]]) + dict = {0: {0: S.One}, 1: {1: S.Half}, 2: {0: x}} + assert m.todod() == dict + +def test_row_col_del(): + e = ShapingOnlyMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + raises(IndexError, lambda: e.row_del(5)) + raises(IndexError, lambda: e.row_del(-5)) + raises(IndexError, lambda: e.col_del(5)) + raises(IndexError, lambda: e.col_del(-5)) + + assert e.row_del(2) == e.row_del(-1) == Matrix([[1, 2, 3], [4, 5, 6]]) + assert e.col_del(2) == e.col_del(-1) == Matrix([[1, 2], [4, 5], [7, 8]]) + + assert e.row_del(1) == e.row_del(-2) == Matrix([[1, 2, 3], [7, 8, 9]]) + assert e.col_del(1) == e.col_del(-2) == Matrix([[1, 3], [4, 6], [7, 9]]) + + +def test_get_diag_blocks1(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + assert a.get_diag_blocks() == [a] + assert b.get_diag_blocks() == [b] + assert c.get_diag_blocks() == [c] + + +def test_get_diag_blocks2(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + A, B, C, D = diag(a, b, b), diag(a, b, c), diag(a, c, b), diag(c, c, b) + A = ShapingOnlyMatrix(A.rows, A.cols, A) + B = ShapingOnlyMatrix(B.rows, B.cols, B) + C = ShapingOnlyMatrix(C.rows, C.cols, C) + D = ShapingOnlyMatrix(D.rows, D.cols, D) + + assert A.get_diag_blocks() == [a, b, b] + assert B.get_diag_blocks() == [a, b, c] + assert C.get_diag_blocks() == [a, c, b] + assert D.get_diag_blocks() == [c, c, b] + + +def test_shape(): + m = ShapingOnlyMatrix(1, 2, [0, 0]) + assert m.shape == (1, 2) + + +def test_reshape(): + m0 = eye_Shaping(3) + assert m0.reshape(1, 9) == Matrix(1, 9, (1, 0, 0, 0, 1, 0, 0, 0, 1)) + m1 = ShapingOnlyMatrix(3, 4, lambda i, j: i + j) + assert m1.reshape( + 4, 3) == Matrix(((0, 1, 2), (3, 1, 2), (3, 4, 2), (3, 4, 5))) + assert m1.reshape(2, 6) == Matrix(((0, 1, 2, 3, 1, 2), (3, 4, 2, 3, 4, 5))) + + +def test_row_col(): + m = ShapingOnlyMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + assert m.row(0) == Matrix(1, 3, [1, 2, 3]) + assert m.col(0) == Matrix(3, 1, [1, 4, 7]) + + +def test_row_join(): + assert eye_Shaping(3).row_join(Matrix([7, 7, 7])) == \ + Matrix([[1, 0, 0, 7], + [0, 1, 0, 7], + [0, 0, 1, 7]]) + + +def test_col_join(): + assert eye_Shaping(3).col_join(Matrix([[7, 7, 7]])) == \ + Matrix([[1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [7, 7, 7]]) + + +def test_row_insert(): + r4 = Matrix([[4, 4, 4]]) + for i in range(-4, 5): + l = [1, 0, 0] + l.insert(i, 4) + assert flatten(eye_Shaping(3).row_insert(i, r4).col(0).tolist()) == l + + +def test_col_insert(): + c4 = Matrix([4, 4, 4]) + for i in range(-4, 5): + l = [0, 0, 0] + l.insert(i, 4) + assert flatten(zeros_Shaping(3).col_insert(i, c4).row(0).tolist()) == l + # issue 13643 + assert eye_Shaping(6).col_insert(3, Matrix([[2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]])) == \ + Matrix([[1, 0, 0, 2, 2, 0, 0, 0], + [0, 1, 0, 2, 2, 0, 0, 0], + [0, 0, 1, 2, 2, 0, 0, 0], + [0, 0, 0, 2, 2, 1, 0, 0], + [0, 0, 0, 2, 2, 0, 1, 0], + [0, 0, 0, 2, 2, 0, 0, 1]]) + + +def test_extract(): + m = ShapingOnlyMatrix(4, 3, lambda i, j: i*3 + j) + assert m.extract([0, 1, 3], [0, 1]) == Matrix(3, 2, [0, 1, 3, 4, 9, 10]) + assert m.extract([0, 3], [0, 0, 2]) == Matrix(2, 3, [0, 0, 2, 9, 9, 11]) + assert m.extract(range(4), range(3)) == m + raises(IndexError, lambda: m.extract([4], [0])) + raises(IndexError, lambda: m.extract([0], [3])) + + +def test_hstack(): + m = ShapingOnlyMatrix(4, 3, lambda i, j: i*3 + j) + m2 = ShapingOnlyMatrix(3, 4, lambda i, j: i*3 + j) + assert m == m.hstack(m) + assert m.hstack(m, m, m) == ShapingOnlyMatrix.hstack(m, m, m) == Matrix([ + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5], + [6, 7, 8, 6, 7, 8, 6, 7, 8], + [9, 10, 11, 9, 10, 11, 9, 10, 11]]) + raises(ShapeError, lambda: m.hstack(m, m2)) + assert Matrix.hstack() == Matrix() + + # test regression #12938 + M1 = Matrix.zeros(0, 0) + M2 = Matrix.zeros(0, 1) + M3 = Matrix.zeros(0, 2) + M4 = Matrix.zeros(0, 3) + m = ShapingOnlyMatrix.hstack(M1, M2, M3, M4) + assert m.rows == 0 and m.cols == 6 + + +def test_vstack(): + m = ShapingOnlyMatrix(4, 3, lambda i, j: i*3 + j) + m2 = ShapingOnlyMatrix(3, 4, lambda i, j: i*3 + j) + assert m == m.vstack(m) + assert m.vstack(m, m, m) == ShapingOnlyMatrix.vstack(m, m, m) == Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11], + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11], + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11]]) + raises(ShapeError, lambda: m.vstack(m, m2)) + assert Matrix.vstack() == Matrix() + + +# PropertiesOnlyMatrix tests +def test_atoms(): + m = PropertiesOnlyMatrix(2, 2, [1, 2, x, 1 - 1/x]) + assert m.atoms() == {S.One, S(2), S.NegativeOne, x} + assert m.atoms(Symbol) == {x} + + +def test_free_symbols(): + assert PropertiesOnlyMatrix([[x], [0]]).free_symbols == {x} + + +def test_has(): + A = PropertiesOnlyMatrix(((x, y), (2, 3))) + assert A.has(x) + assert not A.has(z) + assert A.has(Symbol) + + A = PropertiesOnlyMatrix(((2, y), (2, 3))) + assert not A.has(x) + + +def test_is_anti_symmetric(): + x = symbols('x') + assert PropertiesOnlyMatrix(2, 1, [1, 2]).is_anti_symmetric() is False + m = PropertiesOnlyMatrix(3, 3, [0, x**2 + 2*x + 1, y, -(x + 1)**2, 0, x*y, -y, -x*y, 0]) + assert m.is_anti_symmetric() is True + assert m.is_anti_symmetric(simplify=False) is False + assert m.is_anti_symmetric(simplify=lambda x: x) is False + + m = PropertiesOnlyMatrix(3, 3, [x.expand() for x in m]) + assert m.is_anti_symmetric(simplify=False) is True + m = PropertiesOnlyMatrix(3, 3, [x.expand() for x in [S.One] + list(m)[1:]]) + assert m.is_anti_symmetric() is False + + +def test_diagonal_symmetrical(): + m = PropertiesOnlyMatrix(2, 2, [0, 1, 1, 0]) + assert not m.is_diagonal() + assert m.is_symmetric() + assert m.is_symmetric(simplify=False) + + m = PropertiesOnlyMatrix(2, 2, [1, 0, 0, 1]) + assert m.is_diagonal() + + m = PropertiesOnlyMatrix(3, 3, diag(1, 2, 3)) + assert m.is_diagonal() + assert m.is_symmetric() + + m = PropertiesOnlyMatrix(3, 3, [1, 0, 0, 0, 2, 0, 0, 0, 3]) + assert m == diag(1, 2, 3) + + m = PropertiesOnlyMatrix(2, 3, zeros(2, 3)) + assert not m.is_symmetric() + assert m.is_diagonal() + + m = PropertiesOnlyMatrix(((5, 0), (0, 6), (0, 0))) + assert m.is_diagonal() + + m = PropertiesOnlyMatrix(((5, 0, 0), (0, 6, 0))) + assert m.is_diagonal() + + m = Matrix(3, 3, [1, x**2 + 2*x + 1, y, (x + 1)**2, 2, 0, y, 0, 3]) + assert m.is_symmetric() + assert not m.is_symmetric(simplify=False) + assert m.expand().is_symmetric(simplify=False) + + +def test_is_hermitian(): + a = PropertiesOnlyMatrix([[1, I], [-I, 1]]) + assert a.is_hermitian + a = PropertiesOnlyMatrix([[2*I, I], [-I, 1]]) + assert a.is_hermitian is False + a = PropertiesOnlyMatrix([[x, I], [-I, 1]]) + assert a.is_hermitian is None + a = PropertiesOnlyMatrix([[x, 1], [-I, 1]]) + assert a.is_hermitian is False + + +def test_is_Identity(): + assert eye_Properties(3).is_Identity + assert not PropertiesOnlyMatrix(zeros(3)).is_Identity + assert not PropertiesOnlyMatrix(ones(3)).is_Identity + # issue 6242 + assert not PropertiesOnlyMatrix([[1, 0, 0]]).is_Identity + + +def test_is_symbolic(): + a = PropertiesOnlyMatrix([[x, x], [x, x]]) + assert a.is_symbolic() is True + a = PropertiesOnlyMatrix([[1, 2, 3, 4], [5, 6, 7, 8]]) + assert a.is_symbolic() is False + a = PropertiesOnlyMatrix([[1, 2, 3, 4], [5, 6, x, 8]]) + assert a.is_symbolic() is True + a = PropertiesOnlyMatrix([[1, x, 3]]) + assert a.is_symbolic() is True + a = PropertiesOnlyMatrix([[1, 2, 3]]) + assert a.is_symbolic() is False + a = PropertiesOnlyMatrix([[1], [x], [3]]) + assert a.is_symbolic() is True + a = PropertiesOnlyMatrix([[1], [2], [3]]) + assert a.is_symbolic() is False + + +def test_is_upper(): + a = PropertiesOnlyMatrix([[1, 2, 3]]) + assert a.is_upper is True + a = PropertiesOnlyMatrix([[1], [2], [3]]) + assert a.is_upper is False + + +def test_is_lower(): + a = PropertiesOnlyMatrix([[1, 2, 3]]) + assert a.is_lower is False + a = PropertiesOnlyMatrix([[1], [2], [3]]) + assert a.is_lower is True + + +def test_is_square(): + m = PropertiesOnlyMatrix([[1], [1]]) + m2 = PropertiesOnlyMatrix([[2, 2], [2, 2]]) + assert not m.is_square + assert m2.is_square + + +def test_is_symmetric(): + m = PropertiesOnlyMatrix(2, 2, [0, 1, 1, 0]) + assert m.is_symmetric() + m = PropertiesOnlyMatrix(2, 2, [0, 1, 0, 1]) + assert not m.is_symmetric() + + +def test_is_hessenberg(): + A = PropertiesOnlyMatrix([[3, 4, 1], [2, 4, 5], [0, 1, 2]]) + assert A.is_upper_hessenberg + A = PropertiesOnlyMatrix(3, 3, [3, 2, 0, 4, 4, 1, 1, 5, 2]) + assert A.is_lower_hessenberg + A = PropertiesOnlyMatrix(3, 3, [3, 2, -1, 4, 4, 1, 1, 5, 2]) + assert A.is_lower_hessenberg is False + assert A.is_upper_hessenberg is False + + A = PropertiesOnlyMatrix([[3, 4, 1], [2, 4, 5], [3, 1, 2]]) + assert not A.is_upper_hessenberg + + +def test_is_zero(): + assert PropertiesOnlyMatrix(0, 0, []).is_zero_matrix + assert PropertiesOnlyMatrix([[0, 0], [0, 0]]).is_zero_matrix + assert PropertiesOnlyMatrix(zeros(3, 4)).is_zero_matrix + assert not PropertiesOnlyMatrix(eye(3)).is_zero_matrix + assert PropertiesOnlyMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert PropertiesOnlyMatrix([[x, 1], [0, 0]]).is_zero_matrix == False + a = Symbol('a', nonzero=True) + assert PropertiesOnlyMatrix([[a, 0], [0, 0]]).is_zero_matrix == False + + +def test_values(): + assert set(PropertiesOnlyMatrix(2, 2, [0, 1, 2, 3] + ).values()) == {1, 2, 3} + x = Symbol('x', real=True) + assert set(PropertiesOnlyMatrix(2, 2, [x, 0, 0, 1] + ).values()) == {x, 1} + + +# OperationsOnlyMatrix tests +def test_applyfunc(): + m0 = OperationsOnlyMatrix(eye(3)) + assert m0.applyfunc(lambda x: 2*x) == eye(3)*2 + assert m0.applyfunc(lambda x: 0) == zeros(3) + assert m0.applyfunc(lambda x: 1) == ones(3) + + +def test_adjoint(): + dat = [[0, I], [1, 0]] + ans = OperationsOnlyMatrix([[0, 1], [-I, 0]]) + assert ans.adjoint() == Matrix(dat) + + +def test_as_real_imag(): + m1 = OperationsOnlyMatrix(2, 2, [1, 2, 3, 4]) + m3 = OperationsOnlyMatrix(2, 2, + [1 + S.ImaginaryUnit, 2 + 2*S.ImaginaryUnit, + 3 + 3*S.ImaginaryUnit, 4 + 4*S.ImaginaryUnit]) + + a, b = m3.as_real_imag() + assert a == m1 + assert b == m1 + + +def test_conjugate(): + M = OperationsOnlyMatrix([[0, I, 5], + [1, 2, 0]]) + + assert M.T == Matrix([[0, 1], + [I, 2], + [5, 0]]) + + assert M.C == Matrix([[0, -I, 5], + [1, 2, 0]]) + assert M.C == M.conjugate() + + assert M.H == M.T.C + assert M.H == Matrix([[ 0, 1], + [-I, 2], + [ 5, 0]]) + + +def test_doit(): + a = OperationsOnlyMatrix([[Add(x, x, evaluate=False)]]) + assert a[0] != 2*x + assert a.doit() == Matrix([[2*x]]) + + +def test_evalf(): + a = OperationsOnlyMatrix(2, 1, [sqrt(5), 6]) + assert all(a.evalf()[i] == a[i].evalf() for i in range(2)) + assert all(a.evalf(2)[i] == a[i].evalf(2) for i in range(2)) + assert all(a.n(2)[i] == a[i].n(2) for i in range(2)) + + +def test_expand(): + m0 = OperationsOnlyMatrix([[x*(x + y), 2], [((x + y)*y)*x, x*(y + x*(x + y))]]) + # Test if expand() returns a matrix + m1 = m0.expand() + assert m1 == Matrix( + [[x*y + x**2, 2], [x*y**2 + y*x**2, x*y + y*x**2 + x**3]]) + + a = Symbol('a', real=True) + + assert OperationsOnlyMatrix(1, 1, [exp(I*a)]).expand(complex=True) == \ + Matrix([cos(a) + I*sin(a)]) + + +def test_refine(): + m0 = OperationsOnlyMatrix([[Abs(x)**2, sqrt(x**2)], + [sqrt(x**2)*Abs(y)**2, sqrt(y**2)*Abs(x)**2]]) + m1 = m0.refine(Q.real(x) & Q.real(y)) + assert m1 == Matrix([[x**2, Abs(x)], [y**2*Abs(x), x**2*Abs(y)]]) + + m1 = m0.refine(Q.positive(x) & Q.positive(y)) + assert m1 == Matrix([[x**2, x], [x*y**2, x**2*y]]) + + m1 = m0.refine(Q.negative(x) & Q.negative(y)) + assert m1 == Matrix([[x**2, -x], [-x*y**2, -x**2*y]]) + + +def test_replace(): + F, G = symbols('F, G', cls=Function) + K = OperationsOnlyMatrix(2, 2, lambda i, j: G(i+j)) + M = OperationsOnlyMatrix(2, 2, lambda i, j: F(i+j)) + N = M.replace(F, G) + assert N == K + + +def test_replace_map(): + F, G = symbols('F, G', cls=Function) + K = OperationsOnlyMatrix(2, 2, [(G(0), {F(0): G(0)}), (G(1), {F(1): G(1)}), (G(1), {F(1) \ + : G(1)}), (G(2), {F(2): G(2)})]) + M = OperationsOnlyMatrix(2, 2, lambda i, j: F(i+j)) + N = M.replace(F, G, True) + assert N == K + + +def test_rot90(): + A = Matrix([[1, 2], [3, 4]]) + assert A == A.rot90(0) == A.rot90(4) + assert A.rot90(2) == A.rot90(-2) == A.rot90(6) == Matrix(((4, 3), (2, 1))) + assert A.rot90(3) == A.rot90(-1) == A.rot90(7) == Matrix(((2, 4), (1, 3))) + assert A.rot90() == A.rot90(-7) == A.rot90(-3) == Matrix(((3, 1), (4, 2))) + +def test_simplify(): + n = Symbol('n') + f = Function('f') + + M = OperationsOnlyMatrix([[ 1/x + 1/y, (x + x*y) / x ], + [ (f(x) + y*f(x))/f(x), 2 * (1/n - cos(n * pi)/n) / pi ]]) + assert M.simplify() == Matrix([[ (x + y)/(x * y), 1 + y ], + [ 1 + y, 2*((1 - 1*cos(pi*n))/(pi*n)) ]]) + eq = (1 + x)**2 + M = OperationsOnlyMatrix([[eq]]) + assert M.simplify() == Matrix([[eq]]) + assert M.simplify(ratio=oo) == Matrix([[eq.simplify(ratio=oo)]]) + + # https://github.com/sympy/sympy/issues/19353 + m = Matrix([[30, 2], [3, 4]]) + assert (1/(m.trace())).simplify() == Rational(1, 34) + + +def test_subs(): + assert OperationsOnlyMatrix([[1, x], [x, 4]]).subs(x, 5) == Matrix([[1, 5], [5, 4]]) + assert OperationsOnlyMatrix([[x, 2], [x + y, 4]]).subs([[x, -1], [y, -2]]) == \ + Matrix([[-1, 2], [-3, 4]]) + assert OperationsOnlyMatrix([[x, 2], [x + y, 4]]).subs([(x, -1), (y, -2)]) == \ + Matrix([[-1, 2], [-3, 4]]) + assert OperationsOnlyMatrix([[x, 2], [x + y, 4]]).subs({x: -1, y: -2}) == \ + Matrix([[-1, 2], [-3, 4]]) + assert OperationsOnlyMatrix([[x*y]]).subs({x: y - 1, y: x - 1}, simultaneous=True) == \ + Matrix([[(x - 1)*(y - 1)]]) + + +def test_trace(): + M = OperationsOnlyMatrix([[1, 0, 0], + [0, 5, 0], + [0, 0, 8]]) + assert M.trace() == 14 + + +def test_xreplace(): + assert OperationsOnlyMatrix([[1, x], [x, 4]]).xreplace({x: 5}) == \ + Matrix([[1, 5], [5, 4]]) + assert OperationsOnlyMatrix([[x, 2], [x + y, 4]]).xreplace({x: -1, y: -2}) == \ + Matrix([[-1, 2], [-3, 4]]) + + +def test_permute(): + a = OperationsOnlyMatrix(3, 4, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + + raises(IndexError, lambda: a.permute([[0, 5]])) + raises(ValueError, lambda: a.permute(Symbol('x'))) + b = a.permute_rows([[0, 2], [0, 1]]) + assert a.permute([[0, 2], [0, 1]]) == b == Matrix([ + [5, 6, 7, 8], + [9, 10, 11, 12], + [1, 2, 3, 4]]) + + b = a.permute_cols([[0, 2], [0, 1]]) + assert a.permute([[0, 2], [0, 1]], orientation='cols') == b ==\ + Matrix([ + [ 2, 3, 1, 4], + [ 6, 7, 5, 8], + [10, 11, 9, 12]]) + + b = a.permute_cols([[0, 2], [0, 1]], direction='backward') + assert a.permute([[0, 2], [0, 1]], orientation='cols', direction='backward') == b ==\ + Matrix([ + [ 3, 1, 2, 4], + [ 7, 5, 6, 8], + [11, 9, 10, 12]]) + + assert a.permute([1, 2, 0, 3]) == Matrix([ + [5, 6, 7, 8], + [9, 10, 11, 12], + [1, 2, 3, 4]]) + + from sympy.combinatorics import Permutation + assert a.permute(Permutation([1, 2, 0, 3])) == Matrix([ + [5, 6, 7, 8], + [9, 10, 11, 12], + [1, 2, 3, 4]]) + +def test_upper_triangular(): + + A = OperationsOnlyMatrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] + ]) + + R = A.upper_triangular(2) + assert R == OperationsOnlyMatrix([ + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0], + [0, 0, 0, 0] + ]) + + R = A.upper_triangular(-2) + assert R == OperationsOnlyMatrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [0, 1, 1, 1] + ]) + + R = A.upper_triangular() + assert R == OperationsOnlyMatrix([ + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 1] + ]) + +def test_lower_triangular(): + A = OperationsOnlyMatrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] + ]) + + L = A.lower_triangular() + assert L == ArithmeticOnlyMatrix([ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]]) + + L = A.lower_triangular(2) + assert L == ArithmeticOnlyMatrix([ + [1, 1, 1, 0], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] + ]) + + L = A.lower_triangular(-2) + assert L == ArithmeticOnlyMatrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [1, 1, 0, 0] + ]) + + +# ArithmeticOnlyMatrix tests +def test_abs(): + m = ArithmeticOnlyMatrix([[1, -2], [x, y]]) + assert abs(m) == ArithmeticOnlyMatrix([[1, 2], [Abs(x), Abs(y)]]) + + +def test_add(): + m = ArithmeticOnlyMatrix([[1, 2, 3], [x, y, x], [2*y, -50, z*x]]) + assert m + m == ArithmeticOnlyMatrix([[2, 4, 6], [2*x, 2*y, 2*x], [4*y, -100, 2*z*x]]) + n = ArithmeticOnlyMatrix(1, 2, [1, 2]) + raises(ShapeError, lambda: m + n) + + +def test_multiplication(): + a = ArithmeticOnlyMatrix(( + (1, 2), + (3, 1), + (0, 6), + )) + + b = ArithmeticOnlyMatrix(( + (1, 2), + (3, 0), + )) + + raises(ShapeError, lambda: b*a) + raises(TypeError, lambda: a*{}) + + c = a*b + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + try: + eval('c = a @ b') + except SyntaxError: + pass + else: + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + h = a.multiply_elementwise(c) + assert h == matrix_multiply_elementwise(a, c) + assert h[0, 0] == 7 + assert h[0, 1] == 4 + assert h[1, 0] == 18 + assert h[1, 1] == 6 + assert h[2, 0] == 0 + assert h[2, 1] == 0 + raises(ShapeError, lambda: a.multiply_elementwise(b)) + + c = b * Symbol("x") + assert isinstance(c, ArithmeticOnlyMatrix) + assert c[0, 0] == x + assert c[0, 1] == 2*x + assert c[1, 0] == 3*x + assert c[1, 1] == 0 + + c2 = x * b + assert c == c2 + + c = 5 * b + assert isinstance(c, ArithmeticOnlyMatrix) + assert c[0, 0] == 5 + assert c[0, 1] == 2*5 + assert c[1, 0] == 3*5 + assert c[1, 1] == 0 + + try: + eval('c = 5 @ b') + except SyntaxError: + pass + else: + assert isinstance(c, ArithmeticOnlyMatrix) + assert c[0, 0] == 5 + assert c[0, 1] == 2*5 + assert c[1, 0] == 3*5 + assert c[1, 1] == 0 + + # https://github.com/sympy/sympy/issues/22353 + A = Matrix(ones(3, 1)) + _h = -Rational(1, 2) + B = Matrix([_h, _h, _h]) + assert A.multiply_elementwise(B) == Matrix([ + [_h], + [_h], + [_h]]) + + +def test_matmul(): + a = Matrix([[1, 2], [3, 4]]) + + assert a.__matmul__(2) == NotImplemented + + assert a.__rmatmul__(2) == NotImplemented + + #This is done this way because @ is only supported in Python 3.5+ + #To check 2@a case + try: + eval('2 @ a') + except SyntaxError: + pass + except TypeError: #TypeError is raised in case of NotImplemented is returned + pass + + #Check a@2 case + try: + eval('a @ 2') + except SyntaxError: + pass + except TypeError: #TypeError is raised in case of NotImplemented is returned + pass + + +def test_non_matmul(): + """ + Test that if explicitly specified as non-matrix, mul reverts + to scalar multiplication. + """ + class foo(Expr): + is_Matrix=False + is_MatrixLike=False + shape = (1, 1) + + A = Matrix([[1, 2], [3, 4]]) + b = foo() + assert b*A == Matrix([[b, 2*b], [3*b, 4*b]]) + assert A*b == Matrix([[b, 2*b], [3*b, 4*b]]) + + +def test_power(): + raises(NonSquareMatrixError, lambda: Matrix((1, 2))**2) + + A = ArithmeticOnlyMatrix([[2, 3], [4, 5]]) + assert (A**5)[:] == (6140, 8097, 10796, 14237) + A = ArithmeticOnlyMatrix([[2, 1, 3], [4, 2, 4], [6, 12, 1]]) + assert (A**3)[:] == (290, 262, 251, 448, 440, 368, 702, 954, 433) + assert A**0 == eye(3) + assert A**1 == A + assert (ArithmeticOnlyMatrix([[2]]) ** 100)[0, 0] == 2**100 + assert ArithmeticOnlyMatrix([[1, 2], [3, 4]])**Integer(2) == ArithmeticOnlyMatrix([[7, 10], [15, 22]]) + A = Matrix([[1,2],[4,5]]) + assert A.pow(20, method='cayley') == A.pow(20, method='multiply') + +def test_neg(): + n = ArithmeticOnlyMatrix(1, 2, [1, 2]) + assert -n == ArithmeticOnlyMatrix(1, 2, [-1, -2]) + + +def test_sub(): + n = ArithmeticOnlyMatrix(1, 2, [1, 2]) + assert n - n == ArithmeticOnlyMatrix(1, 2, [0, 0]) + + +def test_div(): + n = ArithmeticOnlyMatrix(1, 2, [1, 2]) + assert n/2 == ArithmeticOnlyMatrix(1, 2, [S.Half, S(2)/2]) + +# SpecialOnlyMatrix tests +def test_eye(): + assert list(SpecialOnlyMatrix.eye(2, 2)) == [1, 0, 0, 1] + assert list(SpecialOnlyMatrix.eye(2)) == [1, 0, 0, 1] + assert type(SpecialOnlyMatrix.eye(2)) == SpecialOnlyMatrix + assert type(SpecialOnlyMatrix.eye(2, cls=Matrix)) == Matrix + + +def test_ones(): + assert list(SpecialOnlyMatrix.ones(2, 2)) == [1, 1, 1, 1] + assert list(SpecialOnlyMatrix.ones(2)) == [1, 1, 1, 1] + assert SpecialOnlyMatrix.ones(2, 3) == Matrix([[1, 1, 1], [1, 1, 1]]) + assert type(SpecialOnlyMatrix.ones(2)) == SpecialOnlyMatrix + assert type(SpecialOnlyMatrix.ones(2, cls=Matrix)) == Matrix + + +def test_zeros(): + assert list(SpecialOnlyMatrix.zeros(2, 2)) == [0, 0, 0, 0] + assert list(SpecialOnlyMatrix.zeros(2)) == [0, 0, 0, 0] + assert SpecialOnlyMatrix.zeros(2, 3) == Matrix([[0, 0, 0], [0, 0, 0]]) + assert type(SpecialOnlyMatrix.zeros(2)) == SpecialOnlyMatrix + assert type(SpecialOnlyMatrix.zeros(2, cls=Matrix)) == Matrix + + +def test_diag_make(): + diag = SpecialOnlyMatrix.diag + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + assert diag(a, b, b) == Matrix([ + [1, 2, 0, 0, 0, 0], + [2, 3, 0, 0, 0, 0], + [0, 0, 3, x, 0, 0], + [0, 0, y, 3, 0, 0], + [0, 0, 0, 0, 3, x], + [0, 0, 0, 0, y, 3], + ]) + assert diag(a, b, c) == Matrix([ + [1, 2, 0, 0, 0, 0, 0], + [2, 3, 0, 0, 0, 0, 0], + [0, 0, 3, x, 0, 0, 0], + [0, 0, y, 3, 0, 0, 0], + [0, 0, 0, 0, 3, x, 3], + [0, 0, 0, 0, y, 3, z], + [0, 0, 0, 0, x, y, z], + ]) + assert diag(a, c, b) == Matrix([ + [1, 2, 0, 0, 0, 0, 0], + [2, 3, 0, 0, 0, 0, 0], + [0, 0, 3, x, 3, 0, 0], + [0, 0, y, 3, z, 0, 0], + [0, 0, x, y, z, 0, 0], + [0, 0, 0, 0, 0, 3, x], + [0, 0, 0, 0, 0, y, 3], + ]) + a = Matrix([x, y, z]) + b = Matrix([[1, 2], [3, 4]]) + c = Matrix([[5, 6]]) + # this "wandering diagonal" is what makes this + # a block diagonal where each block is independent + # of the others + assert diag(a, 7, b, c) == Matrix([ + [x, 0, 0, 0, 0, 0], + [y, 0, 0, 0, 0, 0], + [z, 0, 0, 0, 0, 0], + [0, 7, 0, 0, 0, 0], + [0, 0, 1, 2, 0, 0], + [0, 0, 3, 4, 0, 0], + [0, 0, 0, 0, 5, 6]]) + raises(ValueError, lambda: diag(a, 7, b, c, rows=5)) + assert diag(1) == Matrix([[1]]) + assert diag(1, rows=2) == Matrix([[1, 0], [0, 0]]) + assert diag(1, cols=2) == Matrix([[1, 0], [0, 0]]) + assert diag(1, rows=3, cols=2) == Matrix([[1, 0], [0, 0], [0, 0]]) + assert diag(*[2, 3]) == Matrix([ + [2, 0], + [0, 3]]) + assert diag(Matrix([2, 3])) == Matrix([ + [2], + [3]]) + assert diag([1, [2, 3], 4], unpack=False) == \ + diag([[1], [2, 3], [4]], unpack=False) == Matrix([ + [1, 0], + [2, 3], + [4, 0]]) + assert type(diag(1)) == SpecialOnlyMatrix + assert type(diag(1, cls=Matrix)) == Matrix + assert Matrix.diag([1, 2, 3]) == Matrix.diag(1, 2, 3) + assert Matrix.diag([1, 2, 3], unpack=False).shape == (3, 1) + assert Matrix.diag([[1, 2, 3]]).shape == (3, 1) + assert Matrix.diag([[1, 2, 3]], unpack=False).shape == (1, 3) + assert Matrix.diag([[[1, 2, 3]]]).shape == (1, 3) + # kerning can be used to move the starting point + assert Matrix.diag(ones(0, 2), 1, 2) == Matrix([ + [0, 0, 1, 0], + [0, 0, 0, 2]]) + assert Matrix.diag(ones(2, 0), 1, 2) == Matrix([ + [0, 0], + [0, 0], + [1, 0], + [0, 2]]) + + +def test_diagonal(): + m = Matrix(3, 3, range(9)) + d = m.diagonal() + assert d == m.diagonal(0) + assert tuple(d) == (0, 4, 8) + assert tuple(m.diagonal(1)) == (1, 5) + assert tuple(m.diagonal(-1)) == (3, 7) + assert tuple(m.diagonal(2)) == (2,) + assert type(m.diagonal()) == type(m) + s = SparseMatrix(3, 3, {(1, 1): 1}) + assert type(s.diagonal()) == type(s) + assert type(m) != type(s) + raises(ValueError, lambda: m.diagonal(3)) + raises(ValueError, lambda: m.diagonal(-3)) + raises(ValueError, lambda: m.diagonal(pi)) + M = ones(2, 3) + assert banded({i: list(M.diagonal(i)) + for i in range(1-M.rows, M.cols)}) == M + + +def test_jordan_block(): + assert SpecialOnlyMatrix.jordan_block(3, 2) == SpecialOnlyMatrix.jordan_block(3, eigenvalue=2) \ + == SpecialOnlyMatrix.jordan_block(size=3, eigenvalue=2) \ + == SpecialOnlyMatrix.jordan_block(3, 2, band='upper') \ + == SpecialOnlyMatrix.jordan_block( + size=3, eigenval=2, eigenvalue=2) \ + == Matrix([ + [2, 1, 0], + [0, 2, 1], + [0, 0, 2]]) + + assert SpecialOnlyMatrix.jordan_block(3, 2, band='lower') == Matrix([ + [2, 0, 0], + [1, 2, 0], + [0, 1, 2]]) + # missing eigenvalue + raises(ValueError, lambda: SpecialOnlyMatrix.jordan_block(2)) + # non-integral size + raises(ValueError, lambda: SpecialOnlyMatrix.jordan_block(3.5, 2)) + # size not specified + raises(ValueError, lambda: SpecialOnlyMatrix.jordan_block(eigenvalue=2)) + # inconsistent eigenvalue + raises(ValueError, + lambda: SpecialOnlyMatrix.jordan_block( + eigenvalue=2, eigenval=4)) + + # Using alias keyword + assert SpecialOnlyMatrix.jordan_block(size=3, eigenvalue=2) == \ + SpecialOnlyMatrix.jordan_block(size=3, eigenval=2) + + +def test_orthogonalize(): + m = Matrix([[1, 2], [3, 4]]) + assert m.orthogonalize(Matrix([[2], [1]])) == [Matrix([[2], [1]])] + assert m.orthogonalize(Matrix([[2], [1]]), normalize=True) == \ + [Matrix([[2*sqrt(5)/5], [sqrt(5)/5]])] + assert m.orthogonalize(Matrix([[1], [2]]), Matrix([[-1], [4]])) == \ + [Matrix([[1], [2]]), Matrix([[Rational(-12, 5)], [Rational(6, 5)]])] + assert m.orthogonalize(Matrix([[0], [0]]), Matrix([[-1], [4]])) == \ + [Matrix([[-1], [4]])] + assert m.orthogonalize(Matrix([[0], [0]])) == [] + + n = Matrix([[9, 1, 9], [3, 6, 10], [8, 5, 2]]) + vecs = [Matrix([[-5], [1]]), Matrix([[-5], [2]]), Matrix([[-5], [-2]])] + assert n.orthogonalize(*vecs) == \ + [Matrix([[-5], [1]]), Matrix([[Rational(5, 26)], [Rational(25, 26)]])] + + vecs = [Matrix([0, 0, 0]), Matrix([1, 2, 3]), Matrix([1, 4, 5])] + raises(ValueError, lambda: Matrix.orthogonalize(*vecs, rankcheck=True)) + + vecs = [Matrix([1, 2, 3]), Matrix([4, 5, 6]), Matrix([7, 8, 9])] + raises(ValueError, lambda: Matrix.orthogonalize(*vecs, rankcheck=True)) + +def test_wilkinson(): + + wminus, wplus = Matrix.wilkinson(1) + assert wminus == Matrix([ + [-1, 1, 0], + [1, 0, 1], + [0, 1, 1]]) + assert wplus == Matrix([ + [1, 1, 0], + [1, 0, 1], + [0, 1, 1]]) + + wminus, wplus = Matrix.wilkinson(3) + assert wminus == Matrix([ + [-3, 1, 0, 0, 0, 0, 0], + [1, -2, 1, 0, 0, 0, 0], + [0, 1, -1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 2, 1], + + [0, 0, 0, 0, 0, 1, 3]]) + + assert wplus == Matrix([ + [3, 1, 0, 0, 0, 0, 0], + [1, 2, 1, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 2, 1], + [0, 0, 0, 0, 0, 1, 3]]) + + +# CalculusOnlyMatrix tests +@XFAIL +def test_diff(): + x, y = symbols('x y') + m = CalculusOnlyMatrix(2, 1, [x, y]) + # TODO: currently not working as ``_MinimalMatrix`` cannot be sympified: + assert m.diff(x) == Matrix(2, 1, [1, 0]) + + +def test_integrate(): + x, y = symbols('x y') + m = CalculusOnlyMatrix(2, 1, [x, y]) + assert m.integrate(x) == Matrix(2, 1, [x**2/2, y*x]) + + +def test_jacobian2(): + rho, phi = symbols("rho,phi") + X = CalculusOnlyMatrix(3, 1, [rho*cos(phi), rho*sin(phi), rho**2]) + Y = CalculusOnlyMatrix(2, 1, [rho, phi]) + J = Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)], + [ 2*rho, 0], + ]) + assert X.jacobian(Y) == J + + m = CalculusOnlyMatrix(2, 2, [1, 2, 3, 4]) + m2 = CalculusOnlyMatrix(4, 1, [1, 2, 3, 4]) + raises(TypeError, lambda: m.jacobian(Matrix([1, 2]))) + raises(TypeError, lambda: m2.jacobian(m)) + + +def test_limit(): + x, y = symbols('x y') + m = CalculusOnlyMatrix(2, 1, [1/x, y]) + assert m.limit(x, 5) == Matrix(2, 1, [Rational(1, 5), y]) + + +def test_issue_13774(): + M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + v = [1, 1, 1] + raises(TypeError, lambda: M*v) + raises(TypeError, lambda: v*M) + +def test_companion(): + x = Symbol('x') + y = Symbol('y') + raises(ValueError, lambda: Matrix.companion(1)) + raises(ValueError, lambda: Matrix.companion(Poly([1], x))) + raises(ValueError, lambda: Matrix.companion(Poly([2, 1], x))) + raises(ValueError, lambda: Matrix.companion(Poly(x*y, [x, y]))) + + c0, c1, c2 = symbols('c0:3') + assert Matrix.companion(Poly([1, c0], x)) == Matrix([-c0]) + assert Matrix.companion(Poly([1, c1, c0], x)) == \ + Matrix([[0, -c0], [1, -c1]]) + assert Matrix.companion(Poly([1, c2, c1, c0], x)) == \ + Matrix([[0, 0, -c0], [1, 0, -c1], [0, 1, -c2]]) + +def test_issue_10589(): + x, y, z = symbols("x, y z") + M1 = Matrix([x, y, z]) + M1 = M1.subs(zip([x, y, z], [1, 2, 3])) + assert M1 == Matrix([[1], [2], [3]]) + + M2 = Matrix([[x, x, x, x, x], [x, x, x, x, x], [x, x, x, x, x]]) + M2 = M2.subs(zip([x], [1])) + assert M2 == Matrix([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]) + +def test_rmul_pr19860(): + class Foo(ImmutableDenseMatrix): + _op_priority = MutableDenseMatrix._op_priority + 0.01 + + a = Matrix(2, 2, [1, 2, 3, 4]) + b = Foo(2, 2, [1, 2, 3, 4]) + + # This would throw a RecursionError: maximum recursion depth + # since b always has higher priority even after a.as_mutable() + c = a*b + + assert isinstance(c, Foo) + assert c == Matrix([[7, 10], [15, 22]]) + + +def test_issue_18956(): + A = Array([[1, 2], [3, 4]]) + B = Matrix([[1,2],[3,4]]) + raises(TypeError, lambda: B + A) + raises(TypeError, lambda: A + B) + + +def test__eq__(): + class My(object): + def __iter__(self): + yield 1 + yield 2 + return + def __getitem__(self, i): + return list(self)[i] + a = Matrix(2, 1, [1, 2]) + assert a != My() + class My_sympy(My): + def _sympy_(self): + return Matrix(self) + assert a == My_sympy() diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_decompositions.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_decompositions.py new file mode 100644 index 0000000000000000000000000000000000000000..d169ec3a8846fed786981e62d932fd860b6d4951 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_decompositions.py @@ -0,0 +1,474 @@ +from sympy.core.function import expand_mul +from sympy.core.numbers import I, Rational +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.complexes import Abs +from sympy.simplify.simplify import simplify +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices import Matrix, zeros, eye, SparseMatrix +from sympy.abc import x, y, z +from sympy.testing.pytest import raises, slow +from sympy.testing.matrices import allclose + + +def test_LUdecomp(): + testmat = Matrix([[0, 2, 5, 3], + [3, 3, 7, 4], + [8, 4, 0, 2], + [-2, 6, 3, 4]]) + L, U, p = testmat.LUdecomposition() + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - testmat == zeros(4) + + testmat = Matrix([[6, -2, 7, 4], + [0, 3, 6, 7], + [1, -2, 7, 4], + [-9, 2, 6, 3]]) + L, U, p = testmat.LUdecomposition() + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - testmat == zeros(4) + + # non-square + testmat = Matrix([[1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10, 11, 12]]) + L, U, p = testmat.LUdecomposition(rankcheck=False) + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - testmat == zeros(4, 3) + + # square and singular + testmat = Matrix([[1, 2, 3], + [2, 4, 6], + [4, 5, 6]]) + L, U, p = testmat.LUdecomposition(rankcheck=False) + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - testmat == zeros(3) + + M = Matrix(((1, x, 1), (2, y, 0), (y, 0, z))) + L, U, p = M.LUdecomposition() + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - M == zeros(3) + + mL = Matrix(( + (1, 0, 0), + (2, 3, 0), + )) + assert mL.is_lower is True + assert mL.is_upper is False + mU = Matrix(( + (1, 2, 3), + (0, 4, 5), + )) + assert mU.is_lower is False + assert mU.is_upper is True + + # test FF LUdecomp + M = Matrix([[1, 3, 3], + [3, 2, 6], + [3, 2, 2]]) + P, L, Dee, U = M.LUdecompositionFF() + assert P*M == L*Dee.inv()*U + + M = Matrix([[1, 2, 3, 4], + [3, -1, 2, 3], + [3, 1, 3, -2], + [6, -1, 0, 2]]) + P, L, Dee, U = M.LUdecompositionFF() + assert P*M == L*Dee.inv()*U + + M = Matrix([[0, 0, 1], + [2, 3, 0], + [3, 1, 4]]) + P, L, Dee, U = M.LUdecompositionFF() + assert P*M == L*Dee.inv()*U + + # issue 15794 + M = Matrix( + [[1, 2, 3], + [4, 5, 6], + [7, 8, 9]] + ) + raises(ValueError, lambda : M.LUdecomposition_Simple(rankcheck=True)) + +def test_singular_value_decompositionD(): + A = Matrix([[1, 2], [2, 1]]) + U, S, V = A.singular_value_decomposition() + assert U * S * V.T == A + assert U.T * U == eye(U.cols) + assert V.T * V == eye(V.cols) + + B = Matrix([[1, 2]]) + U, S, V = B.singular_value_decomposition() + + assert U * S * V.T == B + assert U.T * U == eye(U.cols) + assert V.T * V == eye(V.cols) + + C = Matrix([ + [1, 0, 0, 0, 2], + [0, 0, 3, 0, 0], + [0, 0, 0, 0, 0], + [0, 2, 0, 0, 0], + ]) + + U, S, V = C.singular_value_decomposition() + + assert U * S * V.T == C + assert U.T * U == eye(U.cols) + assert V.T * V == eye(V.cols) + + D = Matrix([[Rational(1, 3), sqrt(2)], [0, Rational(1, 4)]]) + U, S, V = D.singular_value_decomposition() + assert simplify(U.T * U) == eye(U.cols) + assert simplify(V.T * V) == eye(V.cols) + assert simplify(U * S * V.T) == D + + +def test_QR(): + A = Matrix([[1, 2], [2, 3]]) + Q, S = A.QRdecomposition() + R = Rational + assert Q == Matrix([ + [ 5**R(-1, 2), (R(2)/5)*(R(1)/5)**R(-1, 2)], + [2*5**R(-1, 2), (-R(1)/5)*(R(1)/5)**R(-1, 2)]]) + assert S == Matrix([[5**R(1, 2), 8*5**R(-1, 2)], [0, (R(1)/5)**R(1, 2)]]) + assert Q*S == A + assert Q.T * Q == eye(2) + + A = Matrix([[1, 1, 1], [1, 1, 3], [2, 3, 4]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[12, 0, -51], [6, 0, 167], [-4, 0, 24]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + x = Symbol('x') + A = Matrix([x]) + Q, R = A.QRdecomposition() + assert Q == Matrix([x / Abs(x)]) + assert R == Matrix([Abs(x)]) + + A = Matrix([[x, 0], [0, x]]) + Q, R = A.QRdecomposition() + assert Q == x / Abs(x) * Matrix([[1, 0], [0, 1]]) + assert R == Abs(x) * Matrix([[1, 0], [0, 1]]) + + +def test_QR_non_square(): + # Narrow (cols < rows) matrices + A = Matrix([[9, 0, 26], [12, 0, -7], [0, 4, 4], [0, -3, -3]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[1, -1, 4], [1, 4, -2], [1, 4, 2], [1, -1, 0]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix(2, 1, [1, 2]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + # Wide (cols > rows) matrices + A = Matrix([[1, 2, 3], [4, 5, 6]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[1, 2, 3, 4], [1, 4, 9, 16], [1, 8, 27, 64]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix(1, 2, [1, 2]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + +def test_QR_trivial(): + # Rank deficient matrices + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + # Zero rank matrices + A = Matrix([[0, 0, 0]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[0, 0, 0]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[0, 0, 0], [0, 0, 0]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[0, 0, 0], [0, 0, 0]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + # Rank deficient matrices with zero norm from beginning columns + A = Matrix([[0, 0, 0], [1, 2, 3]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[0, 0, 0, 0], [1, 2, 3, 4], [0, 0, 0, 0]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[0, 0, 0, 0], [1, 2, 3, 4], [0, 0, 0, 0], [2, 4, 6, 8]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 2, 3]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + +def test_QR_float(): + A = Matrix([[1, 1], [1, 1.01]]) + Q, R = A.QRdecomposition() + assert allclose(Q * R, A) + assert allclose(Q * Q.T, Matrix.eye(2)) + assert allclose(Q.T * Q, Matrix.eye(2)) + + A = Matrix([[1, 1], [1, 1.001]]) + Q, R = A.QRdecomposition() + assert allclose(Q * R, A) + assert allclose(Q * Q.T, Matrix.eye(2)) + assert allclose(Q.T * Q, Matrix.eye(2)) + + +def test_LUdecomposition_Simple_iszerofunc(): + # Test if callable passed to matrices.LUdecomposition_Simple() as iszerofunc keyword argument is used inside + # matrices.LUdecomposition_Simple() + magic_string = "I got passed in!" + def goofyiszero(value): + raise ValueError(magic_string) + + try: + lu, p = Matrix([[1, 0], [0, 1]]).LUdecomposition_Simple(iszerofunc=goofyiszero) + except ValueError as err: + assert magic_string == err.args[0] + return + + assert False + +def test_LUdecomposition_iszerofunc(): + # Test if callable passed to matrices.LUdecomposition() as iszerofunc keyword argument is used inside + # matrices.LUdecomposition_Simple() + magic_string = "I got passed in!" + def goofyiszero(value): + raise ValueError(magic_string) + + try: + l, u, p = Matrix([[1, 0], [0, 1]]).LUdecomposition(iszerofunc=goofyiszero) + except ValueError as err: + assert magic_string == err.args[0] + return + + assert False + +def test_LDLdecomposition(): + raises(NonSquareMatrixError, lambda: Matrix((1, 2)).LDLdecomposition()) + raises(ValueError, lambda: Matrix(((1, 2), (3, 4))).LDLdecomposition()) + raises(ValueError, lambda: Matrix(((5 + I, 0), (0, 1))).LDLdecomposition()) + raises(ValueError, lambda: Matrix(((1, 5), (5, 1))).LDLdecomposition()) + raises(ValueError, lambda: Matrix(((1, 2), (3, 4))).LDLdecomposition(hermitian=False)) + A = Matrix(((1, 5), (5, 1))) + L, D = A.LDLdecomposition(hermitian=False) + assert L * D * L.T == A + A = Matrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L, D = A.LDLdecomposition() + assert L * D * L.T == A + assert L.is_lower + assert L == Matrix([[1, 0, 0], [ Rational(3, 5), 1, 0], [Rational(-1, 5), Rational(1, 3), 1]]) + assert D.is_diagonal() + assert D == Matrix([[25, 0, 0], [0, 9, 0], [0, 0, 9]]) + A = Matrix(((4, -2*I, 2 + 2*I), (2*I, 2, -1 + I), (2 - 2*I, -1 - I, 11))) + L, D = A.LDLdecomposition() + assert expand_mul(L * D * L.H) == A + assert L.expand() == Matrix([[1, 0, 0], [I/2, 1, 0], [S.Half - I/2, 0, 1]]) + assert D.expand() == Matrix(((4, 0, 0), (0, 1, 0), (0, 0, 9))) + + raises(NonSquareMatrixError, lambda: SparseMatrix((1, 2)).LDLdecomposition()) + raises(ValueError, lambda: SparseMatrix(((1, 2), (3, 4))).LDLdecomposition()) + raises(ValueError, lambda: SparseMatrix(((5 + I, 0), (0, 1))).LDLdecomposition()) + raises(ValueError, lambda: SparseMatrix(((1, 5), (5, 1))).LDLdecomposition()) + raises(ValueError, lambda: SparseMatrix(((1, 2), (3, 4))).LDLdecomposition(hermitian=False)) + A = SparseMatrix(((1, 5), (5, 1))) + L, D = A.LDLdecomposition(hermitian=False) + assert L * D * L.T == A + A = SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L, D = A.LDLdecomposition() + assert L * D * L.T == A + assert L.is_lower + assert L == Matrix([[1, 0, 0], [ Rational(3, 5), 1, 0], [Rational(-1, 5), Rational(1, 3), 1]]) + assert D.is_diagonal() + assert D == Matrix([[25, 0, 0], [0, 9, 0], [0, 0, 9]]) + A = SparseMatrix(((4, -2*I, 2 + 2*I), (2*I, 2, -1 + I), (2 - 2*I, -1 - I, 11))) + L, D = A.LDLdecomposition() + assert expand_mul(L * D * L.H) == A + assert L == Matrix(((1, 0, 0), (I/2, 1, 0), (S.Half - I/2, 0, 1))) + assert D == Matrix(((4, 0, 0), (0, 1, 0), (0, 0, 9))) + +def test_pinv_succeeds_with_rank_decomposition_method(): + # Test rank decomposition method of pseudoinverse succeeding + As = [Matrix([ + [61, 89, 55, 20, 71, 0], + [62, 96, 85, 85, 16, 0], + [69, 56, 17, 4, 54, 0], + [10, 54, 91, 41, 71, 0], + [ 7, 30, 10, 48, 90, 0], + [0,0,0,0,0,0]])] + for A in As: + A_pinv = A.pinv(method="RD") + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + +def test_rank_decomposition(): + a = Matrix(0, 0, []) + c, f = a.rank_decomposition() + assert f.is_echelon + assert c.cols == f.rows == a.rank() + assert c * f == a + + a = Matrix(1, 1, [5]) + c, f = a.rank_decomposition() + assert f.is_echelon + assert c.cols == f.rows == a.rank() + assert c * f == a + + a = Matrix(3, 3, [1, 2, 3, 1, 2, 3, 1, 2, 3]) + c, f = a.rank_decomposition() + assert f.is_echelon + assert c.cols == f.rows == a.rank() + assert c * f == a + + a = Matrix([ + [0, 0, 1, 2, 2, -5, 3], + [-1, 5, 2, 2, 1, -7, 5], + [0, 0, -2, -3, -3, 8, -5], + [-1, 5, 0, -1, -2, 1, 0]]) + c, f = a.rank_decomposition() + assert f.is_echelon + assert c.cols == f.rows == a.rank() + assert c * f == a + + +@slow +def test_upper_hessenberg_decomposition(): + A = Matrix([ + [1, 0, sqrt(3)], + [sqrt(2), Rational(1, 2), 2], + [1, Rational(1, 4), 3], + ]) + H, P = A.upper_hessenberg_decomposition() + assert simplify(P * P.H) == eye(P.cols) + assert simplify(P.H * P) == eye(P.cols) + assert H.is_upper_hessenberg + assert (simplify(P * H * P.H)) == A + + + B = Matrix([ + [1, 2, 10], + [8, 2, 5], + [3, 12, 34], + ]) + H, P = B.upper_hessenberg_decomposition() + assert simplify(P * P.H) == eye(P.cols) + assert simplify(P.H * P) == eye(P.cols) + assert H.is_upper_hessenberg + assert simplify(P * H * P.H) == B + + C = Matrix([ + [1, sqrt(2), 2, 3], + [0, 5, 3, 4], + [1, 1, 4, sqrt(5)], + [0, 2, 2, 3] + ]) + + H, P = C.upper_hessenberg_decomposition() + assert simplify(P * P.H) == eye(P.cols) + assert simplify(P.H * P) == eye(P.cols) + assert H.is_upper_hessenberg + assert simplify(P * H * P.H) == C + + D = Matrix([ + [1, 2, 3], + [-3, 5, 6], + [4, -8, 9], + ]) + H, P = D.upper_hessenberg_decomposition() + assert simplify(P * P.H) == eye(P.cols) + assert simplify(P.H * P) == eye(P.cols) + assert H.is_upper_hessenberg + assert simplify(P * H * P.H) == D + + E = Matrix([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [1, 1, 0, 1], + [1, 1, 1, 0] + ]) + + H, P = E.upper_hessenberg_decomposition() + assert simplify(P * P.H) == eye(P.cols) + assert simplify(P.H * P) == eye(P.cols) + assert H.is_upper_hessenberg + assert simplify(P * H * P.H) == E diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_determinant.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_determinant.py new file mode 100644 index 0000000000000000000000000000000000000000..82b42ccf67efa4757bf270782bdf1d65e0efa306 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_determinant.py @@ -0,0 +1,280 @@ +import random +import pytest +from sympy.core.numbers import I +from sympy.core.numbers import Rational +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys.polytools import Poly +from sympy.matrices import Matrix, eye, ones +from sympy.abc import x, y, z +from sympy.testing.pytest import raises +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.functions.combinatorial.factorials import factorial, subfactorial + + +@pytest.mark.parametrize("method", [ + # Evaluating these directly because they are never reached via M.det() + Matrix._eval_det_bareiss, Matrix._eval_det_berkowitz, + Matrix._eval_det_bird, Matrix._eval_det_laplace, Matrix._eval_det_lu +]) +@pytest.mark.parametrize("M, sol", [ + (Matrix(), 1), + (Matrix([[0]]), 0), + (Matrix([[5]]), 5), +]) +def test_eval_determinant(method, M, sol): + assert method(M) == sol + + +@pytest.mark.parametrize("method", [ + "domain-ge", "bareiss", "berkowitz", "bird", "laplace", "lu"]) +@pytest.mark.parametrize("M, sol", [ + (Matrix(( (-3, 2), + ( 8, -5) )), -1), + (Matrix(( (x, 1), + (y, 2*y) )), 2*x*y - y), + (Matrix(( (1, 1, 1), + (1, 2, 3), + (1, 3, 6) )), 1), + (Matrix(( ( 3, -2, 0, 5), + (-2, 1, -2, 2), + ( 0, -2, 5, 0), + ( 5, 0, 3, 4) )), -289), + (Matrix(( ( 1, 2, 3, 4), + ( 5, 6, 7, 8), + ( 9, 10, 11, 12), + (13, 14, 15, 16) )), 0), + (Matrix(( (3, 2, 0, 0, 0), + (0, 3, 2, 0, 0), + (0, 0, 3, 2, 0), + (0, 0, 0, 3, 2), + (2, 0, 0, 0, 3) )), 275), + (Matrix(( ( 3, 0, 0, 0), + (-2, 1, 0, 0), + ( 0, -2, 5, 0), + ( 5, 0, 3, 4) )), 60), + (Matrix(( ( 1, 0, 0, 0), + ( 5, 0, 0, 0), + ( 9, 10, 11, 0), + (13, 14, 15, 16) )), 0), + (Matrix(( (3, 2, 0, 0, 0), + (0, 3, 2, 0, 0), + (0, 0, 3, 2, 0), + (0, 0, 0, 3, 2), + (0, 0, 0, 0, 3) )), 243), + (Matrix(( (1, 0, 1, 2, 12), + (2, 0, 1, 1, 4), + (2, 1, 1, -1, 3), + (3, 2, -1, 1, 8), + (1, 1, 1, 0, 6) )), -55), + (Matrix(( (-5, 2, 3, 4, 5), + ( 1, -4, 3, 4, 5), + ( 1, 2, -3, 4, 5), + ( 1, 2, 3, -2, 5), + ( 1, 2, 3, 4, -1) )), 11664), + (Matrix(( ( 2, 7, -1, 3, 2), + ( 0, 0, 1, 0, 1), + (-2, 0, 7, 0, 2), + (-3, -2, 4, 5, 3), + ( 1, 0, 0, 0, 1) )), 123), + (Matrix(( (x, y, z), + (1, 0, 0), + (y, z, x) )), z**2 - x*y), +]) +def test_determinant(method, M, sol): + assert M.det(method=method) == sol + + +def test_issue_13835(): + a = symbols('a') + M = lambda n: Matrix([[i + a*j for i in range(n)] + for j in range(n)]) + assert M(5).det() == 0 + assert M(6).det() == 0 + assert M(7).det() == 0 + + +def test_issue_14517(): + M = Matrix([ + [ 0, 10*I, 10*I, 0], + [10*I, 0, 0, 10*I], + [10*I, 0, 5 + 2*I, 10*I], + [ 0, 10*I, 10*I, 5 + 2*I]]) + ev = M.eigenvals() + # test one random eigenvalue, the computation is a little slow + test_ev = random.choice(list(ev.keys())) + assert (M - test_ev*eye(4)).det() == 0 + + +@pytest.mark.parametrize("method", [ + "bareis", "det_lu", "det_LU", "Bareis", "BAREISS", "BERKOWITZ", "LU"]) +@pytest.mark.parametrize("M, sol", [ + (Matrix(( ( 3, -2, 0, 5), + (-2, 1, -2, 2), + ( 0, -2, 5, 0), + ( 5, 0, 3, 4) )), -289), + (Matrix(( (-5, 2, 3, 4, 5), + ( 1, -4, 3, 4, 5), + ( 1, 2, -3, 4, 5), + ( 1, 2, 3, -2, 5), + ( 1, 2, 3, 4, -1) )), 11664), +]) +def test_legacy_det(method, M, sol): + # Minimal support for legacy keys for 'method' in det() + # Partially copied from test_determinant() + assert M.det(method=method) == sol + + +def eye_Determinant(n): + return Matrix(n, n, lambda i, j: int(i == j)) + +def zeros_Determinant(n): + return Matrix(n, n, lambda i, j: 0) + +def test_det(): + a = Matrix(2, 3, [1, 2, 3, 4, 5, 6]) + raises(NonSquareMatrixError, lambda: a.det()) + + z = zeros_Determinant(2) + ey = eye_Determinant(2) + assert z.det() == 0 + assert ey.det() == 1 + + x = Symbol('x') + a = Matrix(0, 0, []) + b = Matrix(1, 1, [5]) + c = Matrix(2, 2, [1, 2, 3, 4]) + d = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 8]) + e = Matrix(4, 4, + [x, 1, 2, 3, 4, 5, 6, 7, 2, 9, 10, 11, 12, 13, 14, 14]) + from sympy.abc import i, j, k, l, m, n + f = Matrix(3, 3, [i, l, m, 0, j, n, 0, 0, k]) + g = Matrix(3, 3, [i, 0, 0, l, j, 0, m, n, k]) + h = Matrix(3, 3, [x**3, 0, 0, i, x**-1, 0, j, k, x**-2]) + # the method keyword for `det` doesn't kick in until 4x4 matrices, + # so there is no need to test all methods on smaller ones + + assert a.det() == 1 + assert b.det() == 5 + assert c.det() == -2 + assert d.det() == 3 + assert e.det() == 4*x - 24 + assert e.det(method="domain-ge") == 4*x - 24 + assert e.det(method='bareiss') == 4*x - 24 + assert e.det(method='berkowitz') == 4*x - 24 + assert f.det() == i*j*k + assert g.det() == i*j*k + assert h.det() == 1 + raises(ValueError, lambda: e.det(iszerofunc="test")) + +def test_permanent(): + M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert M.per() == 450 + for i in range(1, 12): + assert ones(i, i).per() == ones(i, i).T.per() == factorial(i) + assert (ones(i, i)-eye(i)).per() == (ones(i, i)-eye(i)).T.per() == subfactorial(i) + + a1, a2, a3, a4, a5 = symbols('a_1 a_2 a_3 a_4 a_5') + M = Matrix([a1, a2, a3, a4, a5]) + assert M.per() == M.T.per() == a1 + a2 + a3 + a4 + a5 + +def test_adjugate(): + x = Symbol('x') + e = Matrix(4, 4, + [x, 1, 2, 3, 4, 5, 6, 7, 2, 9, 10, 11, 12, 13, 14, 14]) + + adj = Matrix([ + [ 4, -8, 4, 0], + [ 76, -14*x - 68, 14*x - 8, -4*x + 24], + [-122, 17*x + 142, -21*x + 4, 8*x - 48], + [ 48, -4*x - 72, 8*x, -4*x + 24]]) + assert e.adjugate() == adj + assert e.adjugate(method='bareiss') == adj + assert e.adjugate(method='berkowitz') == adj + assert e.adjugate(method='bird') == adj + assert e.adjugate(method='laplace') == adj + + a = Matrix(2, 3, [1, 2, 3, 4, 5, 6]) + raises(NonSquareMatrixError, lambda: a.adjugate()) + +def test_util(): + R = Rational + + v1 = Matrix(1, 3, [1, 2, 3]) + v2 = Matrix(1, 3, [3, 4, 5]) + assert v1.norm() == sqrt(14) + assert v1.project(v2) == Matrix(1, 3, [R(39)/25, R(52)/25, R(13)/5]) + assert Matrix.zeros(1, 2) == Matrix(1, 2, [0, 0]) + assert ones(1, 2) == Matrix(1, 2, [1, 1]) + assert v1.copy() == v1 + # cofactor + assert eye(3) == eye(3).cofactor_matrix() + test = Matrix([[1, 3, 2], [2, 6, 3], [2, 3, 6]]) + assert test.cofactor_matrix() == \ + Matrix([[27, -6, -6], [-12, 2, 3], [-3, 1, 0]]) + test = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert test.cofactor_matrix() == \ + Matrix([[-3, 6, -3], [6, -12, 6], [-3, 6, -3]]) + +def test_cofactor_and_minors(): + x = Symbol('x') + e = Matrix(4, 4, + [x, 1, 2, 3, 4, 5, 6, 7, 2, 9, 10, 11, 12, 13, 14, 14]) + + m = Matrix([ + [ x, 1, 3], + [ 2, 9, 11], + [12, 13, 14]]) + cm = Matrix([ + [ 4, 76, -122, 48], + [-8, -14*x - 68, 17*x + 142, -4*x - 72], + [ 4, 14*x - 8, -21*x + 4, 8*x], + [ 0, -4*x + 24, 8*x - 48, -4*x + 24]]) + sub = Matrix([ + [x, 1, 2], + [4, 5, 6], + [2, 9, 10]]) + + assert e.minor_submatrix(1, 2) == m + assert e.minor_submatrix(-1, -1) == sub + assert e.minor(1, 2) == -17*x - 142 + assert e.cofactor(1, 2) == 17*x + 142 + assert e.cofactor_matrix() == cm + assert e.cofactor_matrix(method="bareiss") == cm + assert e.cofactor_matrix(method="berkowitz") == cm + assert e.cofactor_matrix(method="bird") == cm + assert e.cofactor_matrix(method="laplace") == cm + + raises(ValueError, lambda: e.cofactor(4, 5)) + raises(ValueError, lambda: e.minor(4, 5)) + raises(ValueError, lambda: e.minor_submatrix(4, 5)) + + a = Matrix(2, 3, [1, 2, 3, 4, 5, 6]) + assert a.minor_submatrix(0, 0) == Matrix([[5, 6]]) + + raises(ValueError, lambda: + Matrix(0, 0, []).minor_submatrix(0, 0)) + raises(NonSquareMatrixError, lambda: a.cofactor(0, 0)) + raises(NonSquareMatrixError, lambda: a.minor(0, 0)) + raises(NonSquareMatrixError, lambda: a.cofactor_matrix()) + +def test_charpoly(): + x, y = Symbol('x'), Symbol('y') + z, t = Symbol('z'), Symbol('t') + + from sympy.abc import a,b,c + + m = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + + assert eye_Determinant(3).charpoly(x) == Poly((x - 1)**3, x) + assert eye_Determinant(3).charpoly(y) == Poly((y - 1)**3, y) + assert m.charpoly() == Poly(x**3 - 15*x**2 - 18*x, x) + raises(NonSquareMatrixError, lambda: Matrix([[1], [2]]).charpoly()) + n = Matrix(4, 4, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + assert n.charpoly() == Poly(x**4, x) + + n = Matrix(4, 4, [45, 0, 0, 0, 0, 23, 0, 0, 0, 0, 87, 0, 0, 0, 0, 12]) + assert n.charpoly() == Poly(x**4 - 167*x**3 + 8811*x**2 - 173457*x + 1080540, x) + + n = Matrix(3, 3, [x, 0, 0, a, y, 0, b, c, z]) + assert n.charpoly() == Poly(t**3 - (x+y+z)*t**2 + t*(x*y+y*z+x*z) - x*y*z, t) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_domains.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_domains.py new file mode 100644 index 0000000000000000000000000000000000000000..26a54b8879a5c65f3a01b4886d223c08309e733d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_domains.py @@ -0,0 +1,113 @@ +# Test Matrix/DomainMatrix interaction. + + +from sympy import GF, ZZ, QQ, EXRAW +from sympy.polys.matrices import DomainMatrix, DM + +from sympy import ( + Matrix, + MutableMatrix, + ImmutableMatrix, + SparseMatrix, + MutableDenseMatrix, + ImmutableDenseMatrix, + MutableSparseMatrix, + ImmutableSparseMatrix, +) +from sympy import symbols, S, sqrt + +from sympy.testing.pytest import raises + + +x, y = symbols('x y') + + +MATRIX_TYPES = ( + Matrix, + MutableMatrix, + ImmutableMatrix, + SparseMatrix, + MutableDenseMatrix, + ImmutableDenseMatrix, + MutableSparseMatrix, + ImmutableSparseMatrix, +) +IMMUTABLE = ( + ImmutableMatrix, + ImmutableDenseMatrix, + ImmutableSparseMatrix, +) + + +def DMs(items, domain): + return DM(items, domain).to_sparse() + + +def test_Matrix_rep_domain(): + + for Mat in MATRIX_TYPES: + + M = Mat([[1, 2], [3, 4]]) + assert M._rep == DMs([[1, 2], [3, 4]], ZZ) + assert (M / 2)._rep == DMs([[(1,2), 1], [(3,2), 2]], QQ) + if not isinstance(M, IMMUTABLE): + M[0, 0] = x + assert M._rep == DMs([[x, 2], [3, 4]], EXRAW) + + M = Mat([[S(1)/2, 2], [3, 4]]) + assert M._rep == DMs([[(1,2), 2], [3, 4]], QQ) + if not isinstance(M, IMMUTABLE): + M[0, 0] = x + assert M._rep == DMs([[x, 2], [3, 4]], EXRAW) + + dM = DMs([[1, 2], [3, 4]], ZZ) + assert Mat._fromrep(dM)._rep == dM + + # XXX: This is not intended. Perhaps it should be coerced to EXRAW? + # The private _fromrep method is never called like this but perhaps it + # should be guarded. + # + # It is not clear how to integrate domains other than ZZ, QQ and EXRAW with + # the rest of Matrix or if the public type for this needs to be something + # different from Matrix somehow. + K = QQ.algebraic_field(sqrt(2)) + dM = DM([[1, 2], [3, 4]], K) + assert Mat._fromrep(dM)._rep.domain == K + + +def test_Matrix_to_DM(): + + M = Matrix([[1, 2], [3, 4]]) + assert M.to_DM() == DMs([[1, 2], [3, 4]], ZZ) + assert M.to_DM() is not M._rep + assert M.to_DM(field=True) == DMs([[1, 2], [3, 4]], QQ) + assert M.to_DM(domain=QQ) == DMs([[1, 2], [3, 4]], QQ) + assert M.to_DM(domain=QQ[x]) == DMs([[1, 2], [3, 4]], QQ[x]) + assert M.to_DM(domain=GF(3)) == DMs([[1, 2], [0, 1]], GF(3)) + + M = Matrix([[1, 2], [3, 4]]) + M[0, 0] = x + assert M._rep.domain == EXRAW + M[0, 0] = 1 + assert M.to_DM() == DMs([[1, 2], [3, 4]], ZZ) + + M = Matrix([[S(1)/2, 2], [3, 4]]) + assert M.to_DM() == DMs([[QQ(1,2), 2], [3, 4]], QQ) + + M = Matrix([[x, 2], [3, 4]]) + assert M.to_DM() == DMs([[x, 2], [3, 4]], ZZ[x]) + assert M.to_DM(field=True) == DMs([[x, 2], [3, 4]], ZZ.frac_field(x)) + + M = Matrix([[1/x, 2], [3, 4]]) + assert M.to_DM() == DMs([[1/x, 2], [3, 4]], ZZ.frac_field(x)) + + M = Matrix([[1, sqrt(2)], [3, 4]]) + K = QQ.algebraic_field(sqrt(2)) + sqrt2 = K.from_sympy(sqrt(2)) # XXX: Maybe K(sqrt(2)) should work + M_K = DomainMatrix([[K(1), sqrt2], [K(3), K(4)]], (2, 2), K) + assert M.to_DM() == DMs([[1, sqrt(2)], [3, 4]], EXRAW) + assert M.to_DM(extension=True) == M_K.to_sparse() + + # Options cannot be used with the domain parameter + M = Matrix([[1, 2], [3, 4]]) + raises(TypeError, lambda: M.to_DM(domain=QQ, field=True)) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_eigen.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_eigen.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf96325519879e0683d29e2ddc32db7bf83baa4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_eigen.py @@ -0,0 +1,712 @@ +from sympy.core.evalf import N +from sympy.core.numbers import (Float, I, Rational) +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices import eye, Matrix +from sympy.core.singleton import S +from sympy.testing.pytest import raises, XFAIL +from sympy.matrices.exceptions import NonSquareMatrixError, MatrixError +from sympy.matrices.expressions.fourier import DFT +from sympy.simplify.simplify import simplify +from sympy.matrices.immutable import ImmutableMatrix +from sympy.testing.pytest import slow +from sympy.testing.matrices import allclose + + +def test_eigen(): + R = Rational + M = Matrix.eye(3) + assert M.eigenvals(multiple=False) == {S.One: 3} + assert M.eigenvals(multiple=True) == [1, 1, 1] + + assert M.eigenvects() == ( + [(1, 3, [Matrix([1, 0, 0]), + Matrix([0, 1, 0]), + Matrix([0, 0, 1])])]) + + assert M.left_eigenvects() == ( + [(1, 3, [Matrix([[1, 0, 0]]), + Matrix([[0, 1, 0]]), + Matrix([[0, 0, 1]])])]) + + M = Matrix([[0, 1, 1], + [1, 0, 0], + [1, 1, 1]]) + + assert M.eigenvals() == {2*S.One: 1, -S.One: 1, S.Zero: 1} + + assert M.eigenvects() == ( + [ + (-1, 1, [Matrix([-1, 1, 0])]), + ( 0, 1, [Matrix([0, -1, 1])]), + ( 2, 1, [Matrix([R(2, 3), R(1, 3), 1])]) + ]) + + assert M.left_eigenvects() == ( + [ + (-1, 1, [Matrix([[-2, 1, 1]])]), + (0, 1, [Matrix([[-1, -1, 1]])]), + (2, 1, [Matrix([[1, 1, 1]])]) + ]) + + a = Symbol('a') + M = Matrix([[a, 0], + [0, 1]]) + + assert M.eigenvals() == {a: 1, S.One: 1} + + M = Matrix([[1, -1], + [1, 3]]) + assert M.eigenvects() == ([(2, 2, [Matrix(2, 1, [-1, 1])])]) + assert M.left_eigenvects() == ([(2, 2, [Matrix([[1, 1]])])]) + + M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + a = R(15, 2) + b = 3*33**R(1, 2) + c = R(13, 2) + d = (R(33, 8) + 3*b/8) + e = (R(33, 8) - 3*b/8) + + def NS(e, n): + return str(N(e, n)) + r = [ + (a - b/2, 1, [Matrix([(12 + 24/(c - b/2))/((c - b/2)*e) + 3/(c - b/2), + (6 + 12/(c - b/2))/e, 1])]), + ( 0, 1, [Matrix([1, -2, 1])]), + (a + b/2, 1, [Matrix([(12 + 24/(c + b/2))/((c + b/2)*d) + 3/(c + b/2), + (6 + 12/(c + b/2))/d, 1])]), + ] + r1 = [(NS(r[i][0], 2), NS(r[i][1], 2), + [NS(j, 2) for j in r[i][2][0]]) for i in range(len(r))] + r = M.eigenvects() + r2 = [(NS(r[i][0], 2), NS(r[i][1], 2), + [NS(j, 2) for j in r[i][2][0]]) for i in range(len(r))] + assert sorted(r1) == sorted(r2) + + eps = Symbol('eps', real=True) + + M = Matrix([[abs(eps), I*eps ], + [-I*eps, abs(eps) ]]) + + assert M.eigenvects() == ( + [ + ( 0, 1, [Matrix([[-I*eps/abs(eps)], [1]])]), + ( 2*abs(eps), 1, [ Matrix([[I*eps/abs(eps)], [1]]) ] ), + ]) + + assert M.left_eigenvects() == ( + [ + (0, 1, [Matrix([[I*eps/Abs(eps), 1]])]), + (2*Abs(eps), 1, [Matrix([[-I*eps/Abs(eps), 1]])]) + ]) + + M = Matrix(3, 3, [1, 2, 0, 0, 3, 0, 2, -4, 2]) + M._eigenvects = M.eigenvects(simplify=False) + assert max(i.q for i in M._eigenvects[0][2][0]) > 1 + M._eigenvects = M.eigenvects(simplify=True) + assert max(i.q for i in M._eigenvects[0][2][0]) == 1 + + M = Matrix([[Rational(1, 4), 1], [1, 1]]) + assert M.eigenvects() == [ + (Rational(5, 8) - sqrt(73)/8, 1, [Matrix([[-sqrt(73)/8 - Rational(3, 8)], [1]])]), + (Rational(5, 8) + sqrt(73)/8, 1, [Matrix([[Rational(-3, 8) + sqrt(73)/8], [1]])])] + + # issue 10719 + assert Matrix([]).eigenvals() == {} + assert Matrix([]).eigenvals(multiple=True) == [] + assert Matrix([]).eigenvects() == [] + + # issue 15119 + raises(NonSquareMatrixError, + lambda: Matrix([[1, 2], [0, 4], [0, 0]]).eigenvals()) + raises(NonSquareMatrixError, + lambda: Matrix([[1, 0], [3, 4], [5, 6]]).eigenvals()) + raises(NonSquareMatrixError, + lambda: Matrix([[1, 2, 3], [0, 5, 6]]).eigenvals()) + raises(NonSquareMatrixError, + lambda: Matrix([[1, 0, 0], [4, 5, 0]]).eigenvals()) + raises(NonSquareMatrixError, + lambda: Matrix([[1, 2, 3], [0, 5, 6]]).eigenvals( + error_when_incomplete = False)) + raises(NonSquareMatrixError, + lambda: Matrix([[1, 0, 0], [4, 5, 0]]).eigenvals( + error_when_incomplete = False)) + + m = Matrix([[1, 2], [3, 4]]) + assert isinstance(m.eigenvals(simplify=True, multiple=False), dict) + assert isinstance(m.eigenvals(simplify=True, multiple=True), list) + assert isinstance(m.eigenvals(simplify=lambda x: x, multiple=False), dict) + assert isinstance(m.eigenvals(simplify=lambda x: x, multiple=True), list) + + +def test_float_eigenvals(): + m = Matrix([[1, .6, .6], [.6, .9, .9], [.9, .6, .6]]) + evals = [ + Rational(5, 4) - sqrt(385)/20, + sqrt(385)/20 + Rational(5, 4), + S.Zero] + + n_evals = m.eigenvals(rational=True, multiple=True) + n_evals = sorted(n_evals) + s_evals = [x.evalf() for x in evals] + s_evals = sorted(s_evals) + + for x, y in zip(n_evals, s_evals): + assert abs(x-y) < 10**-9 + + +@XFAIL +def test_eigen_vects(): + m = Matrix(2, 2, [1, 0, 0, I]) + raises(NotImplementedError, lambda: m.is_diagonalizable(True)) + # !!! bug because of eigenvects() or roots(x**2 + (-1 - I)*x + I, x) + # see issue 5292 + assert not m.is_diagonalizable(True) + raises(MatrixError, lambda: m.diagonalize(True)) + (P, D) = m.diagonalize(True) + +def test_issue_8240(): + # Eigenvalues of large triangular matrices + x, y = symbols('x y') + n = 200 + + diagonal_variables = [Symbol('x%s' % i) for i in range(n)] + M = [[0 for i in range(n)] for j in range(n)] + for i in range(n): + M[i][i] = diagonal_variables[i] + M = Matrix(M) + + eigenvals = M.eigenvals() + assert len(eigenvals) == n + for i in range(n): + assert eigenvals[diagonal_variables[i]] == 1 + + eigenvals = M.eigenvals(multiple=True) + assert set(eigenvals) == set(diagonal_variables) + + # with multiplicity + M = Matrix([[x, 0, 0], [1, y, 0], [2, 3, x]]) + eigenvals = M.eigenvals() + assert eigenvals == {x: 2, y: 1} + + eigenvals = M.eigenvals(multiple=True) + assert len(eigenvals) == 3 + assert eigenvals.count(x) == 2 + assert eigenvals.count(y) == 1 + + +def test_eigenvals(): + M = Matrix([[0, 1, 1], + [1, 0, 0], + [1, 1, 1]]) + assert M.eigenvals() == {2*S.One: 1, -S.One: 1, S.Zero: 1} + + m = Matrix([ + [3, 0, 0, 0, -3], + [0, -3, -3, 0, 3], + [0, 3, 0, 3, 0], + [0, 0, 3, 0, 3], + [3, 0, 0, 3, 0]]) + + # XXX Used dry-run test because arbitrary symbol that appears in + # CRootOf may not be unique. + assert m.eigenvals() + + +def test_eigenvects(): + M = Matrix([[0, 1, 1], + [1, 0, 0], + [1, 1, 1]]) + vecs = M.eigenvects() + for val, mult, vec_list in vecs: + assert len(vec_list) == 1 + assert M*vec_list[0] == val*vec_list[0] + + +def test_left_eigenvects(): + M = Matrix([[0, 1, 1], + [1, 0, 0], + [1, 1, 1]]) + vecs = M.left_eigenvects() + for val, mult, vec_list in vecs: + assert len(vec_list) == 1 + assert vec_list[0]*M == val*vec_list[0] + + +@slow +def test_bidiagonalize(): + M = Matrix([[1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + assert M.bidiagonalize() == M + assert M.bidiagonalize(upper=False) == M + assert M.bidiagonalize() == M + assert M.bidiagonal_decomposition() == (M, M, M) + assert M.bidiagonal_decomposition(upper=False) == (M, M, M) + assert M.bidiagonalize() == M + + import random + #Real Tests + for real_test in range(2): + test_values = [] + row = 2 + col = 2 + for _ in range(row * col): + value = random.randint(-1000000000, 1000000000) + test_values = test_values + [value] + # L -> Lower Bidiagonalization + # M -> Mutable Matrix + # N -> Immutable Matrix + # 0 -> Bidiagonalized form + # 1,2,3 -> Bidiagonal_decomposition matrices + # 4 -> Product of 1 2 3 + M = Matrix(row, col, test_values) + N = ImmutableMatrix(M) + + N1, N2, N3 = N.bidiagonal_decomposition() + M1, M2, M3 = M.bidiagonal_decomposition() + M0 = M.bidiagonalize() + N0 = N.bidiagonalize() + + N4 = N1 * N2 * N3 + M4 = M1 * M2 * M3 + + N2.simplify() + N4.simplify() + N0.simplify() + + M0.simplify() + M2.simplify() + M4.simplify() + + LM0 = M.bidiagonalize(upper=False) + LM1, LM2, LM3 = M.bidiagonal_decomposition(upper=False) + LN0 = N.bidiagonalize(upper=False) + LN1, LN2, LN3 = N.bidiagonal_decomposition(upper=False) + + LN4 = LN1 * LN2 * LN3 + LM4 = LM1 * LM2 * LM3 + + LN2.simplify() + LN4.simplify() + LN0.simplify() + + LM0.simplify() + LM2.simplify() + LM4.simplify() + + assert M == M4 + assert M2 == M0 + assert N == N4 + assert N2 == N0 + assert M == LM4 + assert LM2 == LM0 + assert N == LN4 + assert LN2 == LN0 + + #Complex Tests + for complex_test in range(2): + test_values = [] + size = 2 + for _ in range(size * size): + real = random.randint(-1000000000, 1000000000) + comp = random.randint(-1000000000, 1000000000) + value = real + comp * I + test_values = test_values + [value] + M = Matrix(size, size, test_values) + N = ImmutableMatrix(M) + # L -> Lower Bidiagonalization + # M -> Mutable Matrix + # N -> Immutable Matrix + # 0 -> Bidiagonalized form + # 1,2,3 -> Bidiagonal_decomposition matrices + # 4 -> Product of 1 2 3 + N1, N2, N3 = N.bidiagonal_decomposition() + M1, M2, M3 = M.bidiagonal_decomposition() + M0 = M.bidiagonalize() + N0 = N.bidiagonalize() + + N4 = N1 * N2 * N3 + M4 = M1 * M2 * M3 + + N2.simplify() + N4.simplify() + N0.simplify() + + M0.simplify() + M2.simplify() + M4.simplify() + + LM0 = M.bidiagonalize(upper=False) + LM1, LM2, LM3 = M.bidiagonal_decomposition(upper=False) + LN0 = N.bidiagonalize(upper=False) + LN1, LN2, LN3 = N.bidiagonal_decomposition(upper=False) + + LN4 = LN1 * LN2 * LN3 + LM4 = LM1 * LM2 * LM3 + + LN2.simplify() + LN4.simplify() + LN0.simplify() + + LM0.simplify() + LM2.simplify() + LM4.simplify() + + assert M == M4 + assert M2 == M0 + assert N == N4 + assert N2 == N0 + assert M == LM4 + assert LM2 == LM0 + assert N == LN4 + assert LN2 == LN0 + + M = Matrix(18, 8, range(1, 145)) + M = M.applyfunc(lambda i: Float(i)) + assert M.bidiagonal_decomposition()[1] == M.bidiagonalize() + assert M.bidiagonal_decomposition(upper=False)[1] == M.bidiagonalize(upper=False) + a, b, c = M.bidiagonal_decomposition() + diff = a * b * c - M + assert abs(max(diff)) < 10**-12 + + +def test_diagonalize(): + m = Matrix(2, 2, [0, -1, 1, 0]) + raises(MatrixError, lambda: m.diagonalize(reals_only=True)) + P, D = m.diagonalize() + assert D.is_diagonal() + assert D == Matrix([ + [-I, 0], + [ 0, I]]) + + # make sure we use floats out if floats are passed in + m = Matrix(2, 2, [0, .5, .5, 0]) + P, D = m.diagonalize() + assert all(isinstance(e, Float) for e in D.values()) + assert all(isinstance(e, Float) for e in P.values()) + + _, D2 = m.diagonalize(reals_only=True) + assert D == D2 + + m = Matrix( + [[0, 1, 0, 0], [1, 0, 0, 0.002], [0.002, 0, 0, 1], [0, 0, 1, 0]]) + P, D = m.diagonalize() + assert allclose(P*D, m*P) + + +def test_is_diagonalizable(): + a, b, c = symbols('a b c') + m = Matrix(2, 2, [a, c, c, b]) + assert m.is_symmetric() + assert m.is_diagonalizable() + assert not Matrix(2, 2, [1, 1, 0, 1]).is_diagonalizable() + + m = Matrix(2, 2, [0, -1, 1, 0]) + assert m.is_diagonalizable() + assert not m.is_diagonalizable(reals_only=True) + + +def test_jordan_form(): + m = Matrix(3, 2, [-3, 1, -3, 20, 3, 10]) + raises(NonSquareMatrixError, lambda: m.jordan_form()) + + # the next two tests test the cases where the old + # algorithm failed due to the fact that the block structure can + # *NOT* be determined from algebraic and geometric multiplicity alone + # This can be seen most easily when one lets compute the J.c.f. of a matrix that + # is in J.c.f already. + m = Matrix(4, 4, [2, 1, 0, 0, + 0, 2, 1, 0, + 0, 0, 2, 0, + 0, 0, 0, 2 + ]) + P, J = m.jordan_form() + assert m == J + + m = Matrix(4, 4, [2, 1, 0, 0, + 0, 2, 0, 0, + 0, 0, 2, 1, + 0, 0, 0, 2 + ]) + P, J = m.jordan_form() + assert m == J + + A = Matrix([[ 2, 4, 1, 0], + [-4, 2, 0, 1], + [ 0, 0, 2, 4], + [ 0, 0, -4, 2]]) + P, J = A.jordan_form() + assert simplify(P*J*P.inv()) == A + + assert Matrix(1, 1, [1]).jordan_form() == (Matrix([1]), Matrix([1])) + assert Matrix(1, 1, [1]).jordan_form(calc_transform=False) == Matrix([1]) + + # If we have eigenvalues in CRootOf form, raise errors + m = Matrix([[3, 0, 0, 0, -3], [0, -3, -3, 0, 3], [0, 3, 0, 3, 0], [0, 0, 3, 0, 3], [3, 0, 0, 3, 0]]) + raises(MatrixError, lambda: m.jordan_form()) + + # make sure that if the input has floats, the output does too + m = Matrix([ + [ 0.6875, 0.125 + 0.1875*sqrt(3)], + [0.125 + 0.1875*sqrt(3), 0.3125]]) + P, J = m.jordan_form() + assert all(isinstance(x, Float) or x == 0 for x in P) + assert all(isinstance(x, Float) or x == 0 for x in J) + + +def test_singular_values(): + x = Symbol('x', real=True) + + A = Matrix([[0, 1*I], [2, 0]]) + # if singular values can be sorted, they should be in decreasing order + assert A.singular_values() == [2, 1] + + A = eye(3) + A[1, 1] = x + A[2, 2] = 5 + vals = A.singular_values() + # since Abs(x) cannot be sorted, test set equality + assert set(vals) == {5, 1, Abs(x)} + + A = Matrix([[sin(x), cos(x)], [-cos(x), sin(x)]]) + vals = [sv.trigsimp() for sv in A.singular_values()] + assert vals == [S.One, S.One] + + A = Matrix([ + [2, 4], + [1, 3], + [0, 0], + [0, 0] + ]) + assert A.singular_values() == \ + [sqrt(sqrt(221) + 15), sqrt(15 - sqrt(221))] + assert A.T.singular_values() == \ + [sqrt(sqrt(221) + 15), sqrt(15 - sqrt(221)), 0, 0] + +def test___eq__(): + assert (Matrix( + [[0, 1, 1], + [1, 0, 0], + [1, 1, 1]]) == {}) is False + + +def test_definite(): + # Examples from Gilbert Strang, "Introduction to Linear Algebra" + # Positive definite matrices + m = Matrix([[2, -1, 0], [-1, 2, -1], [0, -1, 2]]) + assert m.is_positive_definite == True + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + m = Matrix([[5, 4], [4, 5]]) + assert m.is_positive_definite == True + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + # Positive semidefinite matrices + m = Matrix([[2, -1, -1], [-1, 2, -1], [-1, -1, 2]]) + assert m.is_positive_definite == False + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + m = Matrix([[1, 2], [2, 4]]) + assert m.is_positive_definite == False + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + # Examples from Mathematica documentation + # Non-hermitian positive definite matrices + m = Matrix([[2, 3], [4, 8]]) + assert m.is_positive_definite == True + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + # Hermetian matrices + m = Matrix([[1, 2*I], [-I, 4]]) + assert m.is_positive_definite == True + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + # Symbolic matrices examples + a = Symbol('a', positive=True) + b = Symbol('b', negative=True) + m = Matrix([[a, 0, 0], [0, a, 0], [0, 0, a]]) + assert m.is_positive_definite == True + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + m = Matrix([[b, 0, 0], [0, b, 0], [0, 0, b]]) + assert m.is_positive_definite == False + assert m.is_positive_semidefinite == False + assert m.is_negative_definite == True + assert m.is_negative_semidefinite == True + assert m.is_indefinite == False + + m = Matrix([[a, 0], [0, b]]) + assert m.is_positive_definite == False + assert m.is_positive_semidefinite == False + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == True + + m = Matrix([ + [0.0228202735623867, 0.00518748979085398, + -0.0743036351048907, -0.00709135324903921], + [0.00518748979085398, 0.0349045359786350, + 0.0830317991056637, 0.00233147902806909], + [-0.0743036351048907, 0.0830317991056637, + 1.15859676366277, 0.340359081555988], + [-0.00709135324903921, 0.00233147902806909, + 0.340359081555988, 0.928147644848199] + ]) + assert m.is_positive_definite == True + assert m.is_positive_semidefinite == True + assert m.is_indefinite == False + + # test for issue 19547: https://github.com/sympy/sympy/issues/19547 + m = Matrix([ + [0, 0, 0], + [0, 1, 2], + [0, 2, 1] + ]) + assert not m.is_positive_definite + assert not m.is_positive_semidefinite + + +def test_positive_semidefinite_cholesky(): + from sympy.matrices.eigen import _is_positive_semidefinite_cholesky + + m = Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + assert _is_positive_semidefinite_cholesky(m) == True + m = Matrix([[0, 0, 0], [0, 5, -10*I], [0, 10*I, 5]]) + assert _is_positive_semidefinite_cholesky(m) == False + m = Matrix([[1, 0, 0], [0, 0, 0], [0, 0, -1]]) + assert _is_positive_semidefinite_cholesky(m) == False + m = Matrix([[0, 1], [1, 0]]) + assert _is_positive_semidefinite_cholesky(m) == False + + # https://www.value-at-risk.net/cholesky-factorization/ + m = Matrix([[4, -2, -6], [-2, 10, 9], [-6, 9, 14]]) + assert _is_positive_semidefinite_cholesky(m) == True + m = Matrix([[9, -3, 3], [-3, 2, 1], [3, 1, 6]]) + assert _is_positive_semidefinite_cholesky(m) == True + m = Matrix([[4, -2, 2], [-2, 1, -1], [2, -1, 5]]) + assert _is_positive_semidefinite_cholesky(m) == True + m = Matrix([[1, 2, -1], [2, 5, 1], [-1, 1, 9]]) + assert _is_positive_semidefinite_cholesky(m) == False + + +def test_issue_20582(): + A = Matrix([ + [5, -5, -3, 2, -7], + [-2, -5, 0, 2, 1], + [-2, -7, -5, -2, -6], + [7, 10, 3, 9, -2], + [4, -10, 3, -8, -4] + ]) + # XXX Used dry-run test because arbitrary symbol that appears in + # CRootOf may not be unique. + assert A.eigenvects() + +def test_issue_19210(): + t = Symbol('t') + H = Matrix([[3, 0, 0, 0], [0, 1 , 2, 0], [0, 2, 2, 0], [0, 0, 0, 4]]) + A = (-I * H * t).jordan_form() + assert A == (Matrix([ + [0, 1, 0, 0], + [0, 0, -4/(-1 + sqrt(17)), 4/(1 + sqrt(17))], + [0, 0, 1, 1], + [1, 0, 0, 0]]), Matrix([ + [-4*I*t, 0, 0, 0], + [ 0, -3*I*t, 0, 0], + [ 0, 0, t*(-3*I/2 + sqrt(17)*I/2), 0], + [ 0, 0, 0, t*(-sqrt(17)*I/2 - 3*I/2)]])) + + +def test_issue_20275(): + # XXX We use complex expansions because complex exponentials are not + # recognized by polys.domains + A = DFT(3).as_explicit().expand(complex=True) + eigenvects = A.eigenvects() + assert eigenvects[0] == ( + -1, 1, + [Matrix([[1 - sqrt(3)], [1], [1]])] + ) + assert eigenvects[1] == ( + 1, 1, + [Matrix([[1 + sqrt(3)], [1], [1]])] + ) + assert eigenvects[2] == ( + -I, 1, + [Matrix([[0], [-1], [1]])] + ) + + A = DFT(4).as_explicit().expand(complex=True) + eigenvects = A.eigenvects() + assert eigenvects[0] == ( + -1, 1, + [Matrix([[-1], [1], [1], [1]])] + ) + assert eigenvects[1] == ( + 1, 2, + [Matrix([[1], [0], [1], [0]]), Matrix([[2], [1], [0], [1]])] + ) + assert eigenvects[2] == ( + -I, 1, + [Matrix([[0], [-1], [0], [1]])] + ) + + # XXX We skip test for some parts of eigenvectors which are very + # complicated and fragile under expression tree changes + A = DFT(5).as_explicit().expand(complex=True) + eigenvects = A.eigenvects() + assert eigenvects[0] == ( + -1, 1, + [Matrix([[1 - sqrt(5)], [1], [1], [1], [1]])] + ) + assert eigenvects[1] == ( + 1, 2, + [Matrix([[S(1)/2 + sqrt(5)/2], [0], [1], [1], [0]]), + Matrix([[S(1)/2 + sqrt(5)/2], [1], [0], [0], [1]])] + ) + + +def test_issue_20752(): + b = symbols('b', nonzero=True) + m = Matrix([[0, 0, 0], [0, b, 0], [0, 0, b]]) + assert m.is_positive_semidefinite is None + + +def test_issue_25282(): + dd = sd = [0] * 11 + [1] + ds = [2, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0] + ss = ds.copy() + ss[8] = 2 + + def rotate(x, i): + return x[i:] + x[:i] + + mat = [] + for i in range(12): + mat.append(rotate(ss, i) + rotate(sd, i)) + for i in range(12): + mat.append(rotate(ds, i) + rotate(dd, i)) + + assert sum(Matrix(mat).eigenvals().values()) == 24 diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_graph.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf3c819a9477387f53560a034d7949fd76a654f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_graph.py @@ -0,0 +1,108 @@ +from sympy.combinatorics import Permutation +from sympy.core.symbol import symbols +from sympy.matrices import Matrix +from sympy.matrices.expressions import ( + PermutationMatrix, BlockDiagMatrix, BlockMatrix) + + +def test_connected_components(): + a, b, c, d, e, f, g, h, i, j, k, l, m = symbols('a:m') + + M = Matrix([ + [a, 0, 0, 0, b, 0, 0, 0, 0, 0, c, 0, 0], + [0, d, 0, 0, 0, e, 0, 0, 0, 0, 0, f, 0], + [0, 0, g, 0, 0, 0, h, 0, 0, 0, 0, 0, i], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [m, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, m, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, m, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [j, 0, 0, 0, k, 0, 0, 1, 0, 0, l, 0, 0], + [0, j, 0, 0, 0, k, 0, 0, 1, 0, 0, l, 0], + [0, 0, j, 0, 0, 0, k, 0, 0, 1, 0, 0, l], + [0, 0, 0, 0, d, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, d, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, d, 0, 0, 0, 0, 0, 1]]) + cc = M.connected_components() + assert cc == [[0, 4, 7, 10], [1, 5, 8, 11], [2, 6, 9, 12], [3]] + + P, B = M.connected_components_decomposition() + p = Permutation([0, 4, 7, 10, 1, 5, 8, 11, 2, 6, 9, 12, 3]) + assert P == PermutationMatrix(p) + + B0 = Matrix([ + [a, b, 0, c], + [m, 1, 0, 0], + [j, k, 1, l], + [0, d, 0, 1]]) + B1 = Matrix([ + [d, e, 0, f], + [m, 1, 0, 0], + [j, k, 1, l], + [0, d, 0, 1]]) + B2 = Matrix([ + [g, h, 0, i], + [m, 1, 0, 0], + [j, k, 1, l], + [0, d, 0, 1]]) + B3 = Matrix([[1]]) + assert B == BlockDiagMatrix(B0, B1, B2, B3) + + +def test_strongly_connected_components(): + M = Matrix([ + [11, 14, 10, 0, 15, 0], + [0, 44, 0, 0, 45, 0], + [1, 4, 0, 0, 5, 0], + [0, 0, 0, 22, 0, 23], + [0, 54, 0, 0, 55, 0], + [0, 0, 0, 32, 0, 33]]) + scc = M.strongly_connected_components() + assert scc == [[1, 4], [0, 2], [3, 5]] + + P, B = M.strongly_connected_components_decomposition() + p = Permutation([1, 4, 0, 2, 3, 5]) + assert P == PermutationMatrix(p) + assert B == BlockMatrix([ + [ + Matrix([[44, 45], [54, 55]]), + Matrix.zeros(2, 2), + Matrix.zeros(2, 2) + ], + [ + Matrix([[14, 15], [4, 5]]), + Matrix([[11, 10], [1, 0]]), + Matrix.zeros(2, 2) + ], + [ + Matrix.zeros(2, 2), + Matrix.zeros(2, 2), + Matrix([[22, 23], [32, 33]]) + ] + ]) + P = P.as_explicit() + B = B.as_explicit() + assert P.T * B * P == M + + P, B = M.strongly_connected_components_decomposition(lower=False) + p = Permutation([3, 5, 0, 2, 1, 4]) + assert P == PermutationMatrix(p) + assert B == BlockMatrix([ + [ + Matrix([[22, 23], [32, 33]]), + Matrix.zeros(2, 2), + Matrix.zeros(2, 2) + ], + [ + Matrix.zeros(2, 2), + Matrix([[11, 10], [1, 0]]), + Matrix([[14, 15], [4, 5]]) + ], + [ + Matrix.zeros(2, 2), + Matrix.zeros(2, 2), + Matrix([[44, 45], [54, 55]]) + ] + ]) + P = P.as_explicit() + B = B.as_explicit() + assert P.T * B * P == M diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_immutable.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_immutable.py new file mode 100644 index 0000000000000000000000000000000000000000..2b83c1f9fae7f83be9d5f7dd4b484781dc128faf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_immutable.py @@ -0,0 +1,136 @@ +from itertools import product + +from sympy.core.relational import (Equality, Unequality) +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.integrals.integrals import integrate +from sympy.matrices.dense import (Matrix, eye, zeros) +from sympy.matrices.immutable import ImmutableMatrix +from sympy.matrices import SparseMatrix +from sympy.matrices.immutable import \ + ImmutableDenseMatrix, ImmutableSparseMatrix +from sympy.abc import x, y +from sympy.testing.pytest import raises + +IM = ImmutableDenseMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) +ISM = ImmutableSparseMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) +ieye = ImmutableDenseMatrix(eye(3)) + + +def test_creation(): + assert IM.shape == ISM.shape == (3, 3) + assert IM[1, 2] == ISM[1, 2] == 6 + assert IM[2, 2] == ISM[2, 2] == 9 + + +def test_immutability(): + with raises(TypeError): + IM[2, 2] = 5 + with raises(TypeError): + ISM[2, 2] = 5 + + +def test_slicing(): + assert IM[1, :] == ImmutableDenseMatrix([[4, 5, 6]]) + assert IM[:2, :2] == ImmutableDenseMatrix([[1, 2], [4, 5]]) + assert ISM[1, :] == ImmutableSparseMatrix([[4, 5, 6]]) + assert ISM[:2, :2] == ImmutableSparseMatrix([[1, 2], [4, 5]]) + + +def test_subs(): + A = ImmutableMatrix([[1, 2], [3, 4]]) + B = ImmutableMatrix([[1, 2], [x, 4]]) + C = ImmutableMatrix([[-x, x*y], [-(x + y), y**2]]) + assert B.subs(x, 3) == A + assert (x*B).subs(x, 3) == 3*A + assert (x*eye(2) + B).subs(x, 3) == 3*eye(2) + A + assert C.subs([[x, -1], [y, -2]]) == A + assert C.subs([(x, -1), (y, -2)]) == A + assert C.subs({x: -1, y: -2}) == A + assert C.subs({x: y - 1, y: x - 1}, simultaneous=True) == \ + ImmutableMatrix([[1 - y, (x - 1)*(y - 1)], [2 - x - y, (x - 1)**2]]) + + +def test_as_immutable(): + data = [[1, 2], [3, 4]] + X = Matrix(data) + assert sympify(X) == X.as_immutable() == ImmutableMatrix(data) + + data = {(0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4} + X = SparseMatrix(2, 2, data) + assert sympify(X) == X.as_immutable() == ImmutableSparseMatrix(2, 2, data) + + +def test_function_return_types(): + # Lets ensure that decompositions of immutable matrices remain immutable + # I.e. do MatrixBase methods return the correct class? + X = ImmutableMatrix([[1, 2], [3, 4]]) + Y = ImmutableMatrix([[1], [0]]) + q, r = X.QRdecomposition() + assert (type(q), type(r)) == (ImmutableMatrix, ImmutableMatrix) + + assert type(X.LUsolve(Y)) == ImmutableMatrix + assert type(X.QRsolve(Y)) == ImmutableMatrix + + X = ImmutableMatrix([[5, 2], [2, 7]]) + assert X.T == X + assert X.is_symmetric + assert type(X.cholesky()) == ImmutableMatrix + L, D = X.LDLdecomposition() + assert (type(L), type(D)) == (ImmutableMatrix, ImmutableMatrix) + + X = ImmutableMatrix([[1, 2], [2, 1]]) + assert X.is_diagonalizable() + assert X.det() == -3 + assert X.norm(2) == 3 + + assert type(X.eigenvects()[0][2][0]) == ImmutableMatrix + + assert type(zeros(3, 3).as_immutable().nullspace()[0]) == ImmutableMatrix + + X = ImmutableMatrix([[1, 0], [2, 1]]) + assert type(X.lower_triangular_solve(Y)) == ImmutableMatrix + assert type(X.T.upper_triangular_solve(Y)) == ImmutableMatrix + + assert type(X.minor_submatrix(0, 0)) == ImmutableMatrix + +# issue 6279 +# https://github.com/sympy/sympy/issues/6279 +# Test that Immutable _op_ Immutable => Immutable and not MatExpr + + +def test_immutable_evaluation(): + X = ImmutableMatrix(eye(3)) + A = ImmutableMatrix(3, 3, range(9)) + assert isinstance(X + A, ImmutableMatrix) + assert isinstance(X * A, ImmutableMatrix) + assert isinstance(X * 2, ImmutableMatrix) + assert isinstance(2 * X, ImmutableMatrix) + assert isinstance(A**2, ImmutableMatrix) + + +def test_deterimant(): + assert ImmutableMatrix(4, 4, lambda i, j: i + j).det() == 0 + + +def test_Equality(): + assert Equality(IM, IM) is S.true + assert Unequality(IM, IM) is S.false + assert Equality(IM, IM.subs(1, 2)) is S.false + assert Unequality(IM, IM.subs(1, 2)) is S.true + assert Equality(IM, 2) is S.false + assert Unequality(IM, 2) is S.true + M = ImmutableMatrix([x, y]) + assert Equality(M, IM) is S.false + assert Unequality(M, IM) is S.true + assert Equality(M, M.subs(x, 2)).subs(x, 2) is S.true + assert Unequality(M, M.subs(x, 2)).subs(x, 2) is S.false + assert Equality(M, M.subs(x, 2)).subs(x, 3) is S.false + assert Unequality(M, M.subs(x, 2)).subs(x, 3) is S.true + + +def test_integrate(): + intIM = integrate(IM, x) + assert intIM.shape == IM.shape + assert all(intIM[i, j] == (1 + j + 3*i)*x for i, j in + product(range(3), range(3))) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_interactions.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_interactions.py new file mode 100644 index 0000000000000000000000000000000000000000..f4fc3268368e8dd632fc0df187d57ea5e845120c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_interactions.py @@ -0,0 +1,77 @@ +""" +We have a few different kind of Matrices +Matrix, ImmutableMatrix, MatrixExpr + +Here we test the extent to which they cooperate +""" + +from sympy.core.symbol import symbols +from sympy.matrices import (Matrix, MatrixSymbol, eye, Identity, + ImmutableMatrix) +from sympy.matrices.expressions import MatrixExpr, MatAdd +from sympy.matrices.matrixbase import classof +from sympy.testing.pytest import raises + +SM = MatrixSymbol('X', 3, 3) +SV = MatrixSymbol('v', 3, 1) +MM = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) +IM = ImmutableMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) +meye = eye(3) +imeye = ImmutableMatrix(eye(3)) +ideye = Identity(3) +a, b, c = symbols('a,b,c') + + +def test_IM_MM(): + assert isinstance(MM + IM, ImmutableMatrix) + assert isinstance(IM + MM, ImmutableMatrix) + assert isinstance(2*IM + MM, ImmutableMatrix) + assert MM.equals(IM) + + +def test_ME_MM(): + assert isinstance(Identity(3) + MM, MatrixExpr) + assert isinstance(SM + MM, MatAdd) + assert isinstance(MM + SM, MatAdd) + assert (Identity(3) + MM)[1, 1] == 6 + + +def test_equality(): + a, b, c = Identity(3), eye(3), ImmutableMatrix(eye(3)) + for x in [a, b, c]: + for y in [a, b, c]: + assert x.equals(y) + + +def test_matrix_symbol_MM(): + X = MatrixSymbol('X', 3, 3) + Y = eye(3) + X + assert Y[1, 1] == 1 + X[1, 1] + + +def test_matrix_symbol_vector_matrix_multiplication(): + A = MM * SV + B = IM * SV + assert A == B + C = (SV.T * MM.T).T + assert B == C + D = (SV.T * IM.T).T + assert C == D + + +def test_indexing_interactions(): + assert (a * IM)[1, 1] == 5*a + assert (SM + IM)[1, 1] == SM[1, 1] + IM[1, 1] + assert (SM * IM)[1, 1] == SM[1, 0]*IM[0, 1] + SM[1, 1]*IM[1, 1] + \ + SM[1, 2]*IM[2, 1] + + +def test_classof(): + A = Matrix(3, 3, range(9)) + B = ImmutableMatrix(3, 3, range(9)) + C = MatrixSymbol('C', 3, 3) + assert classof(A, A) == Matrix + assert classof(B, B) == ImmutableMatrix + assert classof(A, B) == ImmutableMatrix + assert classof(B, A) == ImmutableMatrix + raises(TypeError, lambda: classof(A, C)) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_matrices.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..d9d97341de570e078d652dddce58fb8f5cb99e43 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_matrices.py @@ -0,0 +1,3487 @@ +# +# Code for testing deprecated matrix classes. New test code should not be added +# here. Instead, add it to test_matrixbase.py. +# +# This entire test module and the corresponding sympy/matrices/matrices.py +# module will be removed in a future release. +# +import random +import concurrent.futures +from collections.abc import Hashable + +from sympy.core.add import Add +from sympy.core.function import Function, diff, expand +from sympy.core.numbers import (E, Float, I, Integer, Rational, nan, oo, pi) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.integrals.integrals import integrate +from sympy.matrices.expressions.transpose import transpose +from sympy.physics.quantum.operator import HermitianOperator, Operator, Dagger +from sympy.polys.polytools import (Poly, PurePoly) +from sympy.polys.rootoftools import RootOf +from sympy.printing.str import sstr +from sympy.sets.sets import FiniteSet +from sympy.simplify.simplify import (signsimp, simplify) +from sympy.simplify.trigsimp import trigsimp +from sympy.matrices.exceptions import (ShapeError, MatrixError, + NonSquareMatrixError) +from sympy.matrices.matrixbase import DeferredVector +from sympy.matrices.determinant import _find_reasonable_pivot_naive +from sympy.matrices.utilities import _simplify +from sympy.matrices import ( + GramSchmidt, ImmutableMatrix, ImmutableSparseMatrix, Matrix, + SparseMatrix, casoratian, diag, eye, hessian, + matrix_multiply_elementwise, ones, randMatrix, rot_axis1, rot_axis2, + rot_axis3, wronskian, zeros, MutableDenseMatrix, ImmutableDenseMatrix, + MatrixSymbol, dotprodsimp, rot_ccw_axis1, rot_ccw_axis2, rot_ccw_axis3) +from sympy.matrices.utilities import _dotprodsimp_state +from sympy.core import Tuple, Wild +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.utilities.iterables import flatten, capture, iterable +from sympy.utilities.exceptions import ignore_warnings +from sympy.testing.pytest import (raises, XFAIL, slow, skip, skip_under_pyodide, + warns_deprecated_sympy) +from sympy.assumptions import Q +from sympy.tensor.array import Array +from sympy.tensor.array.array_derivatives import ArrayDerivative +from sympy.matrices.expressions import MatPow +from sympy.algebras import Quaternion + +from sympy import O + +from sympy.abc import a, b, c, d, x, y, z, t + + +# don't re-order this list +classes = (Matrix, SparseMatrix, ImmutableMatrix, ImmutableSparseMatrix) + + +# Test the deprecated matrixmixins +from sympy.matrices.common import _MinimalMatrix, _CastableMatrix +from sympy.matrices.matrices import MatrixSubspaces, MatrixReductions + + +with warns_deprecated_sympy(): + class SubspaceOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixSubspaces): + pass + + +with warns_deprecated_sympy(): + class ReductionsOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixReductions): + pass + + +def eye_Reductions(n): + return ReductionsOnlyMatrix(n, n, lambda i, j: int(i == j)) + + +def zeros_Reductions(n): + return ReductionsOnlyMatrix(n, n, lambda i, j: 0) + + +def test_args(): + for n, cls in enumerate(classes): + m = cls.zeros(3, 2) + # all should give back the same type of arguments, e.g. ints for shape + assert m.shape == (3, 2) and all(type(i) is int for i in m.shape) + assert m.rows == 3 and type(m.rows) is int + assert m.cols == 2 and type(m.cols) is int + if not n % 2: + assert type(m.flat()) in (list, tuple, Tuple) + else: + assert type(m.todok()) is dict + + +def test_deprecated_mat_smat(): + for cls in Matrix, ImmutableMatrix: + m = cls.zeros(3, 2) + with warns_deprecated_sympy(): + mat = m._mat + assert mat == m.flat() + for cls in SparseMatrix, ImmutableSparseMatrix: + m = cls.zeros(3, 2) + with warns_deprecated_sympy(): + smat = m._smat + assert smat == m.todok() + + +def test_division(): + v = Matrix(1, 2, [x, y]) + assert v/z == Matrix(1, 2, [x/z, y/z]) + + +def test_sum(): + m = Matrix([[1, 2, 3], [x, y, x], [2*y, -50, z*x]]) + assert m + m == Matrix([[2, 4, 6], [2*x, 2*y, 2*x], [4*y, -100, 2*z*x]]) + n = Matrix(1, 2, [1, 2]) + raises(ShapeError, lambda: m + n) + +def test_abs(): + m = Matrix(1, 2, [-3, x]) + n = Matrix(1, 2, [3, Abs(x)]) + assert abs(m) == n + +def test_addition(): + a = Matrix(( + (1, 2), + (3, 1), + )) + + b = Matrix(( + (1, 2), + (3, 0), + )) + + assert a + b == a.add(b) == Matrix([[2, 4], [6, 1]]) + + +def test_fancy_index_matrix(): + for M in (Matrix, SparseMatrix): + a = M(3, 3, range(9)) + assert a == a[:, :] + assert a[1, :] == Matrix(1, 3, [3, 4, 5]) + assert a[:, 1] == Matrix([1, 4, 7]) + assert a[[0, 1], :] == Matrix([[0, 1, 2], [3, 4, 5]]) + assert a[[0, 1], 2] == a[[0, 1], [2]] + assert a[2, [0, 1]] == a[[2], [0, 1]] + assert a[:, [0, 1]] == Matrix([[0, 1], [3, 4], [6, 7]]) + assert a[0, 0] == 0 + assert a[0:2, :] == Matrix([[0, 1, 2], [3, 4, 5]]) + assert a[:, 0:2] == Matrix([[0, 1], [3, 4], [6, 7]]) + assert a[::2, 1] == a[[0, 2], 1] + assert a[1, ::2] == a[1, [0, 2]] + a = M(3, 3, range(9)) + assert a[[0, 2, 1, 2, 1], :] == Matrix([ + [0, 1, 2], + [6, 7, 8], + [3, 4, 5], + [6, 7, 8], + [3, 4, 5]]) + assert a[:, [0,2,1,2,1]] == Matrix([ + [0, 2, 1, 2, 1], + [3, 5, 4, 5, 4], + [6, 8, 7, 8, 7]]) + + a = SparseMatrix.zeros(3) + a[1, 2] = 2 + a[0, 1] = 3 + a[2, 0] = 4 + assert a.extract([1, 1], [2]) == Matrix([ + [2], + [2]]) + assert a.extract([1, 0], [2, 2, 2]) == Matrix([ + [2, 2, 2], + [0, 0, 0]]) + assert a.extract([1, 0, 1, 2], [2, 0, 1, 0]) == Matrix([ + [2, 0, 0, 0], + [0, 0, 3, 0], + [2, 0, 0, 0], + [0, 4, 0, 4]]) + + +def test_multiplication(): + a = Matrix(( + (1, 2), + (3, 1), + (0, 6), + )) + + b = Matrix(( + (1, 2), + (3, 0), + )) + + c = a*b + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + try: + eval('c = a @ b') + except SyntaxError: + pass + else: + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + h = matrix_multiply_elementwise(a, c) + assert h == a.multiply_elementwise(c) + assert h[0, 0] == 7 + assert h[0, 1] == 4 + assert h[1, 0] == 18 + assert h[1, 1] == 6 + assert h[2, 0] == 0 + assert h[2, 1] == 0 + raises(ShapeError, lambda: matrix_multiply_elementwise(a, b)) + + c = b * Symbol("x") + assert isinstance(c, Matrix) + assert c[0, 0] == x + assert c[0, 1] == 2*x + assert c[1, 0] == 3*x + assert c[1, 1] == 0 + + c2 = x * b + assert c == c2 + + c = 5 * b + assert isinstance(c, Matrix) + assert c[0, 0] == 5 + assert c[0, 1] == 2*5 + assert c[1, 0] == 3*5 + assert c[1, 1] == 0 + + try: + eval('c = 5 @ b') + except SyntaxError: + pass + else: + assert isinstance(c, Matrix) + assert c[0, 0] == 5 + assert c[0, 1] == 2*5 + assert c[1, 0] == 3*5 + assert c[1, 1] == 0 + + +def test_multiplication_inf_zero(): + + M = Matrix([[oo, 0], [0, oo]]) + assert M ** 2 == M + + M = Matrix([[oo, oo], [0, 0]]) + assert M ** 2 == Matrix([[nan, nan], [nan, nan]]) + + A = Matrix([ + [0, 0, 0, -S(1)/2], + [0, 1, 0, 0], + [0, 0, 1, 0], + [-S(1)/2, 0, 0, 0]]) + + B = Matrix([ + [pi*x**2, 0, pi*b*x**4/8 + pi*a*x**4/8 + O(x**5), pi*x**4/2 + pi*b**2*x**6/32 + pi*a*b*x**6/48 + pi*a**2*x**6/32 + O(x**7)], + [0, pi*x**4/4, O(x**6), O(x**8)], + [pi*b*x**4/8 + pi*a*x**4/8 + O(x**5), O(x**6), pi*b**2*x**6/32 + pi*a*b*x**6/48 + pi*a**2*x**6/32 + O(x**7), pi*b*x**6/12 + pi*a*x**6/12 + O(x**7)], + [pi*x**4/2 + pi*b**2*x**6/32 + pi*a*b*x**6/48 + pi*a**2*x**6/32 + O(x**7), O(x**8), pi*b*x**6/12 + pi*a*x**6/12 + O(x**7), pi*x**6/3 + 3*pi*b**2*x**8/64 + pi*a*b*x**8/32 + 3*pi*a**2*x**8/64 + O(x**9)]]) + + C = Matrix([ + [-pi*x**4/4 - pi*b**2*x**6/64 - pi*a*b*x**6/96 - pi*a**2*x**6/64 + O(x**7), O(x**8), -pi*b*x**6/24 - pi*a*x**6/24 + O(x**7), -pi*x**6/6 - 3*pi*b**2*x**8/128 - pi*a*b*x**8/64 - 3*pi*a**2*x**8/128 + O(x**9)], + [ 0, pi*x**4/4, O(x**6), O(x**8)], + [ pi*b*x**4/8 + pi*a*x**4/8 + O(x**5), O(x**6), pi*b**2*x**6/32 + pi*a*b*x**6/48 + pi*a**2*x**6/32 + O(x**7), pi*b*x**6/12 + pi*a*x**6/12 + O(x**7)], + [ -pi*x**2/2, 0, -pi*b*x**4/16 - pi*a*x**4/16 + O(x**5), -pi*x**4/4 - pi*b**2*x**6/64 - pi*a*b*x**6/96 - pi*a**2*x**6/64 + O(x**7)]]) + + C2 = Matrix(4, 4, lambda i, j: Add(*(A[i,k]*B[k,j] for k in range(4)))) + + assert A*B == C == C2 + + +def test_power(): + raises(NonSquareMatrixError, lambda: Matrix((1, 2))**2) + + R = Rational + A = Matrix([[2, 3], [4, 5]]) + assert (A**-3)[:] == [R(-269)/8, R(153)/8, R(51)/2, R(-29)/2] + assert (A**5)[:] == [6140, 8097, 10796, 14237] + A = Matrix([[2, 1, 3], [4, 2, 4], [6, 12, 1]]) + assert (A**3)[:] == [290, 262, 251, 448, 440, 368, 702, 954, 433] + assert A**0 == eye(3) + assert A**1 == A + assert (Matrix([[2]]) ** 100)[0, 0] == 2**100 + assert eye(2)**10000000 == eye(2) + assert Matrix([[1, 2], [3, 4]])**Integer(2) == Matrix([[7, 10], [15, 22]]) + + A = Matrix([[33, 24], [48, 57]]) + assert (A**S.Half)[:] == [5, 2, 4, 7] + A = Matrix([[0, 4], [-1, 5]]) + assert (A**S.Half)**2 == A + + assert Matrix([[1, 0], [1, 1]])**S.Half == Matrix([[1, 0], [S.Half, 1]]) + assert Matrix([[1, 0], [1, 1]])**0.5 == Matrix([[1, 0], [0.5, 1]]) + from sympy.abc import n + assert Matrix([[1, a], [0, 1]])**n == Matrix([[1, a*n], [0, 1]]) + assert Matrix([[b, a], [0, b]])**n == Matrix([[b**n, a*b**(n-1)*n], [0, b**n]]) + assert Matrix([ + [a**n, a**(n - 1)*n, (a**n*n**2 - a**n*n)/(2*a**2)], + [ 0, a**n, a**(n - 1)*n], + [ 0, 0, a**n]]) + assert Matrix([[a, 1, 0], [0, a, 0], [0, 0, b]])**n == Matrix([ + [a**n, a**(n-1)*n, 0], + [0, a**n, 0], + [0, 0, b**n]]) + + A = Matrix([[1, 0], [1, 7]]) + assert A._matrix_pow_by_jordan_blocks(S(3)) == A._eval_pow_by_recursion(3) + A = Matrix([[2]]) + assert A**10 == Matrix([[2**10]]) == A._matrix_pow_by_jordan_blocks(S(10)) == \ + A._eval_pow_by_recursion(10) + + # testing a matrix that cannot be jordan blocked issue 11766 + m = Matrix([[3, 0, 0, 0, -3], [0, -3, -3, 0, 3], [0, 3, 0, 3, 0], [0, 0, 3, 0, 3], [3, 0, 0, 3, 0]]) + raises(MatrixError, lambda: m._matrix_pow_by_jordan_blocks(S(10))) + + # test issue 11964 + raises(MatrixError, lambda: Matrix([[1, 1], [3, 3]])._matrix_pow_by_jordan_blocks(S(-10))) + A = Matrix([[0, 1, 0], [0, 0, 1], [0, 0, 0]]) # Nilpotent jordan block size 3 + assert A**10.0 == Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + raises(ValueError, lambda: A**2.1) + raises(ValueError, lambda: A**Rational(3, 2)) + A = Matrix([[8, 1], [3, 2]]) + assert A**10.0 == Matrix([[1760744107, 272388050], [817164150, 126415807]]) + A = Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) # Nilpotent jordan block size 1 + assert A**10.0 == Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + A = Matrix([[0, 1, 0], [0, 0, 1], [0, 0, 1]]) # Nilpotent jordan block size 2 + assert A**10.0 == Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + n = Symbol('n', integer=True) + assert isinstance(A**n, MatPow) + n = Symbol('n', integer=True, negative=True) + raises(ValueError, lambda: A**n) + n = Symbol('n', integer=True, nonnegative=True) + assert A**n == Matrix([ + [KroneckerDelta(0, n), KroneckerDelta(1, n), -KroneckerDelta(0, n) - KroneckerDelta(1, n) + 1], + [ 0, KroneckerDelta(0, n), 1 - KroneckerDelta(0, n)], + [ 0, 0, 1]]) + assert A**(n + 2) == Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + raises(ValueError, lambda: A**Rational(3, 2)) + A = Matrix([[0, 0, 1], [3, 0, 1], [4, 3, 1]]) + assert A**5.0 == Matrix([[168, 72, 89], [291, 144, 161], [572, 267, 329]]) + assert A**5.0 == A**5 + A = Matrix([[0, 1, 0],[-1, 0, 0],[0, 0, 0]]) + n = Symbol("n") + An = A**n + assert An.subs(n, 2).doit() == A**2 + raises(ValueError, lambda: An.subs(n, -2).doit()) + assert An * An == A**(2*n) + + # concretizing behavior for non-integer and complex powers + A = Matrix([[0,0,0],[0,0,0],[0,0,0]]) + n = Symbol('n', integer=True, positive=True) + assert A**n == A + n = Symbol('n', integer=True, nonnegative=True) + assert A**n == diag(0**n, 0**n, 0**n) + assert (A**n).subs(n, 0) == eye(3) + assert (A**n).subs(n, 1) == zeros(3) + A = Matrix ([[2,0,0],[0,2,0],[0,0,2]]) + assert A**2.1 == diag (2**2.1, 2**2.1, 2**2.1) + assert A**I == diag (2**I, 2**I, 2**I) + A = Matrix([[0, 1, 0], [0, 0, 1], [0, 0, 1]]) + raises(ValueError, lambda: A**2.1) + raises(ValueError, lambda: A**I) + A = Matrix([[S.Half, S.Half], [S.Half, S.Half]]) + assert A**S.Half == A + A = Matrix([[1, 1],[3, 3]]) + assert A**S.Half == Matrix ([[S.Half, S.Half], [3*S.Half, 3*S.Half]]) + + +def test_issue_17247_expression_blowup_1(): + M = Matrix([[1+x, 1-x], [1-x, 1+x]]) + with dotprodsimp(True): + assert M.exp().expand() == Matrix([ + [ (exp(2*x) + exp(2))/2, (-exp(2*x) + exp(2))/2], + [(-exp(2*x) + exp(2))/2, (exp(2*x) + exp(2))/2]]) + +def test_issue_17247_expression_blowup_2(): + M = Matrix([[1+x, 1-x], [1-x, 1+x]]) + with dotprodsimp(True): + P, J = M.jordan_form () + assert P*J*P.inv() + +def test_issue_17247_expression_blowup_3(): + M = Matrix([[1+x, 1-x], [1-x, 1+x]]) + with dotprodsimp(True): + assert M**100 == Matrix([ + [633825300114114700748351602688*x**100 + 633825300114114700748351602688, 633825300114114700748351602688 - 633825300114114700748351602688*x**100], + [633825300114114700748351602688 - 633825300114114700748351602688*x**100, 633825300114114700748351602688*x**100 + 633825300114114700748351602688]]) + +def test_issue_17247_expression_blowup_4(): +# This matrix takes extremely long on current master even with intermediate simplification so an abbreviated version is used. It is left here for test in case of future optimizations. +# M = Matrix(S('''[ +# [ -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128, 3/64 + 13*I/64, -23/32 - 59*I/256, 15/128 - 3*I/32, 19/256 + 551*I/1024], +# [-149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024, 119/128 + 143*I/128, -10879/2048 + 4343*I/4096, 129/256 - 549*I/512, 42533/16384 + 29103*I/8192], +# [ 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128, 3/64 + 13*I/64, -23/32 - 59*I/256], +# [ -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024, 119/128 + 143*I/128, -10879/2048 + 4343*I/4096], +# [ 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128], +# [ 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024], +# [ -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64], +# [ 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512], +# [ -4*I, 27/2 + 6*I, -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64], +# [ 1/4 + 5*I/2, -23/8 - 57*I/16, 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128], +# [ -4, 9 - 5*I, -4*I, 27/2 + 6*I, -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16], +# [ -2*I, 119/8 + 29*I/4, 1/4 + 5*I/2, -23/8 - 57*I/16, 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128]]''')) +# assert M**10 == Matrix([ +# [ 7*(-221393644768594642173548179825793834595 - 1861633166167425978847110897013541127952*I)/9671406556917033397649408, 15*(31670992489131684885307005100073928751695 + 10329090958303458811115024718207404523808*I)/77371252455336267181195264, 7*(-3710978679372178839237291049477017392703 + 1377706064483132637295566581525806894169*I)/19342813113834066795298816, (9727707023582419994616144751727760051598 - 59261571067013123836477348473611225724433*I)/9671406556917033397649408, (31896723509506857062605551443641668183707 + 54643444538699269118869436271152084599580*I)/38685626227668133590597632, (-2024044860947539028275487595741003997397402 + 130959428791783397562960461903698670485863*I)/309485009821345068724781056, 3*(26190251453797590396533756519358368860907 - 27221191754180839338002754608545400941638*I)/77371252455336267181195264, (1154643595139959842768960128434994698330461 + 3385496216250226964322872072260446072295634*I)/618970019642690137449562112, 3*(-31849347263064464698310044805285774295286 - 11877437776464148281991240541742691164309*I)/77371252455336267181195264, (4661330392283532534549306589669150228040221 - 4171259766019818631067810706563064103956871*I)/1237940039285380274899124224, (9598353794289061833850770474812760144506 + 358027153990999990968244906482319780943983*I)/309485009821345068724781056, (-9755135335127734571547571921702373498554177 - 4837981372692695195747379349593041939686540*I)/2475880078570760549798248448], +# [(-379516731607474268954110071392894274962069 - 422272153179747548473724096872271700878296*I)/77371252455336267181195264, (41324748029613152354787280677832014263339501 - 12715121258662668420833935373453570749288074*I)/1237940039285380274899124224, (-339216903907423793947110742819264306542397 + 494174755147303922029979279454787373566517*I)/77371252455336267181195264, (-18121350839962855576667529908850640619878381 - 37413012454129786092962531597292531089199003*I)/1237940039285380274899124224, (2489661087330511608618880408199633556675926 + 1137821536550153872137379935240732287260863*I)/309485009821345068724781056, (-136644109701594123227587016790354220062972119 + 110130123468183660555391413889600443583585272*I)/4951760157141521099596496896, (1488043981274920070468141664150073426459593 - 9691968079933445130866371609614474474327650*I)/1237940039285380274899124224, 27*(4636797403026872518131756991410164760195942 + 3369103221138229204457272860484005850416533*I)/4951760157141521099596496896, (-8534279107365915284081669381642269800472363 + 2241118846262661434336333368511372725482742*I)/1237940039285380274899124224, (60923350128174260992536531692058086830950875 - 263673488093551053385865699805250505661590126*I)/9903520314283042199192993792, (18520943561240714459282253753348921824172569 + 24846649186468656345966986622110971925703604*I)/4951760157141521099596496896, (-232781130692604829085973604213529649638644431 + 35981505277760667933017117949103953338570617*I)/9903520314283042199192993792], +# [ (8742968295129404279528270438201520488950 + 3061473358639249112126847237482570858327*I)/4835703278458516698824704, (-245657313712011778432792959787098074935273 + 253113767861878869678042729088355086740856*I)/38685626227668133590597632, (1947031161734702327107371192008011621193 - 19462330079296259148177542369999791122762*I)/9671406556917033397649408, (552856485625209001527688949522750288619217 + 392928441196156725372494335248099016686580*I)/77371252455336267181195264, (-44542866621905323121630214897126343414629 + 3265340021421335059323962377647649632959*I)/19342813113834066795298816, (136272594005759723105646069956434264218730 - 330975364731707309489523680957584684763587*I)/38685626227668133590597632, (27392593965554149283318732469825168894401 + 75157071243800133880129376047131061115278*I)/38685626227668133590597632, 7*(-357821652913266734749960136017214096276154 - 45509144466378076475315751988405961498243*I)/309485009821345068724781056, (104485001373574280824835174390219397141149 - 99041000529599568255829489765415726168162*I)/77371252455336267181195264, (1198066993119982409323525798509037696321291 + 4249784165667887866939369628840569844519936*I)/618970019642690137449562112, (-114985392587849953209115599084503853611014 - 52510376847189529234864487459476242883449*I)/77371252455336267181195264, (6094620517051332877965959223269600650951573 - 4683469779240530439185019982269137976201163*I)/1237940039285380274899124224], +# [ (611292255597977285752123848828590587708323 - 216821743518546668382662964473055912169502*I)/77371252455336267181195264, (-1144023204575811464652692396337616594307487 + 12295317806312398617498029126807758490062855*I)/309485009821345068724781056, (-374093027769390002505693378578475235158281 - 573533923565898290299607461660384634333639*I)/77371252455336267181195264, (47405570632186659000138546955372796986832987 - 2837476058950808941605000274055970055096534*I)/1237940039285380274899124224, (-571573207393621076306216726219753090535121 + 533381457185823100878764749236639320783831*I)/77371252455336267181195264, (-7096548151856165056213543560958582513797519 - 24035731898756040059329175131592138642195366*I)/618970019642690137449562112, (2396762128833271142000266170154694033849225 + 1448501087375679588770230529017516492953051*I)/309485009821345068724781056, (-150609293845161968447166237242456473262037053 + 92581148080922977153207018003184520294188436*I)/4951760157141521099596496896, 5*(270278244730804315149356082977618054486347 - 1997830155222496880429743815321662710091562*I)/1237940039285380274899124224, (62978424789588828258068912690172109324360330 + 44803641177219298311493356929537007630129097*I)/2475880078570760549798248448, 19*(-451431106327656743945775812536216598712236 + 114924966793632084379437683991151177407937*I)/1237940039285380274899124224, (63417747628891221594106738815256002143915995 - 261508229397507037136324178612212080871150958*I)/9903520314283042199192993792], +# [ (-2144231934021288786200752920446633703357 + 2305614436009705803670842248131563850246*I)/1208925819614629174706176, (-90720949337459896266067589013987007078153 - 221951119475096403601562347412753844534569*I)/19342813113834066795298816, (11590973613116630788176337262688659880376 + 6514520676308992726483494976339330626159*I)/4835703278458516698824704, 3*(-131776217149000326618649542018343107657237 + 79095042939612668486212006406818285287004*I)/38685626227668133590597632, (10100577916793945997239221374025741184951 - 28631383488085522003281589065994018550748*I)/9671406556917033397649408, 67*(10090295594251078955008130473573667572549 + 10449901522697161049513326446427839676762*I)/77371252455336267181195264, (-54270981296988368730689531355811033930513 - 3413683117592637309471893510944045467443*I)/19342813113834066795298816, (440372322928679910536575560069973699181278 - 736603803202303189048085196176918214409081*I)/77371252455336267181195264, (33220374714789391132887731139763250155295 + 92055083048787219934030779066298919603554*I)/38685626227668133590597632, 5*(-594638554579967244348856981610805281527116 - 82309245323128933521987392165716076704057*I)/309485009821345068724781056, (128056368815300084550013708313312073721955 - 114619107488668120303579745393765245911404*I)/77371252455336267181195264, 21*(59839959255173222962789517794121843393573 + 241507883613676387255359616163487405826334*I)/618970019642690137449562112], +# [ (-13454485022325376674626653802541391955147 + 184471402121905621396582628515905949793486*I)/19342813113834066795298816, (-6158730123400322562149780662133074862437105 - 3416173052604643794120262081623703514107476*I)/154742504910672534362390528, (770558003844914708453618983120686116100419 - 127758381209767638635199674005029818518766*I)/77371252455336267181195264, (-4693005771813492267479835161596671660631703 + 12703585094750991389845384539501921531449948*I)/309485009821345068724781056, (-295028157441149027913545676461260860036601 - 841544569970643160358138082317324743450770*I)/77371252455336267181195264, (56716442796929448856312202561538574275502893 + 7216818824772560379753073185990186711454778*I)/1237940039285380274899124224, 15*(-87061038932753366532685677510172566368387 + 61306141156647596310941396434445461895538*I)/154742504910672534362390528, (-3455315109680781412178133042301025723909347 - 24969329563196972466388460746447646686670670*I)/618970019642690137449562112, (2453418854160886481106557323699250865361849 + 1497886802326243014471854112161398141242514*I)/309485009821345068724781056, (-151343224544252091980004429001205664193082173 + 90471883264187337053549090899816228846836628*I)/4951760157141521099596496896, (1652018205533026103358164026239417416432989 - 9959733619236515024261775397109724431400162*I)/1237940039285380274899124224, 3*(40676374242956907656984876692623172736522006 + 31023357083037817469535762230872667581366205*I)/4951760157141521099596496896], +# [ (-1226990509403328460274658603410696548387 - 4131739423109992672186585941938392788458*I)/1208925819614629174706176, (162392818524418973411975140074368079662703 + 23706194236915374831230612374344230400704*I)/9671406556917033397649408, (-3935678233089814180000602553655565621193 + 2283744757287145199688061892165659502483*I)/1208925819614629174706176, (-2400210250844254483454290806930306285131 - 315571356806370996069052930302295432758205*I)/19342813113834066795298816, (13365917938215281056563183751673390817910 + 15911483133819801118348625831132324863881*I)/4835703278458516698824704, 3*(-215950551370668982657516660700301003897855 + 51684341999223632631602864028309400489378*I)/38685626227668133590597632, (20886089946811765149439844691320027184765 - 30806277083146786592790625980769214361844*I)/9671406556917033397649408, (562180634592713285745940856221105667874855 + 1031543963988260765153550559766662245114916*I)/77371252455336267181195264, (-65820625814810177122941758625652476012867 - 12429918324787060890804395323920477537595*I)/19342813113834066795298816, (319147848192012911298771180196635859221089 - 402403304933906769233365689834404519960394*I)/38685626227668133590597632, (23035615120921026080284733394359587955057 + 115351677687031786114651452775242461310624*I)/38685626227668133590597632, (-3426830634881892756966440108592579264936130 - 1022954961164128745603407283836365128598559*I)/309485009821345068724781056], +# [ (-192574788060137531023716449082856117537757 - 69222967328876859586831013062387845780692*I)/19342813113834066795298816, (2736383768828013152914815341491629299773262 - 2773252698016291897599353862072533475408743*I)/77371252455336267181195264, (-23280005281223837717773057436155921656805 + 214784953368021840006305033048142888879224*I)/19342813113834066795298816, (-3035247484028969580570400133318947903462326 - 2195168903335435855621328554626336958674325*I)/77371252455336267181195264, (984552428291526892214541708637840971548653 - 64006622534521425620714598573494988589378*I)/77371252455336267181195264, (-3070650452470333005276715136041262898509903 + 7286424705750810474140953092161794621989080*I)/154742504910672534362390528, (-147848877109756404594659513386972921139270 - 416306113044186424749331418059456047650861*I)/38685626227668133590597632, (55272118474097814260289392337160619494260781 + 7494019668394781211907115583302403519488058*I)/1237940039285380274899124224, (-581537886583682322424771088996959213068864 + 542191617758465339135308203815256798407429*I)/77371252455336267181195264, (-6422548983676355789975736799494791970390991 - 23524183982209004826464749309156698827737702*I)/618970019642690137449562112, 7*(180747195387024536886923192475064903482083 + 84352527693562434817771649853047924991804*I)/154742504910672534362390528, (-135485179036717001055310712747643466592387031 + 102346575226653028836678855697782273460527608*I)/4951760157141521099596496896], +# [ (3384238362616083147067025892852431152105 + 156724444932584900214919898954874618256*I)/604462909807314587353088, (-59558300950677430189587207338385764871866 + 114427143574375271097298201388331237478857*I)/4835703278458516698824704, (-1356835789870635633517710130971800616227 - 7023484098542340388800213478357340875410*I)/1208925819614629174706176, (234884918567993750975181728413524549575881 + 79757294640629983786895695752733890213506*I)/9671406556917033397649408, (-7632732774935120473359202657160313866419 + 2905452608512927560554702228553291839465*I)/1208925819614629174706176, (52291747908702842344842889809762246649489 - 520996778817151392090736149644507525892649*I)/19342813113834066795298816, (17472406829219127839967951180375981717322 + 23464704213841582137898905375041819568669*I)/4835703278458516698824704, (-911026971811893092350229536132730760943307 + 150799318130900944080399439626714846752360*I)/38685626227668133590597632, (26234457233977042811089020440646443590687 - 45650293039576452023692126463683727692890*I)/9671406556917033397649408, 3*(288348388717468992528382586652654351121357 + 454526517721403048270274049572136109264668*I)/77371252455336267181195264, (-91583492367747094223295011999405657956347 - 12704691128268298435362255538069612411331*I)/19342813113834066795298816, (411208730251327843849027957710164064354221 - 569898526380691606955496789378230959965898*I)/38685626227668133590597632], +# [ (27127513117071487872628354831658811211795 - 37765296987901990355760582016892124833857*I)/4835703278458516698824704, (1741779916057680444272938534338833170625435 + 3083041729779495966997526404685535449810378*I)/77371252455336267181195264, 3*(-60642236251815783728374561836962709533401 - 24630301165439580049891518846174101510744*I)/19342813113834066795298816, 3*(445885207364591681637745678755008757483408 - 350948497734812895032502179455610024541643*I)/38685626227668133590597632, (-47373295621391195484367368282471381775684 + 219122969294089357477027867028071400054973*I)/19342813113834066795298816, (-2801565819673198722993348253876353741520438 - 2250142129822658548391697042460298703335701*I)/77371252455336267181195264, (801448252275607253266997552356128790317119 - 50890367688077858227059515894356594900558*I)/77371252455336267181195264, (-5082187758525931944557763799137987573501207 + 11610432359082071866576699236013484487676124*I)/309485009821345068724781056, (-328925127096560623794883760398247685166830 - 643447969697471610060622160899409680422019*I)/77371252455336267181195264, 15*(2954944669454003684028194956846659916299765 + 33434406416888505837444969347824812608566*I)/1237940039285380274899124224, (-415749104352001509942256567958449835766827 + 479330966144175743357171151440020955412219*I)/77371252455336267181195264, 3*(-4639987285852134369449873547637372282914255 - 11994411888966030153196659207284951579243273*I)/1237940039285380274899124224], +# [ (-478846096206269117345024348666145495601 + 1249092488629201351470551186322814883283*I)/302231454903657293676544, (-17749319421930878799354766626365926894989 - 18264580106418628161818752318217357231971*I)/1208925819614629174706176, (2801110795431528876849623279389579072819 + 363258850073786330770713557775566973248*I)/604462909807314587353088, (-59053496693129013745775512127095650616252 + 78143588734197260279248498898321500167517*I)/4835703278458516698824704, (-283186724922498212468162690097101115349 - 6443437753863179883794497936345437398276*I)/1208925819614629174706176, (188799118826748909206887165661384998787543 + 84274736720556630026311383931055307398820*I)/9671406556917033397649408, (-5482217151670072904078758141270295025989 + 1818284338672191024475557065444481298568*I)/1208925819614629174706176, (56564463395350195513805521309731217952281 - 360208541416798112109946262159695452898431*I)/19342813113834066795298816, 11*(1259539805728870739006416869463689438068 + 1409136581547898074455004171305324917387*I)/4835703278458516698824704, 5*(-123701190701414554945251071190688818343325 + 30997157322590424677294553832111902279712*I)/38685626227668133590597632, (16130917381301373033736295883982414239781 - 32752041297570919727145380131926943374516*I)/9671406556917033397649408, (650301385108223834347093740500375498354925 + 899526407681131828596801223402866051809258*I)/77371252455336267181195264], +# [ (9011388245256140876590294262420614839483 + 8167917972423946282513000869327525382672*I)/1208925819614629174706176, (-426393174084720190126376382194036323028924 + 180692224825757525982858693158209545430621*I)/9671406556917033397649408, (24588556702197802674765733448108154175535 - 45091766022876486566421953254051868331066*I)/4835703278458516698824704, (1872113939365285277373877183750416985089691 + 3030392393733212574744122057679633775773130*I)/77371252455336267181195264, (-222173405538046189185754954524429864167549 - 75193157893478637039381059488387511299116*I)/19342813113834066795298816, (2670821320766222522963689317316937579844558 - 2645837121493554383087981511645435472169191*I)/77371252455336267181195264, 5*(-2100110309556476773796963197283876204940 + 41957457246479840487980315496957337371937*I)/19342813113834066795298816, (-5733743755499084165382383818991531258980593 - 3328949988392698205198574824396695027195732*I)/154742504910672534362390528, (707827994365259025461378911159398206329247 - 265730616623227695108042528694302299777294*I)/77371252455336267181195264, (-1442501604682933002895864804409322823788319 + 11504137805563265043376405214378288793343879*I)/309485009821345068724781056, (-56130472299445561499538726459719629522285 - 61117552419727805035810982426639329818864*I)/9671406556917033397649408, (39053692321126079849054272431599539429908717 - 10209127700342570953247177602860848130710666*I)/1237940039285380274899124224]]) + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512], + [ 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64], + [ -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128], + [ 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16], + [ 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M**10 == Matrix(S('''[ + [ 7369525394972778926719607798014571861/604462909807314587353088 - 229284202061790301477392339912557559*I/151115727451828646838272, -19704281515163975949388435612632058035/1208925819614629174706176 + 14319858347987648723768698170712102887*I/302231454903657293676544, -3623281909451783042932142262164941211/604462909807314587353088 - 6039240602494288615094338643452320495*I/604462909807314587353088, 109260497799140408739847239685705357695/2417851639229258349412352 - 7427566006564572463236368211555511431*I/2417851639229258349412352, -16095803767674394244695716092817006641/2417851639229258349412352 + 10336681897356760057393429626719177583*I/1208925819614629174706176, -42207883340488041844332828574359769743/2417851639229258349412352 - 182332262671671273188016400290188468499*I/4835703278458516698824704], + [50566491050825573392726324995779608259/1208925819614629174706176 - 90047007594468146222002432884052362145*I/2417851639229258349412352, 74273703462900000967697427843983822011/1208925819614629174706176 + 265947522682943571171988741842776095421*I/1208925819614629174706176, -116900341394390200556829767923360888429/2417851639229258349412352 - 53153263356679268823910621474478756845*I/2417851639229258349412352, 195407378023867871243426523048612490249/1208925819614629174706176 - 1242417915995360200584837585002906728929*I/9671406556917033397649408, -863597594389821970177319682495878193/302231454903657293676544 + 476936100741548328800725360758734300481*I/9671406556917033397649408, -3154451590535653853562472176601754835575/19342813113834066795298816 - 232909875490506237386836489998407329215*I/2417851639229258349412352], + [ -1715444997702484578716037230949868543/302231454903657293676544 + 5009695651321306866158517287924120777*I/302231454903657293676544, -30551582497996879620371947949342101301/604462909807314587353088 - 7632518367986526187139161303331519629*I/151115727451828646838272, 312680739924495153190604170938220575/18889465931478580854784 - 108664334509328818765959789219208459*I/75557863725914323419136, -14693696966703036206178521686918865509/604462909807314587353088 + 72345386220900843930147151999899692401*I/1208925819614629174706176, -8218872496728882299722894680635296519/1208925819614629174706176 - 16776782833358893712645864791807664983*I/1208925819614629174706176, 143237839169380078671242929143670635137/2417851639229258349412352 + 2883817094806115974748882735218469447*I/2417851639229258349412352], + [ 3087979417831061365023111800749855987/151115727451828646838272 + 34441942370802869368851419102423997089*I/604462909807314587353088, -148309181940158040917731426845476175667/604462909807314587353088 - 263987151804109387844966835369350904919*I/9671406556917033397649408, 50259518594816377378747711930008883165/1208925819614629174706176 - 95713974916869240305450001443767979653*I/2417851639229258349412352, 153466447023875527996457943521467271119/2417851639229258349412352 + 517285524891117105834922278517084871349*I/2417851639229258349412352, -29184653615412989036678939366291205575/604462909807314587353088 - 27551322282526322041080173287022121083*I/1208925819614629174706176, 196404220110085511863671393922447671649/1208925819614629174706176 - 1204712019400186021982272049902206202145*I/9671406556917033397649408], + [ -2632581805949645784625606590600098779/151115727451828646838272 - 589957435912868015140272627522612771*I/37778931862957161709568, 26727850893953715274702844733506310247/302231454903657293676544 - 10825791956782128799168209600694020481*I/302231454903657293676544, -1036348763702366164044671908440791295/151115727451828646838272 + 3188624571414467767868303105288107375*I/151115727451828646838272, -36814959939970644875593411585393242449/604462909807314587353088 - 18457555789119782404850043842902832647*I/302231454903657293676544, 12454491297984637815063964572803058647/604462909807314587353088 - 340489532842249733975074349495329171*I/302231454903657293676544, -19547211751145597258386735573258916681/604462909807314587353088 + 87299583775782199663414539883938008933*I/1208925819614629174706176], + [ -40281994229560039213253423262678393183/604462909807314587353088 - 2939986850065527327299273003299736641*I/604462909807314587353088, 331940684638052085845743020267462794181/2417851639229258349412352 - 284574901963624403933361315517248458969*I/1208925819614629174706176, 6453843623051745485064693628073010961/302231454903657293676544 + 36062454107479732681350914931391590957*I/604462909807314587353088, -147665869053634695632880753646441962067/604462909807314587353088 - 305987938660447291246597544085345123927*I/9671406556917033397649408, 107821369195275772166593879711259469423/2417851639229258349412352 - 11645185518211204108659001435013326687*I/302231454903657293676544, 64121228424717666402009446088588091619/1208925819614629174706176 + 265557133337095047883844369272389762133*I/1208925819614629174706176]]''')) + +def test_issue_17247_expression_blowup_5(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.charpoly('x') == PurePoly(x**6 + (-6 - 6*I)*x**5 + 36*I*x**4, x, domain='EX') + +def test_issue_17247_expression_blowup_6(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.det('bareiss') == 0 + +def test_issue_17247_expression_blowup_7(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.det('berkowitz') == 0 + +def test_issue_17247_expression_blowup_8(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.det('lu') == 0 + +def test_issue_17247_expression_blowup_9(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.rref() == (Matrix([ + [1, 0, -1, -2, -3, -4, -5, -6], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0]]), (0, 1)) + +def test_issue_17247_expression_blowup_10(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.cofactor(0, 0) == 0 + +def test_issue_17247_expression_blowup_11(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.cofactor_matrix() == Matrix(6, 6, [0]*36) + +def test_issue_17247_expression_blowup_12(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.eigenvals() == {6: 1, 6*I: 1, 0: 4} + +def test_issue_17247_expression_blowup_13(): + M = Matrix([ + [ 0, 1 - x, x + 1, 1 - x], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 1 - x], + [ 0, 0, 1 - x, 0]]) + + ev = M.eigenvects() + assert ev[0] == (0, 2, [Matrix([0, -1, 0, 1])]) + assert ev[1][0] == x - sqrt(2)*(x - 1) + 1 + assert ev[1][1] == 1 + assert ev[1][2][0].expand(deep=False, numer=True) == Matrix([ + [(-x + sqrt(2)*(x - 1) - 1)/(x - 1)], + [-4*x/(x**2 - 2*x + 1) + (x + 1)*(x - sqrt(2)*(x - 1) + 1)/(x**2 - 2*x + 1)], + [(-x + sqrt(2)*(x - 1) - 1)/(x - 1)], + [1] + ]) + + assert ev[2][0] == x + sqrt(2)*(x - 1) + 1 + assert ev[2][1] == 1 + assert ev[2][2][0].expand(deep=False, numer=True) == Matrix([ + [(-x - sqrt(2)*(x - 1) - 1)/(x - 1)], + [-4*x/(x**2 - 2*x + 1) + (x + 1)*(x + sqrt(2)*(x - 1) + 1)/(x**2 - 2*x + 1)], + [(-x - sqrt(2)*(x - 1) - 1)/(x - 1)], + [1] + ]) + + +def test_issue_17247_expression_blowup_14(): + M = Matrix(8, 8, ([1+x, 1-x]*4 + [1-x, 1+x]*4)*4) + with dotprodsimp(True): + assert M.echelon_form() == Matrix([ + [x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x], + [ 0, 4*x, 0, 4*x, 0, 4*x, 0, 4*x], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0]]) + +def test_issue_17247_expression_blowup_15(): + M = Matrix(8, 8, ([1+x, 1-x]*4 + [1-x, 1+x]*4)*4) + with dotprodsimp(True): + assert M.rowspace() == [Matrix([[x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x]]), Matrix([[0, 4*x, 0, 4*x, 0, 4*x, 0, 4*x]])] + +def test_issue_17247_expression_blowup_16(): + M = Matrix(8, 8, ([1+x, 1-x]*4 + [1-x, 1+x]*4)*4) + with dotprodsimp(True): + assert M.columnspace() == [Matrix([[x + 1],[1 - x],[x + 1],[1 - x],[x + 1],[1 - x],[x + 1],[1 - x]]), Matrix([[1 - x],[x + 1],[1 - x],[x + 1],[1 - x],[x + 1],[1 - x],[x + 1]])] + +def test_issue_17247_expression_blowup_17(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.nullspace() == [ + Matrix([[1],[-2],[1],[0],[0],[0],[0],[0]]), + Matrix([[2],[-3],[0],[1],[0],[0],[0],[0]]), + Matrix([[3],[-4],[0],[0],[1],[0],[0],[0]]), + Matrix([[4],[-5],[0],[0],[0],[1],[0],[0]]), + Matrix([[5],[-6],[0],[0],[0],[0],[1],[0]]), + Matrix([[6],[-7],[0],[0],[0],[0],[0],[1]])] + +def test_issue_17247_expression_blowup_18(): + M = Matrix(6, 6, ([1+x, 1-x]*3 + [1-x, 1+x]*3)*3) + with dotprodsimp(True): + assert not M.is_nilpotent() + +def test_issue_17247_expression_blowup_19(): + M = Matrix(S('''[ + [ -3/4, 0, 1/4 + I/2, 0], + [ 0, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 1/2 - I, 0, 0, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert not M.is_diagonalizable() + +def test_issue_17247_expression_blowup_20(): + M = Matrix([ + [x + 1, 1 - x, 0, 0], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 0], + [ 0, 0, 0, x + 1]]) + with dotprodsimp(True): + assert M.diagonalize() == (Matrix([ + [1, 1, 0, (x + 1)/(x - 1)], + [1, -1, 0, 0], + [1, 1, 1, 0], + [0, 0, 0, 1]]), + Matrix([ + [2, 0, 0, 0], + [0, 2*x, 0, 0], + [0, 0, x + 1, 0], + [0, 0, 0, x + 1]])) + +def test_issue_17247_expression_blowup_21(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='GE') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + +def test_issue_17247_expression_blowup_22(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='LU') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + +def test_issue_17247_expression_blowup_23(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='ADJ').expand() == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + +def test_issue_17247_expression_blowup_24(): + M = SparseMatrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='CH') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + +def test_issue_17247_expression_blowup_25(): + M = SparseMatrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='LDL') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + +def test_issue_17247_expression_blowup_26(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024], + [ 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64], + [ -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512], + [ 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64], + [ 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128], + [ -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16], + [ 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.rank() == 4 + +def test_issue_17247_expression_blowup_27(): + M = Matrix([ + [ 0, 1 - x, x + 1, 1 - x], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 1 - x], + [ 0, 0, 1 - x, 0]]) + with dotprodsimp(True): + P, J = M.jordan_form() + assert P.expand() == Matrix(S('''[ + [ 0, 4*x/(x**2 - 2*x + 1), -(-17*x**4 + 12*sqrt(2)*x**4 - 4*sqrt(2)*x**3 + 6*x**3 - 6*x - 4*sqrt(2)*x + 12*sqrt(2) + 17)/(-7*x**4 + 5*sqrt(2)*x**4 - 6*sqrt(2)*x**3 + 8*x**3 - 2*x**2 + 8*x + 6*sqrt(2)*x - 5*sqrt(2) - 7), -(12*sqrt(2)*x**4 + 17*x**4 - 6*x**3 - 4*sqrt(2)*x**3 - 4*sqrt(2)*x + 6*x - 17 + 12*sqrt(2))/(7*x**4 + 5*sqrt(2)*x**4 - 6*sqrt(2)*x**3 - 8*x**3 + 2*x**2 - 8*x + 6*sqrt(2)*x - 5*sqrt(2) + 7)], + [x - 1, x/(x - 1) + 1/(x - 1), (-7*x**3 + 5*sqrt(2)*x**3 - x**2 + sqrt(2)*x**2 - sqrt(2)*x - x - 5*sqrt(2) - 7)/(-3*x**3 + 2*sqrt(2)*x**3 - 2*sqrt(2)*x**2 + 3*x**2 + 2*sqrt(2)*x + 3*x - 3 - 2*sqrt(2)), (7*x**3 + 5*sqrt(2)*x**3 + x**2 + sqrt(2)*x**2 - sqrt(2)*x + x - 5*sqrt(2) + 7)/(2*sqrt(2)*x**3 + 3*x**3 - 3*x**2 - 2*sqrt(2)*x**2 - 3*x + 2*sqrt(2)*x - 2*sqrt(2) + 3)], + [ 0, 1, -(-3*x**2 + 2*sqrt(2)*x**2 + 2*x - 3 - 2*sqrt(2))/(-x**2 + sqrt(2)*x**2 - 2*sqrt(2)*x + 1 + sqrt(2)), -(2*sqrt(2)*x**2 + 3*x**2 - 2*x - 2*sqrt(2) + 3)/(x**2 + sqrt(2)*x**2 - 2*sqrt(2)*x - 1 + sqrt(2))], + [1 - x, 0, 1, 1]]''')).expand() + assert J == Matrix(S('''[ + [0, 1, 0, 0], + [0, 0, 0, 0], + [0, 0, x - sqrt(2)*(x - 1) + 1, 0], + [0, 0, 0, x + sqrt(2)*(x - 1) + 1]]''')) + +def test_issue_17247_expression_blowup_28(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.singular_values() == S('''[ + sqrt(14609315/131072 + sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) + 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2 + sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2), + sqrt(14609315/131072 - sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) + 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2 + sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2), + sqrt(14609315/131072 - sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2 + sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) - 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2), + sqrt(14609315/131072 - sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2 - sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) - 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2)]''') + + +def test_issue_16823(): + # This still needs to be fixed if not using dotprodsimp. + M = Matrix(S('''[ + [1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I,-9/32-1/16*I,183/256-97/128*I,3/64+13/64*I,-23/32-59/256*I,15/128-3/32*I,19/256+551/1024*I], + [21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I,-219/128+115/256*I,6301/4096-6609/1024*I,119/128+143/128*I,-10879/2048+4343/4096*I,129/256-549/512*I,42533/16384+29103/8192*I], + [-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I,-9/32-1/16*I,183/256-97/128*I,3/64+13/64*I,-23/32-59/256*I], + [1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I,-219/128+115/256*I,6301/4096-6609/1024*I,119/128+143/128*I,-10879/2048+4343/4096*I], + [-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I,-9/32-1/16*I,183/256-97/128*I], + [1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I,-219/128+115/256*I,6301/4096-6609/1024*I], + [-4,9-5*I,-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I], + [-2*I,119/8+29/4*I,1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I], + [0,-6,-4,9-5*I,-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I], + [1,-9/4+3*I,-2*I,119/8+29/4*I,1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I], + [0,-4*I,0,-6,-4,9-5*I,-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I], + [0,1/4+1/2*I,1,-9/4+3*I,-2*I,119/8+29/4*I,1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I]]''')) + with dotprodsimp(True): + assert M.rank() == 8 + + +def test_issue_18531(): + # solve_linear_system still needs fixing but the rref works. + M = Matrix([ + [1, 1, 1, 1, 1, 0, 1, 0, 0], + [1 + sqrt(2), -1 + sqrt(2), 1 - sqrt(2), -sqrt(2) - 1, 1, 1, -1, 1, 1], + [-5 + 2*sqrt(2), -5 - 2*sqrt(2), -5 - 2*sqrt(2), -5 + 2*sqrt(2), -7, 2, -7, -2, 0], + [-3*sqrt(2) - 1, 1 - 3*sqrt(2), -1 + 3*sqrt(2), 1 + 3*sqrt(2), -7, -5, 7, -5, 3], + [7 - 4*sqrt(2), 4*sqrt(2) + 7, 4*sqrt(2) + 7, 7 - 4*sqrt(2), 7, -12, 7, 12, 0], + [-1 + 3*sqrt(2), 1 + 3*sqrt(2), -3*sqrt(2) - 1, 1 - 3*sqrt(2), 7, -5, -7, -5, 3], + [-3 + 2*sqrt(2), -3 - 2*sqrt(2), -3 - 2*sqrt(2), -3 + 2*sqrt(2), -1, 2, -1, -2, 0], + [1 - sqrt(2), -sqrt(2) - 1, 1 + sqrt(2), -1 + sqrt(2), -1, 1, 1, 1, 1] + ]) + with dotprodsimp(True): + assert M.rref() == (Matrix([ + [1, 0, 0, 0, 0, 0, 0, 0, S(1)/2], + [0, 1, 0, 0, 0, 0, 0, 0, -S(1)/2], + [0, 0, 1, 0, 0, 0, 0, 0, S(1)/2], + [0, 0, 0, 1, 0, 0, 0, 0, -S(1)/2], + [0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, -S(1)/2], + [0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, -S(1)/2]]), (0, 1, 2, 3, 4, 5, 6, 7)) + + +def test_creation(): + raises(ValueError, lambda: Matrix(5, 5, range(20))) + raises(ValueError, lambda: Matrix(5, -1, [])) + raises(IndexError, lambda: Matrix((1, 2))[2]) + with raises(IndexError): + Matrix((1, 2))[3] = 5 + + assert Matrix() == Matrix([]) == Matrix(0, 0, []) + assert Matrix([[]]) == Matrix(1, 0, []) + assert Matrix([[], []]) == Matrix(2, 0, []) + + # anything used to be allowed in a matrix + with warns_deprecated_sympy(): + assert Matrix([[[1], (2,)]]).tolist() == [[[1], (2,)]] + with warns_deprecated_sympy(): + assert Matrix([[[1], (2,)]]).T.tolist() == [[[1]], [(2,)]] + M = Matrix([[0]]) + with warns_deprecated_sympy(): + M[0, 0] = S.EmptySet + + a = Matrix([[x, 0], [0, 0]]) + m = a + assert m.cols == m.rows + assert m.cols == 2 + assert m[:] == [x, 0, 0, 0] + + b = Matrix(2, 2, [x, 0, 0, 0]) + m = b + assert m.cols == m.rows + assert m.cols == 2 + assert m[:] == [x, 0, 0, 0] + + assert a == b + + assert Matrix(b) == b + + c23 = Matrix(2, 3, range(1, 7)) + c13 = Matrix(1, 3, range(7, 10)) + c = Matrix([c23, c13]) + assert c.cols == 3 + assert c.rows == 3 + assert c[:] == [1, 2, 3, 4, 5, 6, 7, 8, 9] + + assert Matrix(eye(2)) == eye(2) + assert ImmutableMatrix(ImmutableMatrix(eye(2))) == ImmutableMatrix(eye(2)) + assert ImmutableMatrix(c) == c.as_immutable() + assert Matrix(ImmutableMatrix(c)) == ImmutableMatrix(c).as_mutable() + + assert c is not Matrix(c) + + dat = [[ones(3,2), ones(3,3)*2], [ones(2,3)*3, ones(2,2)*4]] + M = Matrix(dat) + assert M == Matrix([ + [1, 1, 2, 2, 2], + [1, 1, 2, 2, 2], + [1, 1, 2, 2, 2], + [3, 3, 3, 4, 4], + [3, 3, 3, 4, 4]]) + assert M.tolist() != dat + # keep block form if evaluate=False + assert Matrix(dat, evaluate=False).tolist() == dat + A = MatrixSymbol("A", 2, 2) + dat = [ones(2), A] + assert Matrix(dat) == Matrix([ + [ 1, 1], + [ 1, 1], + [A[0, 0], A[0, 1]], + [A[1, 0], A[1, 1]]]) + with warns_deprecated_sympy(): + assert Matrix(dat, evaluate=False).tolist() == [[i] for i in dat] + + # 0-dim tolerance + assert Matrix([ones(2), ones(0)]) == Matrix([ones(2)]) + raises(ValueError, lambda: Matrix([ones(2), ones(0, 3)])) + raises(ValueError, lambda: Matrix([ones(2), ones(3, 0)])) + + # mix of Matrix and iterable + M = Matrix([[1, 2], [3, 4]]) + M2 = Matrix([M, (5, 6)]) + assert M2 == Matrix([[1, 2], [3, 4], [5, 6]]) + + +def test_irregular_block(): + assert Matrix.irregular(3, ones(2,1), ones(3,3)*2, ones(2,2)*3, + ones(1,1)*4, ones(2,2)*5, ones(1,2)*6, ones(1,2)*7) == Matrix([ + [1, 2, 2, 2, 3, 3], + [1, 2, 2, 2, 3, 3], + [4, 2, 2, 2, 5, 5], + [6, 6, 7, 7, 5, 5]]) + + +def test_tolist(): + lst = [[S.One, S.Half, x*y, S.Zero], [x, y, z, x**2], [y, -S.One, z*x, 3]] + m = Matrix(lst) + assert m.tolist() == lst + + +def test_as_mutable(): + assert zeros(0, 3).as_mutable() == zeros(0, 3) + assert zeros(0, 3).as_immutable() == ImmutableMatrix(zeros(0, 3)) + assert zeros(3, 0).as_immutable() == ImmutableMatrix(zeros(3, 0)) + + +def test_slicing(): + m0 = eye(4) + assert m0[:3, :3] == eye(3) + assert m0[2:4, 0:2] == zeros(2) + + m1 = Matrix(3, 3, lambda i, j: i + j) + assert m1[0, :] == Matrix(1, 3, (0, 1, 2)) + assert m1[1:3, 1] == Matrix(2, 1, (2, 3)) + + m2 = Matrix([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]) + assert m2[:, -1] == Matrix(4, 1, [3, 7, 11, 15]) + assert m2[-2:, :] == Matrix([[8, 9, 10, 11], [12, 13, 14, 15]]) + + +def test_submatrix_assignment(): + m = zeros(4) + m[2:4, 2:4] = eye(2) + assert m == Matrix(((0, 0, 0, 0), + (0, 0, 0, 0), + (0, 0, 1, 0), + (0, 0, 0, 1))) + m[:2, :2] = eye(2) + assert m == eye(4) + m[:, 0] = Matrix(4, 1, (1, 2, 3, 4)) + assert m == Matrix(((1, 0, 0, 0), + (2, 1, 0, 0), + (3, 0, 1, 0), + (4, 0, 0, 1))) + m[:, :] = zeros(4) + assert m == zeros(4) + m[:, :] = [(1, 2, 3, 4), (5, 6, 7, 8), (9, 10, 11, 12), (13, 14, 15, 16)] + assert m == Matrix(((1, 2, 3, 4), + (5, 6, 7, 8), + (9, 10, 11, 12), + (13, 14, 15, 16))) + m[:2, 0] = [0, 0] + assert m == Matrix(((0, 2, 3, 4), + (0, 6, 7, 8), + (9, 10, 11, 12), + (13, 14, 15, 16))) + + +def test_extract(): + m = Matrix(4, 3, lambda i, j: i*3 + j) + assert m.extract([0, 1, 3], [0, 1]) == Matrix(3, 2, [0, 1, 3, 4, 9, 10]) + assert m.extract([0, 3], [0, 0, 2]) == Matrix(2, 3, [0, 0, 2, 9, 9, 11]) + assert m.extract(range(4), range(3)) == m + raises(IndexError, lambda: m.extract([4], [0])) + raises(IndexError, lambda: m.extract([0], [3])) + + +def test_reshape(): + m0 = eye(3) + assert m0.reshape(1, 9) == Matrix(1, 9, (1, 0, 0, 0, 1, 0, 0, 0, 1)) + m1 = Matrix(3, 4, lambda i, j: i + j) + assert m1.reshape( + 4, 3) == Matrix(((0, 1, 2), (3, 1, 2), (3, 4, 2), (3, 4, 5))) + assert m1.reshape(2, 6) == Matrix(((0, 1, 2, 3, 1, 2), (3, 4, 2, 3, 4, 5))) + + +def test_applyfunc(): + m0 = eye(3) + assert m0.applyfunc(lambda x: 2*x) == eye(3)*2 + assert m0.applyfunc(lambda x: 0) == zeros(3) + + +def test_expand(): + m0 = Matrix([[x*(x + y), 2], [((x + y)*y)*x, x*(y + x*(x + y))]]) + # Test if expand() returns a matrix + m1 = m0.expand() + assert m1 == Matrix( + [[x*y + x**2, 2], [x*y**2 + y*x**2, x*y + y*x**2 + x**3]]) + + a = Symbol('a', real=True) + + assert Matrix([exp(I*a)]).expand(complex=True) == \ + Matrix([cos(a) + I*sin(a)]) + + assert Matrix([[0, 1, 2], [0, 0, -1], [0, 0, 0]]).exp() == Matrix([ + [1, 1, Rational(3, 2)], + [0, 1, -1], + [0, 0, 1]] + ) + +def test_refine(): + m0 = Matrix([[Abs(x)**2, sqrt(x**2)], + [sqrt(x**2)*Abs(y)**2, sqrt(y**2)*Abs(x)**2]]) + m1 = m0.refine(Q.real(x) & Q.real(y)) + assert m1 == Matrix([[x**2, Abs(x)], [y**2*Abs(x), x**2*Abs(y)]]) + + m1 = m0.refine(Q.positive(x) & Q.positive(y)) + assert m1 == Matrix([[x**2, x], [x*y**2, x**2*y]]) + + m1 = m0.refine(Q.negative(x) & Q.negative(y)) + assert m1 == Matrix([[x**2, -x], [-x*y**2, -x**2*y]]) + +def test_random(): + M = randMatrix(3, 3) + M = randMatrix(3, 3, seed=3) + assert M == randMatrix(3, 3, seed=3) + + M = randMatrix(3, 4, 0, 150) + M = randMatrix(3, seed=4, symmetric=True) + assert M == randMatrix(3, seed=4, symmetric=True) + + S = M.copy() + S.simplify() + assert S == M # doesn't fail when elements are Numbers, not int + + rng = random.Random(4) + assert M == randMatrix(3, symmetric=True, prng=rng) + + # Ensure symmetry + for size in (10, 11): # Test odd and even + for percent in (100, 70, 30): + M = randMatrix(size, symmetric=True, percent=percent, prng=rng) + assert M == M.T + + M = randMatrix(10, min=1, percent=70) + zero_count = 0 + for i in range(M.shape[0]): + for j in range(M.shape[1]): + if M[i, j] == 0: + zero_count += 1 + assert zero_count == 30 + +def test_inverse(): + A = eye(4) + assert A.inv() == eye(4) + assert A.inv(method="LU") == eye(4) + assert A.inv(method="ADJ") == eye(4) + assert A.inv(method="CH") == eye(4) + assert A.inv(method="LDL") == eye(4) + assert A.inv(method="QR") == eye(4) + A = Matrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + Ainv = A.inv() + assert A*Ainv == eye(3) + assert A.inv(method="LU") == Ainv + assert A.inv(method="ADJ") == Ainv + assert A.inv(method="CH") == Ainv + assert A.inv(method="LDL") == Ainv + assert A.inv(method="QR") == Ainv + + AA = Matrix([[0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0], + [1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0], + [1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], + [1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0], + [1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1], + [0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0], + [1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1], + [0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], + [1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0], + [0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0], + [1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0], + [0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1], + [1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0], + [0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0], + [1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1], + [0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1], + [1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1], + [0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1], + [0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1], + [0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0]]) + assert AA.inv(method="BLOCK") * AA == eye(AA.shape[0]) + # test that immutability is not a problem + cls = ImmutableMatrix + m = cls([[48, 49, 31], + [ 9, 71, 94], + [59, 28, 65]]) + assert all(type(m.inv(s)) is cls for s in 'GE ADJ LU CH LDL QR'.split()) + cls = ImmutableSparseMatrix + m = cls([[48, 49, 31], + [ 9, 71, 94], + [59, 28, 65]]) + assert all(type(m.inv(s)) is cls for s in 'GE ADJ LU CH LDL QR'.split()) + + +def test_jacobian_hessian(): + L = Matrix(1, 2, [x**2*y, 2*y**2 + x*y]) + syms = [x, y] + assert L.jacobian(syms) == Matrix([[2*x*y, x**2], [y, 4*y + x]]) + + L = Matrix(1, 2, [x, x**2*y**3]) + assert L.jacobian(syms) == Matrix([[1, 0], [2*x*y**3, x**2*3*y**2]]) + + f = x**2*y + syms = [x, y] + assert hessian(f, syms) == Matrix([[2*y, 2*x], [2*x, 0]]) + + f = x**2*y**3 + assert hessian(f, syms) == \ + Matrix([[2*y**3, 6*x*y**2], [6*x*y**2, 6*x**2*y]]) + + f = z + x*y**2 + g = x**2 + 2*y**3 + ans = Matrix([[0, 2*y], + [2*y, 2*x]]) + assert ans == hessian(f, Matrix([x, y])) + assert ans == hessian(f, Matrix([x, y]).T) + assert hessian(f, (y, x), [g]) == Matrix([ + [ 0, 6*y**2, 2*x], + [6*y**2, 2*x, 2*y], + [ 2*x, 2*y, 0]]) + + +def test_wronskian(): + assert wronskian([cos(x), sin(x)], x) == cos(x)**2 + sin(x)**2 + assert wronskian([exp(x), exp(2*x)], x) == exp(3*x) + assert wronskian([exp(x), x], x) == exp(x) - x*exp(x) + assert wronskian([1, x, x**2], x) == 2 + w1 = -6*exp(x)*sin(x)*x + 6*cos(x)*exp(x)*x**2 - 6*exp(x)*cos(x)*x - \ + exp(x)*cos(x)*x**3 + exp(x)*sin(x)*x**3 + assert wronskian([exp(x), cos(x), x**3], x).expand() == w1 + assert wronskian([exp(x), cos(x), x**3], x, method='berkowitz').expand() \ + == w1 + w2 = -x**3*cos(x)**2 - x**3*sin(x)**2 - 6*x*cos(x)**2 - 6*x*sin(x)**2 + assert wronskian([sin(x), cos(x), x**3], x).expand() == w2 + assert wronskian([sin(x), cos(x), x**3], x, method='berkowitz').expand() \ + == w2 + assert wronskian([], x) == 1 + + +def test_subs(): + assert Matrix([[1, x], [x, 4]]).subs(x, 5) == Matrix([[1, 5], [5, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).subs([[x, -1], [y, -2]]) == \ + Matrix([[-1, 2], [-3, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).subs([(x, -1), (y, -2)]) == \ + Matrix([[-1, 2], [-3, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).subs({x: -1, y: -2}) == \ + Matrix([[-1, 2], [-3, 4]]) + assert Matrix([x*y]).subs({x: y - 1, y: x - 1}, simultaneous=True) == \ + Matrix([(x - 1)*(y - 1)]) + + for cls in classes: + assert Matrix([[2, 0], [0, 2]]) == cls.eye(2).subs(1, 2) + +def test_xreplace(): + assert Matrix([[1, x], [x, 4]]).xreplace({x: 5}) == \ + Matrix([[1, 5], [5, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).xreplace({x: -1, y: -2}) == \ + Matrix([[-1, 2], [-3, 4]]) + for cls in classes: + assert Matrix([[2, 0], [0, 2]]) == cls.eye(2).xreplace({1: 2}) + +def test_simplify(): + n = Symbol('n') + f = Function('f') + + M = Matrix([[ 1/x + 1/y, (x + x*y) / x ], + [ (f(x) + y*f(x))/f(x), 2 * (1/n - cos(n * pi)/n) / pi ]]) + M.simplify() + assert M == Matrix([[ (x + y)/(x * y), 1 + y ], + [ 1 + y, 2*((1 - 1*cos(pi*n))/(pi*n)) ]]) + eq = (1 + x)**2 + M = Matrix([[eq]]) + M.simplify() + assert M == Matrix([[eq]]) + M.simplify(ratio=oo) + assert M == Matrix([[eq.simplify(ratio=oo)]]) + + +def test_transpose(): + M = Matrix([[1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]]) + assert M.T == Matrix( [ [1, 1], + [2, 2], + [3, 3], + [4, 4], + [5, 5], + [6, 6], + [7, 7], + [8, 8], + [9, 9], + [0, 0] ]) + assert M.T.T == M + assert M.T == M.transpose() + + +def test_conjugate(): + M = Matrix([[0, I, 5], + [1, 2, 0]]) + + assert M.T == Matrix([[0, 1], + [I, 2], + [5, 0]]) + + assert M.C == Matrix([[0, -I, 5], + [1, 2, 0]]) + assert M.C == M.conjugate() + + assert M.H == M.T.C + assert M.H == Matrix([[ 0, 1], + [-I, 2], + [ 5, 0]]) + + +def test_conj_dirac(): + raises(AttributeError, lambda: eye(3).D) + + M = Matrix([[1, I, I, I], + [0, 1, I, I], + [0, 0, 1, I], + [0, 0, 0, 1]]) + + assert M.D == Matrix([[ 1, 0, 0, 0], + [-I, 1, 0, 0], + [-I, -I, -1, 0], + [-I, -I, I, -1]]) + + +def test_trace(): + M = Matrix([[1, 0, 0], + [0, 5, 0], + [0, 0, 8]]) + assert M.trace() == 14 + + +def test_shape(): + M = Matrix([[x, 0, 0], + [0, y, 0]]) + assert M.shape == (2, 3) + + +def test_col_row_op(): + M = Matrix([[x, 0, 0], + [0, y, 0]]) + M.row_op(1, lambda r, j: r + j + 1) + assert M == Matrix([[x, 0, 0], + [1, y + 2, 3]]) + + M.col_op(0, lambda c, j: c + y**j) + assert M == Matrix([[x + 1, 0, 0], + [1 + y, y + 2, 3]]) + + # neither row nor slice give copies that allow the original matrix to + # be changed + assert M.row(0) == Matrix([[x + 1, 0, 0]]) + r1 = M.row(0) + r1[0] = 42 + assert M[0, 0] == x + 1 + r1 = M[0, :-1] # also testing negative slice + r1[0] = 42 + assert M[0, 0] == x + 1 + c1 = M.col(0) + assert c1 == Matrix([x + 1, 1 + y]) + c1[0] = 0 + assert M[0, 0] == x + 1 + c1 = M[:, 0] + c1[0] = 42 + assert M[0, 0] == x + 1 + + +def test_row_mult(): + M = Matrix([[1,2,3], + [4,5,6]]) + M.row_mult(1,3) + assert M[1,0] == 12 + assert M[0,0] == 1 + assert M[1,2] == 18 + + +def test_row_add(): + M = Matrix([[1,2,3], + [4,5,6], + [1,1,1]]) + M.row_add(2,0,5) + assert M[0,0] == 6 + assert M[1,0] == 4 + assert M[0,2] == 8 + + +def test_zip_row_op(): + for cls in classes[:2]: # XXX: immutable matrices don't support row ops + M = cls.eye(3) + M.zip_row_op(1, 0, lambda v, u: v + 2*u) + assert M == cls([[1, 0, 0], + [2, 1, 0], + [0, 0, 1]]) + + M = cls.eye(3)*2 + M[0, 1] = -1 + M.zip_row_op(1, 0, lambda v, u: v + 2*u); M + assert M == cls([[2, -1, 0], + [4, 0, 0], + [0, 0, 2]]) + +def test_issue_3950(): + m = Matrix([1, 2, 3]) + a = Matrix([1, 2, 3]) + b = Matrix([2, 2, 3]) + assert not (m in []) + assert not (m in [1]) + assert m != 1 + assert m == a + assert m != b + + +def test_issue_3981(): + class Index1: + def __index__(self): + return 1 + + class Index2: + def __index__(self): + return 2 + index1 = Index1() + index2 = Index2() + + m = Matrix([1, 2, 3]) + + assert m[index2] == 3 + + m[index2] = 5 + assert m[2] == 5 + + m = Matrix([[1, 2, 3], [4, 5, 6]]) + assert m[index1, index2] == 6 + assert m[1, index2] == 6 + assert m[index1, 2] == 6 + + m[index1, index2] = 4 + assert m[1, 2] == 4 + m[1, index2] = 6 + assert m[1, 2] == 6 + m[index1, 2] = 8 + assert m[1, 2] == 8 + + +def test_evalf(): + a = Matrix([sqrt(5), 6]) + assert all(a.evalf()[i] == a[i].evalf() for i in range(2)) + assert all(a.evalf(2)[i] == a[i].evalf(2) for i in range(2)) + assert all(a.n(2)[i] == a[i].n(2) for i in range(2)) + + +def test_is_symbolic(): + a = Matrix([[x, x], [x, x]]) + assert a.is_symbolic() is True + a = Matrix([[1, 2, 3, 4], [5, 6, 7, 8]]) + assert a.is_symbolic() is False + a = Matrix([[1, 2, 3, 4], [5, 6, x, 8]]) + assert a.is_symbolic() is True + a = Matrix([[1, x, 3]]) + assert a.is_symbolic() is True + a = Matrix([[1, 2, 3]]) + assert a.is_symbolic() is False + a = Matrix([[1], [x], [3]]) + assert a.is_symbolic() is True + a = Matrix([[1], [2], [3]]) + assert a.is_symbolic() is False + + +def test_is_upper(): + a = Matrix([[1, 2, 3]]) + assert a.is_upper is True + a = Matrix([[1], [2], [3]]) + assert a.is_upper is False + a = zeros(4, 2) + assert a.is_upper is True + + +def test_is_lower(): + a = Matrix([[1, 2, 3]]) + assert a.is_lower is False + a = Matrix([[1], [2], [3]]) + assert a.is_lower is True + + +def test_is_nilpotent(): + a = Matrix(4, 4, [0, 2, 1, 6, 0, 0, 1, 2, 0, 0, 0, 3, 0, 0, 0, 0]) + assert a.is_nilpotent() + a = Matrix([[1, 0], [0, 1]]) + assert not a.is_nilpotent() + a = Matrix([]) + assert a.is_nilpotent() + + +def test_zeros_ones_fill(): + n, m = 3, 5 + + a = zeros(n, m) + a.fill( 5 ) + + b = 5 * ones(n, m) + + assert a == b + assert a.rows == b.rows == 3 + assert a.cols == b.cols == 5 + assert a.shape == b.shape == (3, 5) + assert zeros(2) == zeros(2, 2) + assert ones(2) == ones(2, 2) + assert zeros(2, 3) == Matrix(2, 3, [0]*6) + assert ones(2, 3) == Matrix(2, 3, [1]*6) + + a.fill(0) + assert a == zeros(n, m) + + +def test_empty_zeros(): + a = zeros(0) + assert a == Matrix() + a = zeros(0, 2) + assert a.rows == 0 + assert a.cols == 2 + a = zeros(2, 0) + assert a.rows == 2 + assert a.cols == 0 + + +def test_issue_3749(): + a = Matrix([[x**2, x*y], [x*sin(y), x*cos(y)]]) + assert a.diff(x) == Matrix([[2*x, y], [sin(y), cos(y)]]) + assert Matrix([ + [x, -x, x**2], + [exp(x), 1/x - exp(-x), x + 1/x]]).limit(x, oo) == \ + Matrix([[oo, -oo, oo], [oo, 0, oo]]) + assert Matrix([ + [(exp(x) - 1)/x, 2*x + y*x, x**x ], + [1/x, abs(x), abs(sin(x + 1))]]).limit(x, 0) == \ + Matrix([[1, 0, 1], [oo, 0, sin(1)]]) + assert a.integrate(x) == Matrix([ + [Rational(1, 3)*x**3, y*x**2/2], + [x**2*sin(y)/2, x**2*cos(y)/2]]) + + +def test_inv_iszerofunc(): + A = eye(4) + A.col_swap(0, 1) + for method in "GE", "LU": + assert A.inv(method=method, iszerofunc=lambda x: x == 0) == \ + A.inv(method="ADJ") + + +def test_jacobian_metrics(): + rho, phi = symbols("rho,phi") + X = Matrix([rho*cos(phi), rho*sin(phi)]) + Y = Matrix([rho, phi]) + J = X.jacobian(Y) + assert J == X.jacobian(Y.T) + assert J == (X.T).jacobian(Y) + assert J == (X.T).jacobian(Y.T) + g = J.T*eye(J.shape[0])*J + g = g.applyfunc(trigsimp) + assert g == Matrix([[1, 0], [0, rho**2]]) + + +def test_jacobian2(): + rho, phi = symbols("rho,phi") + X = Matrix([rho*cos(phi), rho*sin(phi), rho**2]) + Y = Matrix([rho, phi]) + J = Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)], + [ 2*rho, 0], + ]) + assert X.jacobian(Y) == J + + +def test_issue_4564(): + X = Matrix([exp(x + y + z), exp(x + y + z), exp(x + y + z)]) + Y = Matrix([x, y, z]) + for i in range(1, 3): + for j in range(1, 3): + X_slice = X[:i, :] + Y_slice = Y[:j, :] + J = X_slice.jacobian(Y_slice) + assert J.rows == i + assert J.cols == j + for k in range(j): + assert J[:, k] == X_slice + + +def test_nonvectorJacobian(): + X = Matrix([[exp(x + y + z), exp(x + y + z)], + [exp(x + y + z), exp(x + y + z)]]) + raises(TypeError, lambda: X.jacobian(Matrix([x, y, z]))) + X = X[0, :] + Y = Matrix([[x, y], [x, z]]) + raises(TypeError, lambda: X.jacobian(Y)) + raises(TypeError, lambda: X.jacobian(Matrix([ [x, y], [x, z] ]))) + + +def test_vec(): + m = Matrix([[1, 3], [2, 4]]) + m_vec = m.vec() + assert m_vec.cols == 1 + for i in range(4): + assert m_vec[i] == i + 1 + + +def test_vech(): + m = Matrix([[1, 2], [2, 3]]) + m_vech = m.vech() + assert m_vech.cols == 1 + for i in range(3): + assert m_vech[i] == i + 1 + m_vech = m.vech(diagonal=False) + assert m_vech[0] == 2 + + m = Matrix([[1, x*(x + y)], [y*x + x**2, 1]]) + m_vech = m.vech(diagonal=False) + assert m_vech[0] == y*x + x**2 + + m = Matrix([[1, x*(x + y)], [y*x, 1]]) + m_vech = m.vech(diagonal=False, check_symmetry=False) + assert m_vech[0] == y*x + + raises(ShapeError, lambda: Matrix([[1, 3]]).vech()) + raises(ValueError, lambda: Matrix([[1, 3], [2, 4]]).vech()) + raises(ShapeError, lambda: Matrix([[1, 3]]).vech()) + raises(ValueError, lambda: Matrix([[1, 3], [2, 4]]).vech()) + + +def test_diag(): + # mostly tested in testcommonmatrix.py + assert diag([1, 2, 3]) == Matrix([1, 2, 3]) + m = [1, 2, [3]] + raises(ValueError, lambda: diag(m)) + assert diag(m, strict=False) == Matrix([1, 2, 3]) + + +def test_get_diag_blocks1(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + assert a.get_diag_blocks() == [a] + assert b.get_diag_blocks() == [b] + assert c.get_diag_blocks() == [c] + + +def test_get_diag_blocks2(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + assert diag(a, b, b).get_diag_blocks() == [a, b, b] + assert diag(a, b, c).get_diag_blocks() == [a, b, c] + assert diag(a, c, b).get_diag_blocks() == [a, c, b] + assert diag(c, c, b).get_diag_blocks() == [c, c, b] + + +def test_inv_block(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + A = diag(a, b, b) + assert A.inv(try_block_diag=True) == diag(a.inv(), b.inv(), b.inv()) + A = diag(a, b, c) + assert A.inv(try_block_diag=True) == diag(a.inv(), b.inv(), c.inv()) + A = diag(a, c, b) + assert A.inv(try_block_diag=True) == diag(a.inv(), c.inv(), b.inv()) + A = diag(a, a, b, a, c, a) + assert A.inv(try_block_diag=True) == diag( + a.inv(), a.inv(), b.inv(), a.inv(), c.inv(), a.inv()) + assert A.inv(try_block_diag=True, method="ADJ") == diag( + a.inv(method="ADJ"), a.inv(method="ADJ"), b.inv(method="ADJ"), + a.inv(method="ADJ"), c.inv(method="ADJ"), a.inv(method="ADJ")) + + +def test_creation_args(): + """ + Check that matrix dimensions can be specified using any reasonable type + (see issue 4614). + """ + raises(ValueError, lambda: zeros(3, -1)) + raises(TypeError, lambda: zeros(1, 2, 3, 4)) + assert zeros(int(3)) == zeros(3) + assert zeros(Integer(3)) == zeros(3) + raises(ValueError, lambda: zeros(3.)) + assert eye(int(3)) == eye(3) + assert eye(Integer(3)) == eye(3) + raises(ValueError, lambda: eye(3.)) + assert ones(int(3), Integer(4)) == ones(3, 4) + raises(TypeError, lambda: Matrix(5)) + raises(TypeError, lambda: Matrix(1, 2)) + raises(ValueError, lambda: Matrix([1, [2]])) + + +def test_diagonal_symmetrical(): + m = Matrix(2, 2, [0, 1, 1, 0]) + assert not m.is_diagonal() + assert m.is_symmetric() + assert m.is_symmetric(simplify=False) + + m = Matrix(2, 2, [1, 0, 0, 1]) + assert m.is_diagonal() + + m = diag(1, 2, 3) + assert m.is_diagonal() + assert m.is_symmetric() + + m = Matrix(3, 3, [1, 0, 0, 0, 2, 0, 0, 0, 3]) + assert m == diag(1, 2, 3) + + m = Matrix(2, 3, zeros(2, 3)) + assert not m.is_symmetric() + assert m.is_diagonal() + + m = Matrix(((5, 0), (0, 6), (0, 0))) + assert m.is_diagonal() + + m = Matrix(((5, 0, 0), (0, 6, 0))) + assert m.is_diagonal() + + m = Matrix(3, 3, [1, x**2 + 2*x + 1, y, (x + 1)**2, 2, 0, y, 0, 3]) + assert m.is_symmetric() + assert not m.is_symmetric(simplify=False) + assert m.expand().is_symmetric(simplify=False) + + +def test_diagonalization(): + m = Matrix([[1, 2+I], [2-I, 3]]) + assert m.is_diagonalizable() + + m = Matrix(3, 2, [-3, 1, -3, 20, 3, 10]) + assert not m.is_diagonalizable() + assert not m.is_symmetric() + raises(NonSquareMatrixError, lambda: m.diagonalize()) + + # diagonalizable + m = diag(1, 2, 3) + (P, D) = m.diagonalize() + assert P == eye(3) + assert D == m + + m = Matrix(2, 2, [0, 1, 1, 0]) + assert m.is_symmetric() + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + + m = Matrix(2, 2, [1, 0, 0, 3]) + assert m.is_symmetric() + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + assert P == eye(2) + assert D == m + + m = Matrix(2, 2, [1, 1, 0, 0]) + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + + m = Matrix(3, 3, [1, 2, 0, 0, 3, 0, 2, -4, 2]) + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + for i in P: + assert i.as_numer_denom()[1] == 1 + + m = Matrix(2, 2, [1, 0, 0, 0]) + assert m.is_diagonal() + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + assert P == Matrix([[0, 1], [1, 0]]) + + # diagonalizable, complex only + m = Matrix(2, 2, [0, 1, -1, 0]) + assert not m.is_diagonalizable(True) + raises(MatrixError, lambda: m.diagonalize(True)) + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + + # not diagonalizable + m = Matrix(2, 2, [0, 1, 0, 0]) + assert not m.is_diagonalizable() + raises(MatrixError, lambda: m.diagonalize()) + + m = Matrix(3, 3, [-3, 1, -3, 20, 3, 10, 2, -2, 4]) + assert not m.is_diagonalizable() + raises(MatrixError, lambda: m.diagonalize()) + + # symbolic + a, b, c, d = symbols('a b c d') + m = Matrix(2, 2, [a, c, c, b]) + assert m.is_symmetric() + assert m.is_diagonalizable() + + +def test_issue_15887(): + # Mutable matrix should not use cache + a = MutableDenseMatrix([[0, 1], [1, 0]]) + assert a.is_diagonalizable() is True + a[1, 0] = 0 + assert a.is_diagonalizable() is False + + a = MutableDenseMatrix([[0, 1], [1, 0]]) + a.diagonalize() + a[1, 0] = 0 + raises(MatrixError, lambda: a.diagonalize()) + + +def test_jordan_form(): + + m = Matrix(3, 2, [-3, 1, -3, 20, 3, 10]) + raises(NonSquareMatrixError, lambda: m.jordan_form()) + + # diagonalizable + m = Matrix(3, 3, [7, -12, 6, 10, -19, 10, 12, -24, 13]) + Jmust = Matrix(3, 3, [-1, 0, 0, 0, 1, 0, 0, 0, 1]) + P, J = m.jordan_form() + assert Jmust == J + assert Jmust == m.diagonalize()[1] + + # m = Matrix(3, 3, [0, 6, 3, 1, 3, 1, -2, 2, 1]) + # m.jordan_form() # very long + # m.jordan_form() # + + # diagonalizable, complex only + + # Jordan cells + # complexity: one of eigenvalues is zero + m = Matrix(3, 3, [0, 1, 0, -4, 4, 0, -2, 1, 2]) + # The blocks are ordered according to the value of their eigenvalues, + # in order to make the matrix compatible with .diagonalize() + Jmust = Matrix(3, 3, [2, 1, 0, 0, 2, 0, 0, 0, 2]) + P, J = m.jordan_form() + assert Jmust == J + + # complexity: all of eigenvalues are equal + m = Matrix(3, 3, [2, 6, -15, 1, 1, -5, 1, 2, -6]) + # Jmust = Matrix(3, 3, [-1, 0, 0, 0, -1, 1, 0, 0, -1]) + # same here see 1456ff + Jmust = Matrix(3, 3, [-1, 1, 0, 0, -1, 0, 0, 0, -1]) + P, J = m.jordan_form() + assert Jmust == J + + # complexity: two of eigenvalues are zero + m = Matrix(3, 3, [4, -5, 2, 5, -7, 3, 6, -9, 4]) + Jmust = Matrix(3, 3, [0, 1, 0, 0, 0, 0, 0, 0, 1]) + P, J = m.jordan_form() + assert Jmust == J + + m = Matrix(4, 4, [6, 5, -2, -3, -3, -1, 3, 3, 2, 1, -2, -3, -1, 1, 5, 5]) + Jmust = Matrix(4, 4, [2, 1, 0, 0, + 0, 2, 0, 0, + 0, 0, 2, 1, + 0, 0, 0, 2] + ) + P, J = m.jordan_form() + assert Jmust == J + + m = Matrix(4, 4, [6, 2, -8, -6, -3, 2, 9, 6, 2, -2, -8, -6, -1, 0, 3, 4]) + # Jmust = Matrix(4, 4, [2, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0, -2]) + # same here see 1456ff + Jmust = Matrix(4, 4, [-2, 0, 0, 0, + 0, 2, 1, 0, + 0, 0, 2, 0, + 0, 0, 0, 2]) + P, J = m.jordan_form() + assert Jmust == J + + m = Matrix(4, 4, [5, 4, 2, 1, 0, 1, -1, -1, -1, -1, 3, 0, 1, 1, -1, 2]) + assert not m.is_diagonalizable() + Jmust = Matrix(4, 4, [1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 4, 1, 0, 0, 0, 4]) + P, J = m.jordan_form() + assert Jmust == J + + # checking for maximum precision to remain unchanged + m = Matrix([[Float('1.0', precision=110), Float('2.0', precision=110)], + [Float('3.14159265358979323846264338327', precision=110), Float('4.0', precision=110)]]) + P, J = m.jordan_form() + for term in J.values(): + if isinstance(term, Float): + assert term._prec == 110 + + +def test_jordan_form_complex_issue_9274(): + A = Matrix([[ 2, 4, 1, 0], + [-4, 2, 0, 1], + [ 0, 0, 2, 4], + [ 0, 0, -4, 2]]) + p = 2 - 4*I + q = 2 + 4*I + Jmust1 = Matrix([[p, 1, 0, 0], + [0, p, 0, 0], + [0, 0, q, 1], + [0, 0, 0, q]]) + Jmust2 = Matrix([[q, 1, 0, 0], + [0, q, 0, 0], + [0, 0, p, 1], + [0, 0, 0, p]]) + P, J = A.jordan_form() + assert J == Jmust1 or J == Jmust2 + assert simplify(P*J*P.inv()) == A + +def test_issue_10220(): + # two non-orthogonal Jordan blocks with eigenvalue 1 + M = Matrix([[1, 0, 0, 1], + [0, 1, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1]]) + P, J = M.jordan_form() + assert P == Matrix([[0, 1, 0, 1], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0]]) + assert J == Matrix([ + [1, 1, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + +def test_jordan_form_issue_15858(): + A = Matrix([ + [1, 1, 1, 0], + [-2, -1, 0, -1], + [0, 0, -1, -1], + [0, 0, 2, 1]]) + (P, J) = A.jordan_form() + assert P.expand() == Matrix([ + [ -I, -I/2, I, I/2], + [-1 + I, 0, -1 - I, 0], + [ 0, -S(1)/2 - I/2, 0, -S(1)/2 + I/2], + [ 0, 1, 0, 1]]) + assert J == Matrix([ + [-I, 1, 0, 0], + [0, -I, 0, 0], + [0, 0, I, 1], + [0, 0, 0, I]]) + +def test_Matrix_berkowitz_charpoly(): + UA, K_i, K_w = symbols('UA K_i K_w') + + A = Matrix([[-K_i - UA + K_i**2/(K_i + K_w), K_i*K_w/(K_i + K_w)], + [ K_i*K_w/(K_i + K_w), -K_w + K_w**2/(K_i + K_w)]]) + + charpoly = A.charpoly(x) + + assert charpoly == \ + Poly(x**2 + (K_i*UA + K_w*UA + 2*K_i*K_w)/(K_i + K_w)*x + + K_i*K_w*UA/(K_i + K_w), x, domain='ZZ(K_i,K_w,UA)') + + assert type(charpoly) is PurePoly + + A = Matrix([[1, 3], [2, 0]]) + assert A.charpoly() == A.charpoly(x) == PurePoly(x**2 - x - 6) + + A = Matrix([[1, 2], [x, 0]]) + p = A.charpoly(x) + assert p.gen != x + assert p.as_expr().subs(p.gen, x) == x**2 - 3*x + + +def test_exp_jordan_block(): + l = Symbol('lamda') + + m = Matrix.jordan_block(1, l) + assert m._eval_matrix_exp_jblock() == Matrix([[exp(l)]]) + + m = Matrix.jordan_block(3, l) + assert m._eval_matrix_exp_jblock() == \ + Matrix([ + [exp(l), exp(l), exp(l)/2], + [0, exp(l), exp(l)], + [0, 0, exp(l)]]) + + +def test_exp(): + m = Matrix([[3, 4], [0, -2]]) + m_exp = Matrix([[exp(3), -4*exp(-2)/5 + 4*exp(3)/5], [0, exp(-2)]]) + assert m.exp() == m_exp + assert exp(m) == m_exp + + m = Matrix([[1, 0], [0, 1]]) + assert m.exp() == Matrix([[E, 0], [0, E]]) + assert exp(m) == Matrix([[E, 0], [0, E]]) + + m = Matrix([[1, -1], [1, 1]]) + assert m.exp() == Matrix([[E*cos(1), -E*sin(1)], [E*sin(1), E*cos(1)]]) + + +def test_log(): + l = Symbol('lamda') + + m = Matrix.jordan_block(1, l) + assert m._eval_matrix_log_jblock() == Matrix([[log(l)]]) + + m = Matrix.jordan_block(4, l) + assert m._eval_matrix_log_jblock() == \ + Matrix( + [ + [log(l), 1/l, -1/(2*l**2), 1/(3*l**3)], + [0, log(l), 1/l, -1/(2*l**2)], + [0, 0, log(l), 1/l], + [0, 0, 0, log(l)] + ] + ) + + m = Matrix( + [[0, 0, 1], + [0, 0, 0], + [-1, 0, 0]] + ) + raises(MatrixError, lambda: m.log()) + + +def test_has(): + A = Matrix(((x, y), (2, 3))) + assert A.has(x) + assert not A.has(z) + assert A.has(Symbol) + + A = A.subs(x, 2) + assert not A.has(x) + + +def test_find_reasonable_pivot_naive_finds_guaranteed_nonzero1(): + # Test if matrices._find_reasonable_pivot_naive() + # finds a guaranteed non-zero pivot when the + # some of the candidate pivots are symbolic expressions. + # Keyword argument: simpfunc=None indicates that no simplifications + # should be performed during the search. + x = Symbol('x') + column = Matrix(3, 1, [x, cos(x)**2 + sin(x)**2, S.Half]) + pivot_offset, pivot_val, pivot_assumed_nonzero, simplified =\ + _find_reasonable_pivot_naive(column) + assert pivot_val == S.Half + +def test_find_reasonable_pivot_naive_finds_guaranteed_nonzero2(): + # Test if matrices._find_reasonable_pivot_naive() + # finds a guaranteed non-zero pivot when the + # some of the candidate pivots are symbolic expressions. + # Keyword argument: simpfunc=_simplify indicates that the search + # should attempt to simplify candidate pivots. + x = Symbol('x') + column = Matrix(3, 1, + [x, + cos(x)**2+sin(x)**2+x**2, + cos(x)**2+sin(x)**2]) + pivot_offset, pivot_val, pivot_assumed_nonzero, simplified =\ + _find_reasonable_pivot_naive(column, simpfunc=_simplify) + assert pivot_val == 1 + +def test_find_reasonable_pivot_naive_simplifies(): + # Test if matrices._find_reasonable_pivot_naive() + # simplifies candidate pivots, and reports + # their offsets correctly. + x = Symbol('x') + column = Matrix(3, 1, + [x, + cos(x)**2+sin(x)**2+x, + cos(x)**2+sin(x)**2]) + pivot_offset, pivot_val, pivot_assumed_nonzero, simplified =\ + _find_reasonable_pivot_naive(column, simpfunc=_simplify) + + assert len(simplified) == 2 + assert simplified[0][0] == 1 + assert simplified[0][1] == 1+x + assert simplified[1][0] == 2 + assert simplified[1][1] == 1 + +def test_errors(): + raises(ValueError, lambda: Matrix([[1, 2], [1]])) + raises(IndexError, lambda: Matrix([[1, 2]])[1.2, 5]) + raises(IndexError, lambda: Matrix([[1, 2]])[1, 5.2]) + raises(ValueError, lambda: randMatrix(3, c=4, symmetric=True)) + raises(ValueError, lambda: Matrix([1, 2]).reshape(4, 6)) + raises(ShapeError, + lambda: Matrix([[1, 2], [3, 4]]).copyin_matrix([1, 0], Matrix([1, 2]))) + raises(TypeError, lambda: Matrix([[1, 2], [3, 4]]).copyin_list([0, + 1], set())) + raises(NonSquareMatrixError, lambda: Matrix([[1, 2, 3], [2, 3, 0]]).inv()) + raises(ShapeError, + lambda: Matrix(1, 2, [1, 2]).row_join(Matrix([[1, 2], [3, 4]]))) + raises( + ShapeError, lambda: Matrix([1, 2]).col_join(Matrix([[1, 2], [3, 4]]))) + raises(ShapeError, lambda: Matrix([1]).row_insert(1, Matrix([[1, + 2], [3, 4]]))) + raises(ShapeError, lambda: Matrix([1]).col_insert(1, Matrix([[1, + 2], [3, 4]]))) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).trace()) + raises(TypeError, lambda: Matrix([1]).applyfunc(1)) + raises(ValueError, lambda: Matrix([[1, 2], [3, 4]]).minor(4, 5)) + raises(ValueError, lambda: Matrix([[1, 2], [3, 4]]).minor_submatrix(4, 5)) + raises(TypeError, lambda: Matrix([1, 2, 3]).cross(1)) + raises(TypeError, lambda: Matrix([1, 2, 3]).dot(1)) + raises(ShapeError, lambda: Matrix([1, 2, 3]).dot(Matrix([1, 2]))) + raises(ShapeError, lambda: Matrix([1, 2]).dot([])) + raises(TypeError, lambda: Matrix([1, 2]).dot('a')) + raises(ShapeError, lambda: Matrix([1, 2]).dot([1, 2, 3])) + raises(NonSquareMatrixError, lambda: Matrix([1, 2, 3]).exp()) + raises(ShapeError, lambda: Matrix([[1, 2], [3, 4]]).normalized()) + raises(ValueError, lambda: Matrix([1, 2]).inv(method='not a method')) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).inverse_GE()) + raises(ValueError, lambda: Matrix([[1, 2], [1, 2]]).inverse_GE()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).inverse_ADJ()) + raises(ValueError, lambda: Matrix([[1, 2], [1, 2]]).inverse_ADJ()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).inverse_LU()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).is_nilpotent()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).det()) + raises(ValueError, + lambda: Matrix([[1, 2], [3, 4]]).det(method='Not a real method')) + raises(ValueError, + lambda: Matrix([[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12], [13, 14, 15, 16]]).det(iszerofunc="Not function")) + raises(ValueError, + lambda: Matrix([[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12], [13, 14, 15, 16]]).det(iszerofunc=False)) + raises(ValueError, + lambda: hessian(Matrix([[1, 2], [3, 4]]), Matrix([[1, 2], [2, 1]]))) + raises(ValueError, lambda: hessian(Matrix([[1, 2], [3, 4]]), [])) + raises(ValueError, lambda: hessian(Symbol('x')**2, 'a')) + raises(IndexError, lambda: eye(3)[5, 2]) + raises(IndexError, lambda: eye(3)[2, 5]) + M = Matrix(((1, 2, 3, 4), (5, 6, 7, 8), (9, 10, 11, 12), (13, 14, 15, 16))) + raises(ValueError, lambda: M.det('method=LU_decomposition()')) + V = Matrix([[10, 10, 10]]) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(ValueError, lambda: M.row_insert(4.7, V)) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(ValueError, lambda: M.col_insert(-4.2, V)) + +def test_len(): + assert len(Matrix()) == 0 + assert len(Matrix([[1, 2]])) == len(Matrix([[1], [2]])) == 2 + assert len(Matrix(0, 2, lambda i, j: 0)) == \ + len(Matrix(2, 0, lambda i, j: 0)) == 0 + assert len(Matrix([[0, 1, 2], [3, 4, 5]])) == 6 + assert Matrix([1]) == Matrix([[1]]) + assert not Matrix() + assert Matrix() == Matrix([]) + + +def test_integrate(): + A = Matrix(((1, 4, x), (y, 2, 4), (10, 5, x**2))) + assert A.integrate(x) == \ + Matrix(((x, 4*x, x**2/2), (x*y, 2*x, 4*x), (10*x, 5*x, x**3/3))) + assert A.integrate(y) == \ + Matrix(((y, 4*y, x*y), (y**2/2, 2*y, 4*y), (10*y, 5*y, y*x**2))) + + +def test_limit(): + A = Matrix(((1, 4, sin(x)/x), (y, 2, 4), (10, 5, x**2 + 1))) + assert A.limit(x, 0) == Matrix(((1, 4, 1), (y, 2, 4), (10, 5, 1))) + + +def test_diff(): + A = MutableDenseMatrix(((1, 4, x), (y, 2, 4), (10, 5, x**2 + 1))) + assert isinstance(A.diff(x), type(A)) + assert A.diff(x) == MutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert A.diff(y) == MutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + assert diff(A, x) == MutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert diff(A, y) == MutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + A_imm = A.as_immutable() + assert isinstance(A_imm.diff(x), type(A_imm)) + assert A_imm.diff(x) == ImmutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert A_imm.diff(y) == ImmutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + assert diff(A_imm, x) == ImmutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert diff(A_imm, y) == ImmutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + assert A.diff(x, evaluate=False) == ArrayDerivative(A, x, evaluate=False) + assert diff(A, x, evaluate=False) == ArrayDerivative(A, x, evaluate=False) + + +def test_diff_by_matrix(): + + # Derive matrix by matrix: + + A = MutableDenseMatrix([[x, y], [z, t]]) + assert A.diff(A) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + assert diff(A, A) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + A_imm = A.as_immutable() + assert A_imm.diff(A_imm) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + assert diff(A_imm, A_imm) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + # Derive a constant matrix: + assert A.diff(a) == MutableDenseMatrix([[0, 0], [0, 0]]) + + B = ImmutableDenseMatrix([a, b]) + assert A.diff(B) == Array.zeros(2, 1, 2, 2) + assert A.diff(A) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + # Test diff with tuples: + + dB = B.diff([[a, b]]) + assert dB.shape == (2, 2, 1) + assert dB == Array([[[1], [0]], [[0], [1]]]) + + f = Function("f") + fxyz = f(x, y, z) + assert fxyz.diff([[x, y, z]]) == Array([fxyz.diff(x), fxyz.diff(y), fxyz.diff(z)]) + assert fxyz.diff(([x, y, z], 2)) == Array([ + [fxyz.diff(x, 2), fxyz.diff(x, y), fxyz.diff(x, z)], + [fxyz.diff(x, y), fxyz.diff(y, 2), fxyz.diff(y, z)], + [fxyz.diff(x, z), fxyz.diff(z, y), fxyz.diff(z, 2)], + ]) + + expr = sin(x)*exp(y) + assert expr.diff([[x, y]]) == Array([cos(x)*exp(y), sin(x)*exp(y)]) + assert expr.diff(y, ((x, y),)) == Array([cos(x)*exp(y), sin(x)*exp(y)]) + assert expr.diff(x, ((x, y),)) == Array([-sin(x)*exp(y), cos(x)*exp(y)]) + assert expr.diff(((y, x),), [[x, y]]) == Array([[cos(x)*exp(y), -sin(x)*exp(y)], [sin(x)*exp(y), cos(x)*exp(y)]]) + + # Test different notations: + + assert fxyz.diff(x).diff(y).diff(x) == fxyz.diff(((x, y, z),), 3)[0, 1, 0] + assert fxyz.diff(z).diff(y).diff(x) == fxyz.diff(((x, y, z),), 3)[2, 1, 0] + assert fxyz.diff([[x, y, z]], ((z, y, x),)) == Array([[fxyz.diff(i).diff(j) for i in (x, y, z)] for j in (z, y, x)]) + + # Test scalar derived by matrix remains matrix: + res = x.diff(Matrix([[x, y]])) + assert isinstance(res, ImmutableDenseMatrix) + assert res == Matrix([[1, 0]]) + res = (x**3).diff(Matrix([[x, y]])) + assert isinstance(res, ImmutableDenseMatrix) + assert res == Matrix([[3*x**2, 0]]) + + +def test_getattr(): + A = Matrix(((1, 4, x), (y, 2, 4), (10, 5, x**2 + 1))) + raises(AttributeError, lambda: A.nonexistantattribute) + assert getattr(A, 'diff')(x) == Matrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + + +def test_hessenberg(): + A = Matrix([[3, 4, 1], [2, 4, 5], [0, 1, 2]]) + assert A.is_upper_hessenberg + A = A.T + assert A.is_lower_hessenberg + A[0, -1] = 1 + assert A.is_lower_hessenberg is False + + A = Matrix([[3, 4, 1], [2, 4, 5], [3, 1, 2]]) + assert not A.is_upper_hessenberg + + A = zeros(5, 2) + assert A.is_upper_hessenberg + + +def test_cholesky(): + raises(NonSquareMatrixError, lambda: Matrix((1, 2)).cholesky()) + raises(ValueError, lambda: Matrix(((1, 2), (3, 4))).cholesky()) + raises(ValueError, lambda: Matrix(((5 + I, 0), (0, 1))).cholesky()) + raises(ValueError, lambda: Matrix(((1, 5), (5, 1))).cholesky()) + raises(ValueError, lambda: Matrix(((1, 2), (3, 4))).cholesky(hermitian=False)) + assert Matrix(((5 + I, 0), (0, 1))).cholesky(hermitian=False) == Matrix([ + [sqrt(5 + I), 0], [0, 1]]) + A = Matrix(((1, 5), (5, 1))) + L = A.cholesky(hermitian=False) + assert L == Matrix([[1, 0], [5, 2*sqrt(6)*I]]) + assert L*L.T == A + A = Matrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L = A.cholesky() + assert L * L.T == A + assert L.is_lower + assert L == Matrix([[5, 0, 0], [3, 3, 0], [-1, 1, 3]]) + A = Matrix(((4, -2*I, 2 + 2*I), (2*I, 2, -1 + I), (2 - 2*I, -1 - I, 11))) + assert A.cholesky().expand() == Matrix(((2, 0, 0), (I, 1, 0), (1 - I, 0, 3))) + + raises(NonSquareMatrixError, lambda: SparseMatrix((1, 2)).cholesky()) + raises(ValueError, lambda: SparseMatrix(((1, 2), (3, 4))).cholesky()) + raises(ValueError, lambda: SparseMatrix(((5 + I, 0), (0, 1))).cholesky()) + raises(ValueError, lambda: SparseMatrix(((1, 5), (5, 1))).cholesky()) + raises(ValueError, lambda: SparseMatrix(((1, 2), (3, 4))).cholesky(hermitian=False)) + assert SparseMatrix(((5 + I, 0), (0, 1))).cholesky(hermitian=False) == Matrix([ + [sqrt(5 + I), 0], [0, 1]]) + A = SparseMatrix(((1, 5), (5, 1))) + L = A.cholesky(hermitian=False) + assert L == Matrix([[1, 0], [5, 2*sqrt(6)*I]]) + assert L*L.T == A + A = SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L = A.cholesky() + assert L * L.T == A + assert L.is_lower + assert L == Matrix([[5, 0, 0], [3, 3, 0], [-1, 1, 3]]) + A = SparseMatrix(((4, -2*I, 2 + 2*I), (2*I, 2, -1 + I), (2 - 2*I, -1 - I, 11))) + assert A.cholesky() == Matrix(((2, 0, 0), (I, 1, 0), (1 - I, 0, 3))) + + +def test_matrix_norm(): + # Vector Tests + # Test columns and symbols + x = Symbol('x', real=True) + v = Matrix([cos(x), sin(x)]) + assert trigsimp(v.norm(2)) == 1 + assert v.norm(10) == Pow(cos(x)**10 + sin(x)**10, Rational(1, 10)) + + # Test Rows + A = Matrix([[5, Rational(3, 2)]]) + assert A.norm() == Pow(25 + Rational(9, 4), S.Half) + assert A.norm(oo) == max(A) + assert A.norm(-oo) == min(A) + + # Matrix Tests + # Intuitive test + A = Matrix([[1, 1], [1, 1]]) + assert A.norm(2) == 2 + assert A.norm(-2) == 0 + assert A.norm('frobenius') == 2 + assert eye(10).norm(2) == eye(10).norm(-2) == 1 + assert A.norm(oo) == 2 + + # Test with Symbols and more complex entries + A = Matrix([[3, y, y], [x, S.Half, -pi]]) + assert (A.norm('fro') + == sqrt(Rational(37, 4) + 2*abs(y)**2 + pi**2 + x**2)) + + # Check non-square + A = Matrix([[1, 2, -3], [4, 5, Rational(13, 2)]]) + assert A.norm(2) == sqrt(Rational(389, 8) + sqrt(78665)/8) + assert A.norm(-2) is S.Zero + assert A.norm('frobenius') == sqrt(389)/2 + + # Test properties of matrix norms + # https://en.wikipedia.org/wiki/Matrix_norm#Definition + # Two matrices + A = Matrix([[1, 2], [3, 4]]) + B = Matrix([[5, 5], [-2, 2]]) + C = Matrix([[0, -I], [I, 0]]) + D = Matrix([[1, 0], [0, -1]]) + L = [A, B, C, D] + alpha = Symbol('alpha', real=True) + + for order in ['fro', 2, -2]: + # Zero Check + assert zeros(3).norm(order) is S.Zero + # Check Triangle Inequality for all Pairs of Matrices + for X in L: + for Y in L: + dif = (X.norm(order) + Y.norm(order) - + (X + Y).norm(order)) + assert (dif >= 0) + # Scalar multiplication linearity + for M in [A, B, C, D]: + dif = simplify((alpha*M).norm(order) - + abs(alpha) * M.norm(order)) + assert dif == 0 + + # Test Properties of Vector Norms + # https://en.wikipedia.org/wiki/Vector_norm + # Two column vectors + a = Matrix([1, 1 - 1*I, -3]) + b = Matrix([S.Half, 1*I, 1]) + c = Matrix([-1, -1, -1]) + d = Matrix([3, 2, I]) + e = Matrix([Integer(1e2), Rational(1, 1e2), 1]) + L = [a, b, c, d, e] + alpha = Symbol('alpha', real=True) + + for order in [1, 2, -1, -2, S.Infinity, S.NegativeInfinity, pi]: + # Zero Check + if order > 0: + assert Matrix([0, 0, 0]).norm(order) is S.Zero + # Triangle inequality on all pairs + if order >= 1: # Triangle InEq holds only for these norms + for X in L: + for Y in L: + dif = (X.norm(order) + Y.norm(order) - + (X + Y).norm(order)) + assert simplify(dif >= 0) is S.true + # Linear to scalar multiplication + if order in [1, 2, -1, -2, S.Infinity, S.NegativeInfinity]: + for X in L: + dif = simplify((alpha*X).norm(order) - + (abs(alpha) * X.norm(order))) + assert dif == 0 + + # ord=1 + M = Matrix(3, 3, [1, 3, 0, -2, -1, 0, 3, 9, 6]) + assert M.norm(1) == 13 + + +def test_condition_number(): + x = Symbol('x', real=True) + A = eye(3) + A[0, 0] = 10 + A[2, 2] = Rational(1, 10) + assert A.condition_number() == 100 + + A[1, 1] = x + assert A.condition_number() == Max(10, Abs(x)) / Min(Rational(1, 10), Abs(x)) + + M = Matrix([[cos(x), sin(x)], [-sin(x), cos(x)]]) + Mc = M.condition_number() + assert all(Float(1.).epsilon_eq(Mc.subs(x, val).evalf()) for val in + [Rational(1, 5), S.Half, Rational(1, 10), pi/2, pi, pi*Rational(7, 4) ]) + + #issue 10782 + assert Matrix([]).condition_number() == 0 + + +def test_equality(): + A = Matrix(((1, 2, 3), (4, 5, 6), (7, 8, 9))) + B = Matrix(((9, 8, 7), (6, 5, 4), (3, 2, 1))) + assert A == A[:, :] + assert not A != A[:, :] + assert not A == B + assert A != B + assert A != 10 + assert not A == 10 + + # A SparseMatrix can be equal to a Matrix + C = SparseMatrix(((1, 0, 0), (0, 1, 0), (0, 0, 1))) + D = Matrix(((1, 0, 0), (0, 1, 0), (0, 0, 1))) + assert C == D + assert not C != D + + +def test_col_join(): + assert eye(3).col_join(Matrix([[7, 7, 7]])) == \ + Matrix([[1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [7, 7, 7]]) + + +def test_row_insert(): + r4 = Matrix([[4, 4, 4]]) + for i in range(-4, 5): + l = [1, 0, 0] + l.insert(i, 4) + assert flatten(eye(3).row_insert(i, r4).col(0).tolist()) == l + + +def test_col_insert(): + c4 = Matrix([4, 4, 4]) + for i in range(-4, 5): + l = [0, 0, 0] + l.insert(i, 4) + assert flatten(zeros(3).col_insert(i, c4).row(0).tolist()) == l + + +def test_normalized(): + assert Matrix([3, 4]).normalized() == \ + Matrix([Rational(3, 5), Rational(4, 5)]) + + # Zero vector trivial cases + assert Matrix([0, 0, 0]).normalized() == Matrix([0, 0, 0]) + + # Machine precision error truncation trivial cases + m = Matrix([0,0,1.e-100]) + assert m.normalized( + iszerofunc=lambda x: x.evalf(n=10, chop=True).is_zero + ) == Matrix([0, 0, 0]) + + +def test_print_nonzero(): + assert capture(lambda: eye(3).print_nonzero()) == \ + '[X ]\n[ X ]\n[ X]\n' + assert capture(lambda: eye(3).print_nonzero('.')) == \ + '[. ]\n[ . ]\n[ .]\n' + + +def test_zeros_eye(): + assert Matrix.eye(3) == eye(3) + assert Matrix.zeros(3) == zeros(3) + assert ones(3, 4) == Matrix(3, 4, [1]*12) + + i = Matrix([[1, 0], [0, 1]]) + z = Matrix([[0, 0], [0, 0]]) + for cls in classes: + m = cls.eye(2) + assert i == m # but m == i will fail if m is immutable + assert i == eye(2, cls=cls) + assert type(m) == cls + m = cls.zeros(2) + assert z == m + assert z == zeros(2, cls=cls) + assert type(m) == cls + + +def test_is_zero(): + assert Matrix().is_zero_matrix + assert Matrix([[0, 0], [0, 0]]).is_zero_matrix + assert zeros(3, 4).is_zero_matrix + assert not eye(3).is_zero_matrix + assert Matrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert SparseMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert ImmutableMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert ImmutableSparseMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert Matrix([[x, 1], [0, 0]]).is_zero_matrix == False + a = Symbol('a', nonzero=True) + assert Matrix([[a, 0], [0, 0]]).is_zero_matrix == False + + +def test_rotation_matrices(): + # This tests the rotation matrices by rotating about an axis and back. + theta = pi/3 + r3_plus = rot_axis3(theta) + r3_minus = rot_axis3(-theta) + r2_plus = rot_axis2(theta) + r2_minus = rot_axis2(-theta) + r1_plus = rot_axis1(theta) + r1_minus = rot_axis1(-theta) + assert r3_minus*r3_plus*eye(3) == eye(3) + assert r2_minus*r2_plus*eye(3) == eye(3) + assert r1_minus*r1_plus*eye(3) == eye(3) + + # Check the correctness of the trace of the rotation matrix + assert r1_plus.trace() == 1 + 2*cos(theta) + assert r2_plus.trace() == 1 + 2*cos(theta) + assert r3_plus.trace() == 1 + 2*cos(theta) + + # Check that a rotation with zero angle doesn't change anything. + assert rot_axis1(0) == eye(3) + assert rot_axis2(0) == eye(3) + assert rot_axis3(0) == eye(3) + + # Check left-hand convention + # see Issue #24529 + q1 = Quaternion.from_axis_angle([1, 0, 0], pi / 2) + q2 = Quaternion.from_axis_angle([0, 1, 0], pi / 2) + q3 = Quaternion.from_axis_angle([0, 0, 1], pi / 2) + assert rot_axis1(- pi / 2) == q1.to_rotation_matrix() + assert rot_axis2(- pi / 2) == q2.to_rotation_matrix() + assert rot_axis3(- pi / 2) == q3.to_rotation_matrix() + # Check right-hand convention + assert rot_ccw_axis1(+ pi / 2) == q1.to_rotation_matrix() + assert rot_ccw_axis2(+ pi / 2) == q2.to_rotation_matrix() + assert rot_ccw_axis3(+ pi / 2) == q3.to_rotation_matrix() + + +def test_DeferredVector(): + assert str(DeferredVector("vector")[4]) == "vector[4]" + assert sympify(DeferredVector("d")) == DeferredVector("d") + raises(IndexError, lambda: DeferredVector("d")[-1]) + assert str(DeferredVector("d")) == "d" + assert repr(DeferredVector("test")) == "DeferredVector('test')" + +def test_DeferredVector_not_iterable(): + assert not iterable(DeferredVector('X')) + +def test_DeferredVector_Matrix(): + raises(TypeError, lambda: Matrix(DeferredVector("V"))) + +def test_GramSchmidt(): + R = Rational + m1 = Matrix(1, 2, [1, 2]) + m2 = Matrix(1, 2, [2, 3]) + assert GramSchmidt([m1, m2]) == \ + [Matrix(1, 2, [1, 2]), Matrix(1, 2, [R(2)/5, R(-1)/5])] + assert GramSchmidt([m1.T, m2.T]) == \ + [Matrix(2, 1, [1, 2]), Matrix(2, 1, [R(2)/5, R(-1)/5])] + # from wikipedia + assert GramSchmidt([Matrix([3, 1]), Matrix([2, 2])], True) == [ + Matrix([3*sqrt(10)/10, sqrt(10)/10]), + Matrix([-sqrt(10)/10, 3*sqrt(10)/10])] + # https://github.com/sympy/sympy/issues/9488 + L = FiniteSet(Matrix([1])) + assert GramSchmidt(L) == [Matrix([[1]])] + + +def test_casoratian(): + assert casoratian([1, 2, 3, 4], 1) == 0 + assert casoratian([1, 2, 3, 4], 1, zero=False) == 0 + + +def test_zero_dimension_multiply(): + assert (Matrix()*zeros(0, 3)).shape == (0, 3) + assert zeros(3, 0)*zeros(0, 3) == zeros(3, 3) + assert zeros(0, 3)*zeros(3, 0) == Matrix() + + +def test_slice_issue_2884(): + m = Matrix(2, 2, range(4)) + assert m[1, :] == Matrix([[2, 3]]) + assert m[-1, :] == Matrix([[2, 3]]) + assert m[:, 1] == Matrix([[1, 3]]).T + assert m[:, -1] == Matrix([[1, 3]]).T + raises(IndexError, lambda: m[2, :]) + raises(IndexError, lambda: m[2, 2]) + + +def test_slice_issue_3401(): + assert zeros(0, 3)[:, -1].shape == (0, 1) + assert zeros(3, 0)[0, :] == Matrix(1, 0, []) + + +def test_copyin(): + s = zeros(3, 3) + s[3] = 1 + assert s[:, 0] == Matrix([0, 1, 0]) + assert s[3] == 1 + assert s[3: 4] == [1] + s[1, 1] = 42 + assert s[1, 1] == 42 + assert s[1, 1:] == Matrix([[42, 0]]) + s[1, 1:] = Matrix([[5, 6]]) + assert s[1, :] == Matrix([[1, 5, 6]]) + s[1, 1:] = [[42, 43]] + assert s[1, :] == Matrix([[1, 42, 43]]) + s[0, 0] = 17 + assert s[:, :1] == Matrix([17, 1, 0]) + s[0, 0] = [1, 1, 1] + assert s[:, 0] == Matrix([1, 1, 1]) + s[0, 0] = Matrix([1, 1, 1]) + assert s[:, 0] == Matrix([1, 1, 1]) + s[0, 0] = SparseMatrix([1, 1, 1]) + assert s[:, 0] == Matrix([1, 1, 1]) + + +def test_invertible_check(): + # sometimes a singular matrix will have a pivot vector shorter than + # the number of rows in a matrix... + assert Matrix([[1, 2], [1, 2]]).rref() == (Matrix([[1, 2], [0, 0]]), (0,)) + raises(ValueError, lambda: Matrix([[1, 2], [1, 2]]).inv()) + m = Matrix([ + [-1, -1, 0], + [ x, 1, 1], + [ 1, x, -1], + ]) + assert len(m.rref()[1]) != m.rows + # in addition, unless simplify=True in the call to rref, the identity + # matrix will be returned even though m is not invertible + assert m.rref()[0] != eye(3) + assert m.rref(simplify=signsimp)[0] != eye(3) + raises(ValueError, lambda: m.inv(method="ADJ")) + raises(ValueError, lambda: m.inv(method="GE")) + raises(ValueError, lambda: m.inv(method="LU")) + + +def test_issue_3959(): + x, y = symbols('x, y') + e = x*y + assert e.subs(x, Matrix([3, 5, 3])) == Matrix([3, 5, 3])*y + + +def test_issue_5964(): + assert str(Matrix([[1, 2], [3, 4]])) == 'Matrix([[1, 2], [3, 4]])' + + +def test_issue_7604(): + x, y = symbols("x y") + assert sstr(Matrix([[x, 2*y], [y**2, x + 3]])) == \ + 'Matrix([\n[ x, 2*y],\n[y**2, x + 3]])' + + +def test_is_Identity(): + assert eye(3).is_Identity + assert eye(3).as_immutable().is_Identity + assert not zeros(3).is_Identity + assert not ones(3).is_Identity + # issue 6242 + assert not Matrix([[1, 0, 0]]).is_Identity + # issue 8854 + assert SparseMatrix(3,3, {(0,0):1, (1,1):1, (2,2):1}).is_Identity + assert not SparseMatrix(2,3, range(6)).is_Identity + assert not SparseMatrix(3,3, {(0,0):1, (1,1):1}).is_Identity + assert not SparseMatrix(3,3, {(0,0):1, (1,1):1, (2,2):1, (0,1):2, (0,2):3}).is_Identity + + +def test_dot(): + assert ones(1, 3).dot(ones(3, 1)) == 3 + assert ones(1, 3).dot([1, 1, 1]) == 3 + assert Matrix([1, 2, 3]).dot(Matrix([1, 2, 3])) == 14 + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I])) == -5 + I + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I]), hermitian=False) == -5 + I + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I]), hermitian=True) == 13 + I + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I]), hermitian=True, conjugate_convention="physics") == 13 - I + assert Matrix([1, 2, 3*I]).dot(Matrix([4, 5*I, 6]), hermitian=True, conjugate_convention="right") == 4 + 8*I + assert Matrix([1, 2, 3*I]).dot(Matrix([4, 5*I, 6]), hermitian=True, conjugate_convention="left") == 4 - 8*I + assert Matrix([I, 2*I]).dot(Matrix([I, 2*I]), hermitian=False, conjugate_convention="left") == -5 + assert Matrix([I, 2*I]).dot(Matrix([I, 2*I]), conjugate_convention="left") == 5 + raises(ValueError, lambda: Matrix([1, 2]).dot(Matrix([3, 4]), hermitian=True, conjugate_convention="test")) + + +def test_dual(): + B_x, B_y, B_z, E_x, E_y, E_z = symbols( + 'B_x B_y B_z E_x E_y E_z', real=True) + F = Matrix(( + ( 0, E_x, E_y, E_z), + (-E_x, 0, B_z, -B_y), + (-E_y, -B_z, 0, B_x), + (-E_z, B_y, -B_x, 0) + )) + Fd = Matrix(( + ( 0, -B_x, -B_y, -B_z), + (B_x, 0, E_z, -E_y), + (B_y, -E_z, 0, E_x), + (B_z, E_y, -E_x, 0) + )) + assert F.dual().equals(Fd) + assert eye(3).dual().equals(zeros(3)) + assert F.dual().dual().equals(-F) + + +def test_anti_symmetric(): + assert Matrix([1, 2]).is_anti_symmetric() is False + m = Matrix(3, 3, [0, x**2 + 2*x + 1, y, -(x + 1)**2, 0, x*y, -y, -x*y, 0]) + assert m.is_anti_symmetric() is True + assert m.is_anti_symmetric(simplify=False) is None + assert m.is_anti_symmetric(simplify=lambda x: x) is None + + # tweak to fail + m[2, 1] = -m[2, 1] + assert m.is_anti_symmetric() is None + # untweak + m[2, 1] = -m[2, 1] + + m = m.expand() + assert m.is_anti_symmetric(simplify=False) is True + m[0, 0] = 1 + assert m.is_anti_symmetric() is False + + +def test_normalize_sort_diogonalization(): + A = Matrix(((1, 2), (2, 1))) + P, Q = A.diagonalize(normalize=True) + assert P*P.T == P.T*P == eye(P.cols) + P, Q = A.diagonalize(normalize=True, sort=True) + assert P*P.T == P.T*P == eye(P.cols) + assert P*Q*P.inv() == A + + +def test_issue_5321(): + raises(ValueError, lambda: Matrix([[1, 2, 3], Matrix(0, 1, [])])) + + +def test_issue_5320(): + assert Matrix.hstack(eye(2), 2*eye(2)) == Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2] + ]) + assert Matrix.vstack(eye(2), 2*eye(2)) == Matrix([ + [1, 0], + [0, 1], + [2, 0], + [0, 2] + ]) + cls = SparseMatrix + assert cls.hstack(cls(eye(2)), cls(2*eye(2))) == Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2] + ]) + +def test_issue_11944(): + A = Matrix([[1]]) + AIm = sympify(A) + assert Matrix.hstack(AIm, A) == Matrix([[1, 1]]) + assert Matrix.vstack(AIm, A) == Matrix([[1], [1]]) + +def test_cross(): + a = [1, 2, 3] + b = [3, 4, 5] + col = Matrix([-2, 4, -2]) + row = col.T + + def test(M, ans): + assert ans == M + assert type(M) == cls + for cls in classes: + A = cls(a) + B = cls(b) + test(A.cross(B), col) + test(A.cross(B.T), col) + test(A.T.cross(B.T), row) + test(A.T.cross(B), row) + raises(ShapeError, lambda: + Matrix(1, 2, [1, 1]).cross(Matrix(1, 2, [1, 1]))) + +def test_hat_vee(): + v1 = Matrix([x, y, z]) + v2 = Matrix([a, b, c]) + assert v1.hat() * v2 == v1.cross(v2) + assert v1.hat().is_anti_symmetric() + assert v1.hat().vee() == v1 + +def test_hash(): + for cls in classes[-2:]: + s = {cls.eye(1), cls.eye(1)} + assert len(s) == 1 and s.pop() == cls.eye(1) + # issue 3979 + for cls in classes[:2]: + assert not isinstance(cls.eye(1), Hashable) + + +@XFAIL +def test_issue_3979(): + # when this passes, delete this and change the [1:2] + # to [:2] in the test_hash above for issue 3979 + cls = classes[0] + raises(AttributeError, lambda: hash(cls.eye(1))) + + +def test_adjoint(): + dat = [[0, I], [1, 0]] + ans = Matrix([[0, 1], [-I, 0]]) + for cls in classes: + assert ans == cls(dat).adjoint() + + +def test_adjoint_with_operator(): + # Regression test for issue 25130: adjoint() should propagate to operators + import sympy.physics.quantum + a = sympy.physics.quantum.operator.Operator('a') + a_dag = sympy.physics.quantum.Dagger(a) + dat = [[0, I * a], [0, a_dag]] + ans = Matrix([[0, 0], [-I * a_dag, a]]) + for cls in classes: + assert ans == cls(dat).adjoint() + + +def test_simplify_immutable(): + assert simplify(ImmutableMatrix([[sin(x)**2 + cos(x)**2]])) == \ + ImmutableMatrix([[1]]) + +def test_replace(): + F, G = symbols('F, G', cls=Function) + K = Matrix(2, 2, lambda i, j: G(i+j)) + M = Matrix(2, 2, lambda i, j: F(i+j)) + N = M.replace(F, G) + assert N == K + + +def test_atoms(): + m = Matrix([[1, 2], [x, 1 - 1/x]]) + assert m.atoms() == {S.One,S(2),S.NegativeOne, x} + assert m.atoms(Symbol) == {x} + + +def test_pinv(): + # Pseudoinverse of an invertible matrix is the inverse. + A1 = Matrix([[a, b], [c, d]]) + assert simplify(A1.pinv(method="RD")) == simplify(A1.inv()) + + # Test the four properties of the pseudoinverse for various matrices. + As = [Matrix([[13, 104], [2212, 3], [-3, 5]]), + Matrix([[1, 7, 9], [11, 17, 19]]), + Matrix([a, b])] + + for A in As: + A_pinv = A.pinv(method="RD") + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + + # XXX Pinv with diagonalization makes expression too complicated. + for A in As: + A_pinv = simplify(A.pinv(method="ED")) + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + + # XXX Computing pinv using diagonalization makes an expression that + # is too complicated to simplify. + # A1 = Matrix([[a, b], [c, d]]) + # assert simplify(A1.pinv(method="ED")) == simplify(A1.inv()) + # so this is tested numerically at a fixed random point + + from sympy.core.numbers import comp + q = A1.pinv(method="ED") + w = A1.inv() + reps = {a: -73633, b: 11362, c: 55486, d: 62570} + assert all( + comp(i.n(), j.n()) + for i, j in zip(q.subs(reps), w.subs(reps)) + ) + + +@slow +def test_pinv_rank_deficient_when_diagonalization_fails(): + # Test the four properties of the pseudoinverse for matrices when + # diagonalization of A.H*A fails. + As = [ + Matrix([ + [61, 89, 55, 20, 71, 0], + [62, 96, 85, 85, 16, 0], + [69, 56, 17, 4, 54, 0], + [10, 54, 91, 41, 71, 0], + [ 7, 30, 10, 48, 90, 0], + [0, 0, 0, 0, 0, 0]]) + ] + for A in As: + A_pinv = A.pinv(method="ED") + AAp = A * A_pinv + ApA = A_pinv * A + assert AAp.H == AAp + + # Here ApA.H and ApA are equivalent expressions but they are very + # complicated expressions involving RootOfs. Using simplify would be + # too slow and so would evalf so we substitute approximate values for + # the RootOfs and then evalf which is less accurate but good enough to + # confirm that these two matrices are equivalent. + # + # assert ApA.H == ApA # <--- would fail (structural equality) + # assert simplify(ApA.H - ApA).is_zero_matrix # <--- too slow + # (ApA.H - ApA).evalf() # <--- too slow + + def allclose(M1, M2): + rootofs = M1.atoms(RootOf) + rootofs_approx = {r: r.evalf() for r in rootofs} + diff_approx = (M1 - M2).xreplace(rootofs_approx).evalf() + return all(abs(e) < 1e-10 for e in diff_approx) + + assert allclose(ApA.H, ApA) + + +def test_issue_7201(): + assert ones(0, 1) + ones(0, 1) == Matrix(0, 1, []) + assert ones(1, 0) + ones(1, 0) == Matrix(1, 0, []) + +def test_free_symbols(): + for M in ImmutableMatrix, ImmutableSparseMatrix, Matrix, SparseMatrix: + assert M([[x], [0]]).free_symbols == {x} + +def test_from_ndarray(): + """See issue 7465.""" + try: + from numpy import array + except ImportError: + skip('NumPy must be available to test creating matrices from ndarrays') + + assert Matrix(array([1, 2, 3])) == Matrix([1, 2, 3]) + assert Matrix(array([[1, 2, 3]])) == Matrix([[1, 2, 3]]) + assert Matrix(array([[1, 2, 3], [4, 5, 6]])) == \ + Matrix([[1, 2, 3], [4, 5, 6]]) + assert Matrix(array([x, y, z])) == Matrix([x, y, z]) + raises(NotImplementedError, + lambda: Matrix(array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]))) + assert Matrix([array([1, 2]), array([3, 4])]) == Matrix([[1, 2], [3, 4]]) + assert Matrix([array([1, 2]), [3, 4]]) == Matrix([[1, 2], [3, 4]]) + assert Matrix([array([]), array([])]) == Matrix(2, 0, []) != Matrix(0, 0, []) + +def test_17522_numpy(): + from sympy.matrices.common import _matrixify + try: + from numpy import array, matrix + except ImportError: + skip('NumPy must be available to test indexing matrixified NumPy ndarrays and matrices') + + m = _matrixify(array([[1, 2], [3, 4]])) + assert m[3] == 4 + assert list(m) == [1, 2, 3, 4] + + with ignore_warnings(PendingDeprecationWarning): + m = _matrixify(matrix([[1, 2], [3, 4]])) + assert m[3] == 4 + assert list(m) == [1, 2, 3, 4] + +def test_17522_mpmath(): + from sympy.matrices.common import _matrixify + try: + from mpmath import matrix + except ImportError: + skip('mpmath must be available to test indexing matrixified mpmath matrices') + + m = _matrixify(matrix([[1, 2], [3, 4]])) + assert m[3] == 4.0 + assert list(m) == [1.0, 2.0, 3.0, 4.0] + +def test_17522_scipy(): + from sympy.matrices.common import _matrixify + try: + from scipy.sparse import csr_matrix + except ImportError: + skip('SciPy must be available to test indexing matrixified SciPy sparse matrices') + + m = _matrixify(csr_matrix([[1, 2], [3, 4]])) + assert m[3] == 4 + assert list(m) == [1, 2, 3, 4] + +def test_hermitian(): + a = Matrix([[1, I], [-I, 1]]) + assert a.is_hermitian + a[0, 0] = 2*I + assert a.is_hermitian is False + a[0, 0] = x + assert a.is_hermitian is None + a[0, 1] = a[1, 0]*I + assert a.is_hermitian is False + b = HermitianOperator("b") + c = Operator("c") + assert Matrix([[b]]).is_hermitian is True + assert Matrix([[b, c], [Dagger(c), b]]).is_hermitian is True + assert Matrix([[b, c], [c, b]]).is_hermitian is False + assert Matrix([[b, c], [transpose(c), b]]).is_hermitian is False + +def test_doit(): + a = Matrix([[Add(x,x, evaluate=False)]]) + assert a[0] != 2*x + assert a.doit() == Matrix([[2*x]]) + +def test_issue_9457_9467_9876(): + # for row_del(index) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + M.row_del(1) + assert M == Matrix([[1, 2, 3], [3, 4, 5]]) + N = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + N.row_del(-2) + assert N == Matrix([[1, 2, 3], [3, 4, 5]]) + O = Matrix([[1, 2, 3], [5, 6, 7], [9, 10, 11]]) + O.row_del(-1) + assert O == Matrix([[1, 2, 3], [5, 6, 7]]) + P = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: P.row_del(10)) + Q = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: Q.row_del(-10)) + + # for col_del(index) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + M.col_del(1) + assert M == Matrix([[1, 3], [2, 4], [3, 5]]) + N = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + N.col_del(-2) + assert N == Matrix([[1, 3], [2, 4], [3, 5]]) + P = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: P.col_del(10)) + Q = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: Q.col_del(-10)) + +def test_issue_9422(): + x, y = symbols('x y', commutative=False) + a, b = symbols('a b') + M = eye(2) + M1 = Matrix(2, 2, [x, y, y, z]) + assert y*x*M != x*y*M + assert b*a*M == a*b*M + assert x*M1 != M1*x + assert a*M1 == M1*a + assert y*x*M == Matrix([[y*x, 0], [0, y*x]]) + + +def test_issue_10770(): + M = Matrix([]) + a = ['col_insert', 'row_join'], Matrix([9, 6, 3]) + b = ['row_insert', 'col_join'], a[1].T + c = ['row_insert', 'col_insert'], Matrix([[1, 2], [3, 4]]) + for ops, m in (a, b, c): + for op in ops: + f = getattr(M, op) + new = f(m) if 'join' in op else f(42, m) + assert new == m and id(new) != id(m) + + +def test_issue_10658(): + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert A.extract([0, 1, 2], [True, True, False]) == \ + Matrix([[1, 2], [4, 5], [7, 8]]) + assert A.extract([0, 1, 2], [True, False, False]) == Matrix([[1], [4], [7]]) + assert A.extract([True, False, False], [0, 1, 2]) == Matrix([[1, 2, 3]]) + assert A.extract([True, False, True], [0, 1, 2]) == \ + Matrix([[1, 2, 3], [7, 8, 9]]) + assert A.extract([0, 1, 2], [False, False, False]) == Matrix(3, 0, []) + assert A.extract([False, False, False], [0, 1, 2]) == Matrix(0, 3, []) + assert A.extract([True, False, True], [False, True, False]) == \ + Matrix([[2], [8]]) + +def test_opportunistic_simplification(): + # this test relates to issue #10718, #9480, #11434 + + # issue #9480 + m = Matrix([[-5 + 5*sqrt(2), -5], [-5*sqrt(2)/2 + 5, -5*sqrt(2)/2]]) + assert m.rank() == 1 + + # issue #10781 + m = Matrix([[3+3*sqrt(3)*I, -9],[4,-3+3*sqrt(3)*I]]) + assert simplify(m.rref()[0] - Matrix([[1, -9/(3 + 3*sqrt(3)*I)], [0, 0]])) == zeros(2, 2) + + # issue #11434 + ax,ay,bx,by,cx,cy,dx,dy,ex,ey,t0,t1 = symbols('a_x a_y b_x b_y c_x c_y d_x d_y e_x e_y t_0 t_1') + m = Matrix([[ax,ay,ax*t0,ay*t0,0],[bx,by,bx*t0,by*t0,0],[cx,cy,cx*t0,cy*t0,1],[dx,dy,dx*t0,dy*t0,1],[ex,ey,2*ex*t1-ex*t0,2*ey*t1-ey*t0,0]]) + assert m.rank() == 4 + +def test_partial_pivoting(): + # example from https://en.wikipedia.org/wiki/Pivot_element + # partial pivoting with back substitution gives a perfect result + # naive pivoting give an error ~1e-13, so anything better than + # 1e-15 is good + mm=Matrix([[0.003, 59.14, 59.17], [5.291, -6.13, 46.78]]) + assert (mm.rref()[0] - Matrix([[1.0, 0, 10.0], + [ 0, 1.0, 1.0]])).norm() < 1e-15 + + # issue #11549 + m_mixed = Matrix([[6e-17, 1.0, 4], + [ -1.0, 0, 8], + [ 0, 0, 1]]) + m_float = Matrix([[6e-17, 1.0, 4.], + [ -1.0, 0., 8.], + [ 0., 0., 1.]]) + m_inv = Matrix([[ 0, -1.0, 8.0], + [1.0, 6.0e-17, -4.0], + [ 0, 0, 1]]) + # this example is numerically unstable and involves a matrix with a norm >= 8, + # this comparing the difference of the results with 1e-15 is numerically sound. + assert (m_mixed.inv() - m_inv).norm() < 1e-15 + assert (m_float.inv() - m_inv).norm() < 1e-15 + +def test_iszero_substitution(): + """ When doing numerical computations, all elements that pass + the iszerofunc test should be set to numerically zero if they + aren't already. """ + + # Matrix from issue #9060 + m = Matrix([[0.9, -0.1, -0.2, 0],[-0.8, 0.9, -0.4, 0],[-0.1, -0.8, 0.6, 0]]) + m_rref = m.rref(iszerofunc=lambda x: abs(x)<6e-15)[0] + m_correct = Matrix([[1.0, 0, -0.301369863013699, 0],[ 0, 1.0, -0.712328767123288, 0],[ 0, 0, 0, 0]]) + m_diff = m_rref - m_correct + assert m_diff.norm() < 1e-15 + # if a zero-substitution wasn't made, this entry will be -1.11022302462516e-16 + assert m_rref[2,2] == 0 + +def test_issue_11238(): + from sympy.geometry.point import Point + xx = 8*tan(pi*Rational(13, 45))/(tan(pi*Rational(13, 45)) + sqrt(3)) + yy = (-8*sqrt(3)*tan(pi*Rational(13, 45))**2 + 24*tan(pi*Rational(13, 45)))/(-3 + tan(pi*Rational(13, 45))**2) + p1 = Point(0, 0) + p2 = Point(1, -sqrt(3)) + p0 = Point(xx,yy) + m1 = Matrix([p1 - simplify(p0), p2 - simplify(p0)]) + m2 = Matrix([p1 - p0, p2 - p0]) + m3 = Matrix([simplify(p1 - p0), simplify(p2 - p0)]) + + # This system has expressions which are zero and + # cannot be easily proved to be such, so without + # numerical testing, these assertions will fail. + Z = lambda x: abs(x.n()) < 1e-20 + assert m1.rank(simplify=True, iszerofunc=Z) == 1 + assert m2.rank(simplify=True, iszerofunc=Z) == 1 + assert m3.rank(simplify=True, iszerofunc=Z) == 1 + +def test_as_real_imag(): + m1 = Matrix(2,2,[1,2,3,4]) + m2 = m1*S.ImaginaryUnit + m3 = m1 + m2 + + for kls in classes: + a,b = kls(m3).as_real_imag() + assert list(a) == list(m1) + assert list(b) == list(m1) + +def test_deprecated(): + # Maintain tests for deprecated functions. We must capture + # the deprecation warnings. When the deprecated functionality is + # removed, the corresponding tests should be removed. + + m = Matrix(3, 3, [0, 1, 0, -4, 4, 0, -2, 1, 2]) + P, Jcells = m.jordan_cells() + assert Jcells[1] == Matrix(1, 1, [2]) + assert Jcells[0] == Matrix(2, 2, [2, 1, 0, 2]) + + +def test_issue_14489(): + from sympy.core.mod import Mod + A = Matrix([-1, 1, 2]) + B = Matrix([10, 20, -15]) + + assert Mod(A, 3) == Matrix([2, 1, 2]) + assert Mod(B, 4) == Matrix([2, 0, 1]) + +def test_issue_14943(): + # Test that __array__ accepts the optional dtype argument + try: + from numpy import array + except ImportError: + skip('NumPy must be available to test creating matrices from ndarrays') + + M = Matrix([[1,2], [3,4]]) + assert array(M, dtype=float).dtype.name == 'float64' + +def test_case_6913(): + m = MatrixSymbol('m', 1, 1) + a = Symbol("a") + a = m[0, 0]>0 + assert str(a) == 'm[0, 0] > 0' + +def test_issue_11948(): + A = MatrixSymbol('A', 3, 3) + a = Wild('a') + assert A.match(a) == {a: A} + +def test_gramschmidt_conjugate_dot(): + vecs = [Matrix([1, I]), Matrix([1, -I])] + assert Matrix.orthogonalize(*vecs) == \ + [Matrix([[1], [I]]), Matrix([[1], [-I]])] + + vecs = [Matrix([1, I, 0]), Matrix([I, 0, -I])] + assert Matrix.orthogonalize(*vecs) == \ + [Matrix([[1], [I], [0]]), Matrix([[I/2], [S(1)/2], [-I]])] + + mat = Matrix([[1, I], [1, -I]]) + Q, R = mat.QRdecomposition() + assert Q * Q.H == Matrix.eye(2) + +def test_issue_8207(): + a = Matrix(MatrixSymbol('a', 3, 1)) + b = Matrix(MatrixSymbol('b', 3, 1)) + c = a.dot(b) + d = diff(c, a[0, 0]) + e = diff(d, a[0, 0]) + assert d == b[0, 0] + assert e == 0 + +def test_func(): + from sympy.simplify.simplify import nthroot + + A = Matrix([[1, 2],[0, 3]]) + assert A.analytic_func(sin(x*t), x) == Matrix([[sin(t), sin(3*t) - sin(t)], [0, sin(3*t)]]) + + A = Matrix([[2, 1],[1, 2]]) + assert (pi * A / 6).analytic_func(cos(x), x) == Matrix([[sqrt(3)/4, -sqrt(3)/4], [-sqrt(3)/4, sqrt(3)/4]]) + + + raises(ValueError, lambda : zeros(5).analytic_func(log(x), x)) + raises(ValueError, lambda : (A*x).analytic_func(log(x), x)) + + A = Matrix([[0, -1, -2, 3], [0, -1, -2, 3], [0, 1, 0, -1], [0, 0, -1, 1]]) + assert A.analytic_func(exp(x), x) == A.exp() + raises(ValueError, lambda : A.analytic_func(sqrt(x), x)) + + A = Matrix([[41, 12],[12, 34]]) + assert simplify(A.analytic_func(sqrt(x), x)**2) == A + + A = Matrix([[3, -12, 4], [-1, 0, -2], [-1, 5, -1]]) + assert simplify(A.analytic_func(nthroot(x, 3), x)**3) == A + + A = Matrix([[2, 0, 0, 0], [1, 2, 0, 0], [0, 1, 3, 0], [0, 0, 1, 3]]) + assert A.analytic_func(exp(x), x) == A.exp() + + A = Matrix([[0, 2, 1, 6], [0, 0, 1, 2], [0, 0, 0, 3], [0, 0, 0, 0]]) + assert A.analytic_func(exp(x*t), x) == expand(simplify((A*t).exp())) + + +@skip_under_pyodide("Cannot create threads under pyodide.") +def test_issue_19809(): + + def f(): + assert _dotprodsimp_state.state == None + m = Matrix([[1]]) + m = m * m + return True + + with dotprodsimp(True): + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(f) + assert future.result() + + +def test_issue_23276(): + M = Matrix([x, y]) + assert integrate(M, (x, 0, 1), (y, 0, 1)) == Matrix([ + [S.Half], + [S.Half]]) + + +# SubspaceOnlyMatrix tests +def test_columnspace_one(): + m = SubspaceOnlyMatrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + basis = m.columnspace() + assert basis[0] == Matrix([1, -2, 0, 3]) + assert basis[1] == Matrix([2, -5, -3, 6]) + assert basis[2] == Matrix([2, -1, 4, -7]) + + assert len(basis) == 3 + assert Matrix.hstack(m, *basis).columnspace() == basis + + +def test_rowspace(): + m = SubspaceOnlyMatrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + basis = m.rowspace() + assert basis[0] == Matrix([[1, 2, 0, 2, 5]]) + assert basis[1] == Matrix([[0, -1, 1, 3, 2]]) + assert basis[2] == Matrix([[0, 0, 0, 5, 5]]) + + assert len(basis) == 3 + + +def test_nullspace_one(): + m = SubspaceOnlyMatrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + basis = m.nullspace() + assert basis[0] == Matrix([-2, 1, 1, 0, 0]) + assert basis[1] == Matrix([-1, -1, 0, -1, 1]) + # make sure the null space is really gets zeroed + assert all(e.is_zero for e in m*basis[0]) + assert all(e.is_zero for e in m*basis[1]) + + +# ReductionsOnlyMatrix tests +def test_row_op(): + e = eye_Reductions(3) + + raises(ValueError, lambda: e.elementary_row_op("abc")) + raises(ValueError, lambda: e.elementary_row_op()) + raises(ValueError, lambda: e.elementary_row_op('n->kn', row=5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->kn', row=-5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=1, row2=5)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=5, row2=1)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=-5, row2=1)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=1, row2=-5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=1, row2=5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=5, row2=1, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=-5, row2=1, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=1, row2=-5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=1, row2=1, k=5)) + + # test various ways to set arguments + assert e.elementary_row_op("n->kn", 0, 5) == Matrix([[5, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->kn", 1, 5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->kn", row=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->kn", row1=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_row_op("n<->m", 0, 1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_row_op("n<->m", row1=0, row2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_row_op("n<->m", row=0, row2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->n+km", 0, 5, 1) == Matrix([[1, 5, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->n+km", row=0, k=5, row2=1) == Matrix([[1, 5, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->n+km", row1=0, k=5, row2=1) == Matrix([[1, 5, 0], [0, 1, 0], [0, 0, 1]]) + + # make sure the matrix doesn't change size + a = ReductionsOnlyMatrix(2, 3, [0]*6) + assert a.elementary_row_op("n->kn", 1, 5) == Matrix(2, 3, [0]*6) + assert a.elementary_row_op("n<->m", 0, 1) == Matrix(2, 3, [0]*6) + assert a.elementary_row_op("n->n+km", 0, 5, 1) == Matrix(2, 3, [0]*6) + + +def test_col_op(): + e = eye_Reductions(3) + + raises(ValueError, lambda: e.elementary_col_op("abc")) + raises(ValueError, lambda: e.elementary_col_op()) + raises(ValueError, lambda: e.elementary_col_op('n->kn', col=5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->kn', col=-5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=1, col2=5)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=5, col2=1)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=-5, col2=1)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=1, col2=-5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=1, col2=5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=5, col2=1, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=-5, col2=1, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=1, col2=-5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=1, col2=1, k=5)) + + # test various ways to set arguments + assert e.elementary_col_op("n->kn", 0, 5) == Matrix([[5, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->kn", 1, 5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->kn", col=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->kn", col1=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_col_op("n<->m", 0, 1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_col_op("n<->m", col1=0, col2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_col_op("n<->m", col=0, col2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->n+km", 0, 5, 1) == Matrix([[1, 0, 0], [5, 1, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->n+km", col=0, k=5, col2=1) == Matrix([[1, 0, 0], [5, 1, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->n+km", col1=0, k=5, col2=1) == Matrix([[1, 0, 0], [5, 1, 0], [0, 0, 1]]) + + # make sure the matrix doesn't change size + a = ReductionsOnlyMatrix(2, 3, [0]*6) + assert a.elementary_col_op("n->kn", 1, 5) == Matrix(2, 3, [0]*6) + assert a.elementary_col_op("n<->m", 0, 1) == Matrix(2, 3, [0]*6) + assert a.elementary_col_op("n->n+km", 0, 5, 1) == Matrix(2, 3, [0]*6) + + +def test_is_echelon(): + zro = zeros_Reductions(3) + ident = eye_Reductions(3) + + assert zro.is_echelon + assert ident.is_echelon + + a = ReductionsOnlyMatrix(0, 0, []) + assert a.is_echelon + + a = ReductionsOnlyMatrix(2, 3, [3, 2, 1, 0, 0, 6]) + assert a.is_echelon + + a = ReductionsOnlyMatrix(2, 3, [0, 0, 6, 3, 2, 1]) + assert not a.is_echelon + + x = Symbol('x') + a = ReductionsOnlyMatrix(3, 1, [x, 0, 0]) + assert a.is_echelon + + a = ReductionsOnlyMatrix(3, 1, [x, x, 0]) + assert not a.is_echelon + + a = ReductionsOnlyMatrix(3, 3, [0, 0, 0, 1, 2, 3, 0, 0, 0]) + assert not a.is_echelon + + +def test_echelon_form(): + # echelon form is not unique, but the result + # must be row-equivalent to the original matrix + # and it must be in echelon form. + + a = zeros_Reductions(3) + e = eye_Reductions(3) + + # we can assume the zero matrix and the identity matrix shouldn't change + assert a.echelon_form() == a + assert e.echelon_form() == e + + a = ReductionsOnlyMatrix(0, 0, []) + assert a.echelon_form() == a + + a = ReductionsOnlyMatrix(1, 1, [5]) + assert a.echelon_form() == a + + # now we get to the real tests + + def verify_row_null_space(mat, rows, nulls): + for v in nulls: + assert all(t.is_zero for t in a_echelon*v) + for v in rows: + if not all(t.is_zero for t in v): + assert not all(t.is_zero for t in a_echelon*v.transpose()) + + a = ReductionsOnlyMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + nulls = [Matrix([ + [ 1], + [-2], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + + a = ReductionsOnlyMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 8]) + nulls = [] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + a = ReductionsOnlyMatrix(3, 3, [2, 1, 3, 0, 0, 0, 2, 1, 3]) + nulls = [Matrix([ + [Rational(-1, 2)], + [ 1], + [ 0]]), + Matrix([ + [Rational(-3, 2)], + [ 0], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + # this one requires a row swap + a = ReductionsOnlyMatrix(3, 3, [2, 1, 3, 0, 0, 0, 1, 1, 3]) + nulls = [Matrix([ + [ 0], + [ -3], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + a = ReductionsOnlyMatrix(3, 3, [0, 3, 3, 0, 2, 2, 0, 1, 1]) + nulls = [Matrix([ + [1], + [0], + [0]]), + Matrix([ + [ 0], + [-1], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + a = ReductionsOnlyMatrix(2, 3, [2, 2, 3, 3, 3, 0]) + nulls = [Matrix([ + [-1], + [1], + [0]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + +def test_rref(): + e = ReductionsOnlyMatrix(0, 0, []) + assert e.rref(pivots=False) == e + + e = ReductionsOnlyMatrix(1, 1, [1]) + a = ReductionsOnlyMatrix(1, 1, [5]) + assert e.rref(pivots=False) == a.rref(pivots=False) == e + + a = ReductionsOnlyMatrix(3, 1, [1, 2, 3]) + assert a.rref(pivots=False) == Matrix([[1], [0], [0]]) + + a = ReductionsOnlyMatrix(1, 3, [1, 2, 3]) + assert a.rref(pivots=False) == Matrix([[1, 2, 3]]) + + a = ReductionsOnlyMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + assert a.rref(pivots=False) == Matrix([ + [1, 0, -1], + [0, 1, 2], + [0, 0, 0]]) + + a = ReductionsOnlyMatrix(3, 3, [1, 2, 3, 1, 2, 3, 1, 2, 3]) + b = ReductionsOnlyMatrix(3, 3, [1, 2, 3, 0, 0, 0, 0, 0, 0]) + c = ReductionsOnlyMatrix(3, 3, [0, 0, 0, 1, 2, 3, 0, 0, 0]) + d = ReductionsOnlyMatrix(3, 3, [0, 0, 0, 0, 0, 0, 1, 2, 3]) + assert a.rref(pivots=False) == \ + b.rref(pivots=False) == \ + c.rref(pivots=False) == \ + d.rref(pivots=False) == b + + e = eye_Reductions(3) + z = zeros_Reductions(3) + assert e.rref(pivots=False) == e + assert z.rref(pivots=False) == z + + a = ReductionsOnlyMatrix([ + [ 0, 0, 1, 2, 2, -5, 3], + [-1, 5, 2, 2, 1, -7, 5], + [ 0, 0, -2, -3, -3, 8, -5], + [-1, 5, 0, -1, -2, 1, 0]]) + mat, pivot_offsets = a.rref() + assert mat == Matrix([ + [1, -5, 0, 0, 1, 1, -1], + [0, 0, 1, 0, 0, -1, 1], + [0, 0, 0, 1, 1, -2, 1], + [0, 0, 0, 0, 0, 0, 0]]) + assert pivot_offsets == (0, 2, 3) + + a = ReductionsOnlyMatrix([[Rational(1, 19), Rational(1, 5), 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11], + [ 12, 13, 14, 15]]) + assert a.rref(pivots=False) == Matrix([ + [1, 0, 0, Rational(-76, 157)], + [0, 1, 0, Rational(-5, 157)], + [0, 0, 1, Rational(238, 157)], + [0, 0, 0, 0]]) + + x = Symbol('x') + a = ReductionsOnlyMatrix(2, 3, [x, 1, 1, sqrt(x), x, 1]) + for i, j in zip(a.rref(pivots=False), + [1, 0, sqrt(x)*(-x + 1)/(-x**Rational(5, 2) + x), + 0, 1, 1/(sqrt(x) + x + 1)]): + assert simplify(i - j).is_zero + + +def test_rref_rhs(): + a, b, c, d = symbols('a b c d') + A = Matrix([[0, 0], [0, 0], [1, 2], [3, 4]]) + B = Matrix([a, b, c, d]) + assert A.rref_rhs(B) == (Matrix([ + [1, 0], + [0, 1], + [0, 0], + [0, 0]]), Matrix([ + [ -2*c + d], + [3*c/2 - d/2], + [ a], + [ b]])) + + +def test_issue_17827(): + C = Matrix([ + [3, 4, -1, 1], + [9, 12, -3, 3], + [0, 2, 1, 3], + [2, 3, 0, -2], + [0, 3, 3, -5], + [8, 15, 0, 6] + ]) + # Tests for row/col within valid range + D = C.elementary_row_op('n<->m', row1=2, row2=5) + E = C.elementary_row_op('n->n+km', row1=5, row2=3, k=-4) + F = C.elementary_row_op('n->kn', row=5, k=2) + assert(D[5, :] == Matrix([[0, 2, 1, 3]])) + assert(E[5, :] == Matrix([[0, 3, 0, 14]])) + assert(F[5, :] == Matrix([[16, 30, 0, 12]])) + # Tests for row/col out of range + raises(ValueError, lambda: C.elementary_row_op('n<->m', row1=2, row2=6)) + raises(ValueError, lambda: C.elementary_row_op('n->kn', row=7, k=2)) + raises(ValueError, lambda: C.elementary_row_op('n->n+km', row1=-1, row2=5, k=2)) + +def test_rank(): + m = Matrix([[1, 2], [x, 1 - 1/x]]) + assert m.rank() == 2 + n = Matrix(3, 3, range(1, 10)) + assert n.rank() == 2 + p = zeros(3) + assert p.rank() == 0 + +def test_issue_11434(): + ax, ay, bx, by, cx, cy, dx, dy, ex, ey, t0, t1 = \ + symbols('a_x a_y b_x b_y c_x c_y d_x d_y e_x e_y t_0 t_1') + M = Matrix([[ax, ay, ax*t0, ay*t0, 0], + [bx, by, bx*t0, by*t0, 0], + [cx, cy, cx*t0, cy*t0, 1], + [dx, dy, dx*t0, dy*t0, 1], + [ex, ey, 2*ex*t1 - ex*t0, 2*ey*t1 - ey*t0, 0]]) + assert M.rank() == 4 + +def test_rank_regression_from_so(): + # see: + # https://stackoverflow.com/questions/19072700/why-does-sympy-give-me-the-wrong-answer-when-i-row-reduce-a-symbolic-matrix + + nu, lamb = symbols('nu, lambda') + A = Matrix([[-3*nu, 1, 0, 0], + [ 3*nu, -2*nu - 1, 2, 0], + [ 0, 2*nu, (-1*nu) - lamb - 2, 3], + [ 0, 0, nu + lamb, -3]]) + expected_reduced = Matrix([[1, 0, 0, 1/(nu**2*(-lamb - nu))], + [0, 1, 0, 3/(nu*(-lamb - nu))], + [0, 0, 1, 3/(-lamb - nu)], + [0, 0, 0, 0]]) + expected_pivots = (0, 1, 2) + + reduced, pivots = A.rref() + + assert simplify(expected_reduced - reduced) == zeros(*A.shape) + assert pivots == expected_pivots + +def test_issue_15872(): + A = Matrix([[1, 1, 1, 0], [-2, -1, 0, -1], [0, 0, -1, -1], [0, 0, 2, 1]]) + B = A - Matrix.eye(4) * I + assert B.rank() == 3 + assert (B**2).rank() == 2 + assert (B**3).rank() == 2 diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_matrixbase.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_matrixbase.py new file mode 100644 index 0000000000000000000000000000000000000000..a77f51596c6622dc427feeeb9383214592fab632 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_matrixbase.py @@ -0,0 +1,3795 @@ +import concurrent.futures +import random +from collections.abc import Hashable + +from sympy import ( + Abs, Add, Array, DeferredVector, E, Expr, FiniteSet, Float, Function, + GramSchmidt, I, ImmutableDenseMatrix, ImmutableMatrix, + ImmutableSparseMatrix, Integer, KroneckerDelta, MatPow, Matrix, + MatrixSymbol, Max, Min, MutableDenseMatrix, MutableSparseMatrix, Poly, Pow, + PurePoly, Q, Quaternion, Rational, RootOf, S, SparseMatrix, Symbol, Tuple, + Wild, banded, casoratian, cos, diag, diff, exp, expand, eye, floor, hessian, + integrate, log, matrix_multiply_elementwise, nan, ones, oo, pi, randMatrix, + rot_axis1, rot_axis2, rot_axis3, rot_ccw_axis1, rot_ccw_axis2, + rot_ccw_axis3, signsimp, simplify, sin, sqrt, sstr, symbols, sympify, tan, + trigsimp, wronskian, zeros, cancel) +from sympy.abc import a, b, c, d, t, x, y, z +from sympy.core.kind import NumberKind, UndefinedKind +from sympy.matrices.determinant import _find_reasonable_pivot_naive +from sympy.matrices.exceptions import ( + MatrixError, NonSquareMatrixError, ShapeError) +from sympy.matrices.kind import MatrixKind +from sympy.matrices.utilities import _dotprodsimp_state, _simplify, dotprodsimp +from sympy.tensor.array.array_derivatives import ArrayDerivative +from sympy.testing.pytest import ( + ignore_warnings, raises, skip, skip_under_pyodide, slow, + warns_deprecated_sympy) +from sympy.utilities.iterables import capture, iterable +from importlib.metadata import version + +all_classes = (Matrix, SparseMatrix, ImmutableMatrix, ImmutableSparseMatrix) +mutable_classes = (Matrix, SparseMatrix) +immutable_classes = (ImmutableMatrix, ImmutableSparseMatrix) + + +def test__MinimalMatrix(): + x = Matrix(2, 3, [1, 2, 3, 4, 5, 6]) + assert x.rows == 2 + assert x.cols == 3 + assert x[2] == 3 + assert x[1, 1] == 5 + assert list(x) == [1, 2, 3, 4, 5, 6] + assert list(x[1, :]) == [4, 5, 6] + assert list(x[:, 1]) == [2, 5] + assert list(x[:, :]) == list(x) + assert x[:, :] == x + assert Matrix(x) == x + assert Matrix([[1, 2, 3], [4, 5, 6]]) == x + assert Matrix(([1, 2, 3], [4, 5, 6])) == x + assert Matrix([(1, 2, 3), (4, 5, 6)]) == x + assert Matrix(((1, 2, 3), (4, 5, 6))) == x + assert not (Matrix([[1, 2], [3, 4], [5, 6]]) == x) + + +def test_kind(): + assert Matrix([[1, 2], [3, 4]]).kind == MatrixKind(NumberKind) + assert Matrix([[0, 0], [0, 0]]).kind == MatrixKind(NumberKind) + assert Matrix(0, 0, []).kind == MatrixKind(NumberKind) + assert Matrix([[x]]).kind == MatrixKind(NumberKind) + assert Matrix([[1, Matrix([[1]])]]).kind == MatrixKind(UndefinedKind) + assert SparseMatrix([[1]]).kind == MatrixKind(NumberKind) + assert SparseMatrix([[1, Matrix([[1]])]]).kind == MatrixKind(UndefinedKind) + + +def test_todok(): + a, b, c, d = symbols('a:d') + m1 = MutableDenseMatrix([[a, b], [c, d]]) + m2 = ImmutableDenseMatrix([[a, b], [c, d]]) + m3 = MutableSparseMatrix([[a, b], [c, d]]) + m4 = ImmutableSparseMatrix([[a, b], [c, d]]) + assert m1.todok() == m2.todok() == m3.todok() == m4.todok() == \ + {(0, 0): a, (0, 1): b, (1, 0): c, (1, 1): d} + + +def test_tolist(): + lst = [[S.One, S.Half, x*y, S.Zero], [x, y, z, x**2], [y, -S.One, z*x, 3]] + flat_lst = [S.One, S.Half, x*y, S.Zero, x, y, z, x**2, y, -S.One, z*x, 3] + m = Matrix(3, 4, flat_lst) + assert m.tolist() == lst + + +def test_todod(): + m = Matrix([[S.One, 0], [0, S.Half], [x, 0]]) + dict = {0: {0: S.One}, 1: {1: S.Half}, 2: {0: x}} + assert m.todod() == dict + + +def test_row_col_del(): + e = ImmutableMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + raises(IndexError, lambda: e.row_del(5)) + raises(IndexError, lambda: e.row_del(-5)) + raises(IndexError, lambda: e.col_del(5)) + raises(IndexError, lambda: e.col_del(-5)) + + assert e.row_del(2) == e.row_del(-1) == Matrix([[1, 2, 3], [4, 5, 6]]) + assert e.col_del(2) == e.col_del(-1) == Matrix([[1, 2], [4, 5], [7, 8]]) + + assert e.row_del(1) == e.row_del(-2) == Matrix([[1, 2, 3], [7, 8, 9]]) + assert e.col_del(1) == e.col_del(-2) == Matrix([[1, 3], [4, 6], [7, 9]]) + + +def test_get_diag_blocks1(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + assert a.get_diag_blocks() == [a] + assert b.get_diag_blocks() == [b] + assert c.get_diag_blocks() == [c] + + +def test_get_diag_blocks2(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + A, B, C, D = diag(a, b, b), diag(a, b, c), diag(a, c, b), diag(c, c, b) + A = Matrix(A.rows, A.cols, A) + B = Matrix(B.rows, B.cols, B) + C = Matrix(C.rows, C.cols, C) + D = Matrix(D.rows, D.cols, D) + + assert A.get_diag_blocks() == [a, b, b] + assert B.get_diag_blocks() == [a, b, c] + assert C.get_diag_blocks() == [a, c, b] + assert D.get_diag_blocks() == [c, c, b] + + +def test_row_col(): + m = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + assert m.row(0) == Matrix(1, 3, [1, 2, 3]) + assert m.col(0) == Matrix(3, 1, [1, 4, 7]) + + +def test_row_join(): + assert eye(3).row_join(Matrix([7, 7, 7])) == \ + Matrix([[1, 0, 0, 7], + [0, 1, 0, 7], + [0, 0, 1, 7]]) + + +def test_col_join(): + assert eye(3).col_join(Matrix([[7, 7, 7]])) == \ + Matrix([[1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [7, 7, 7]]) + + +def test_row_insert(): + r4 = Matrix([[4, 4, 4]]) + for i in range(-4, 5): + l = [1, 0, 0] + l.insert(i, 4) + assert eye(3).row_insert(i, r4).col(0).flat() == l + + +def test_col_insert(): + c4 = Matrix([4, 4, 4]) + for i in range(-4, 5): + l = [0, 0, 0] + l.insert(i, 4) + assert zeros(3).col_insert(i, c4).row(0).flat() == l + # issue 13643 + assert eye(6).col_insert(3, Matrix([[2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]])) == \ + Matrix([[1, 0, 0, 2, 2, 0, 0, 0], + [0, 1, 0, 2, 2, 0, 0, 0], + [0, 0, 1, 2, 2, 0, 0, 0], + [0, 0, 0, 2, 2, 1, 0, 0], + [0, 0, 0, 2, 2, 0, 1, 0], + [0, 0, 0, 2, 2, 0, 0, 1]]) + + +def test_extract(): + m = Matrix(4, 3, lambda i, j: i*3 + j) + assert m.extract([0, 1, 3], [0, 1]) == Matrix(3, 2, [0, 1, 3, 4, 9, 10]) + assert m.extract([0, 3], [0, 0, 2]) == Matrix(2, 3, [0, 0, 2, 9, 9, 11]) + assert m.extract(range(4), range(3)) == m + raises(IndexError, lambda: m.extract([4], [0])) + raises(IndexError, lambda: m.extract([0], [3])) + + +def test_hstack(): + m = Matrix(4, 3, lambda i, j: i*3 + j) + m2 = Matrix(3, 4, lambda i, j: i*3 + j) + assert m == m.hstack(m) + assert m.hstack(m, m, m) == Matrix.hstack(m, m, m) == Matrix([ + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5], + [6, 7, 8, 6, 7, 8, 6, 7, 8], + [9, 10, 11, 9, 10, 11, 9, 10, 11]]) + raises(ShapeError, lambda: m.hstack(m, m2)) + assert Matrix.hstack() == Matrix() + + # test regression #12938 + M1 = Matrix.zeros(0, 0) + M2 = Matrix.zeros(0, 1) + M3 = Matrix.zeros(0, 2) + M4 = Matrix.zeros(0, 3) + m = Matrix.hstack(M1, M2, M3, M4) + assert m.rows == 0 and m.cols == 6 + + +def test_vstack(): + m = Matrix(4, 3, lambda i, j: i*3 + j) + m2 = Matrix(3, 4, lambda i, j: i*3 + j) + assert m == m.vstack(m) + assert m.vstack(m, m, m) == Matrix.vstack(m, m, m) == Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11], + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11], + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11]]) + raises(ShapeError, lambda: m.vstack(m, m2)) + assert Matrix.vstack() == Matrix() + + +def test_has(): + A = Matrix(((x, y), (2, 3))) + assert A.has(x) + assert not A.has(z) + assert A.has(Symbol) + + A = Matrix(((2, y), (2, 3))) + assert not A.has(x) + + +def test_is_anti_symmetric(): + x = symbols('x') + assert Matrix(2, 1, [1, 2]).is_anti_symmetric() is False + m = Matrix(3, 3, [0, x**2 + 2*x + 1, y, -(x + 1)**2, 0, x*y, -y, -x*y, 0]) + assert m.is_anti_symmetric() is True + assert m.is_anti_symmetric(simplify=False) is None + assert m.is_anti_symmetric(simplify=lambda x: x) is None + + m = Matrix(3, 3, [x.expand() for x in m]) + assert m.is_anti_symmetric(simplify=False) is True + m = Matrix(3, 3, [x.expand() for x in [S.One] + list(m)[1:]]) + assert m.is_anti_symmetric() is False + + +def test_is_hermitian(): + a = Matrix([[1, I], [-I, 1]]) + assert a.is_hermitian + a = Matrix([[2*I, I], [-I, 1]]) + assert a.is_hermitian is False + a = Matrix([[x, I], [-I, 1]]) + assert a.is_hermitian is None + a = Matrix([[x, 1], [-I, 1]]) + assert a.is_hermitian is False + + +def test_is_symbolic(): + a = Matrix([[x, x], [x, x]]) + assert a.is_symbolic() is True + a = Matrix([[1, 2, 3, 4], [5, 6, 7, 8]]) + assert a.is_symbolic() is False + a = Matrix([[1, 2, 3, 4], [5, 6, x, 8]]) + assert a.is_symbolic() is True + a = Matrix([[1, x, 3]]) + assert a.is_symbolic() is True + a = Matrix([[1, 2, 3]]) + assert a.is_symbolic() is False + a = Matrix([[1], [x], [3]]) + assert a.is_symbolic() is True + a = Matrix([[1], [2], [3]]) + assert a.is_symbolic() is False + + +def test_is_square(): + m = Matrix([[1], [1]]) + m2 = Matrix([[2, 2], [2, 2]]) + assert not m.is_square + assert m2.is_square + + +def test_is_symmetric(): + m = Matrix(2, 2, [0, 1, 1, 0]) + assert m.is_symmetric() + m = Matrix(2, 2, [0, 1, 0, 1]) + assert not m.is_symmetric() + + +def test_is_hessenberg(): + A = Matrix([[3, 4, 1], [2, 4, 5], [0, 1, 2]]) + assert A.is_upper_hessenberg + A = Matrix(3, 3, [3, 2, 0, 4, 4, 1, 1, 5, 2]) + assert A.is_lower_hessenberg + A = Matrix(3, 3, [3, 2, -1, 4, 4, 1, 1, 5, 2]) + assert A.is_lower_hessenberg is False + assert A.is_upper_hessenberg is False + + A = Matrix([[3, 4, 1], [2, 4, 5], [3, 1, 2]]) + assert not A.is_upper_hessenberg + + +def test_values(): + assert set(Matrix(2, 2, [0, 1, 2, 3] + ).values()) == {1, 2, 3} + x = Symbol('x', real=True) + assert set(Matrix(2, 2, [x, 0, 0, 1] + ).values()) == {x, 1} + + +def test_conjugate(): + M = Matrix([[0, I, 5], + [1, 2, 0]]) + + assert M.T == Matrix([[0, 1], + [I, 2], + [5, 0]]) + + assert M.C == Matrix([[0, -I, 5], + [1, 2, 0]]) + assert M.C == M.conjugate() + + assert M.H == M.T.C + assert M.H == Matrix([[ 0, 1], + [-I, 2], + [ 5, 0]]) + + +def test_doit(): + a = Matrix([[Add(x, x, evaluate=False)]]) + assert a[0] != 2*x + assert a.doit() == Matrix([[2*x]]) + + +def test_evalf(): + a = Matrix(2, 1, [sqrt(5), 6]) + assert all(a.evalf()[i] == a[i].evalf() for i in range(2)) + assert all(a.evalf(2)[i] == a[i].evalf(2) for i in range(2)) + assert all(a.n(2)[i] == a[i].n(2) for i in range(2)) + + +def test_replace(): + F, G = symbols('F, G', cls=Function) + K = Matrix(2, 2, lambda i, j: G(i+j)) + M = Matrix(2, 2, lambda i, j: F(i+j)) + N = M.replace(F, G) + assert N == K + + +def test_replace_map(): + F, G = symbols('F, G', cls=Function) + M = Matrix(2, 2, lambda i, j: F(i+j)) + N, d = M.replace(F, G, True) + assert N == Matrix(2, 2, lambda i, j: G(i+j)) + assert d == {F(0): G(0), F(1): G(1), F(2): G(2)} + +def test_numpy_conversion(): + try: + from numpy import array, array_equal + except ImportError: + skip('NumPy must be available to test creating matrices from ndarrays') + A = Matrix([[1,2], [3,4]]) + np_array = array([[1,2], [3,4]]) + assert array_equal(array(A), np_array) + assert array_equal(array(A, copy=True), np_array) + if(int(version('numpy').split('.')[0]) >= 2): #run this test only if numpy is new enough that copy variable is passed properly. + raises(TypeError, lambda: array(A, copy=False)) + +def test_rot90(): + A = Matrix([[1, 2], [3, 4]]) + assert A == A.rot90(0) == A.rot90(4) + assert A.rot90(2) == A.rot90(-2) == A.rot90(6) == Matrix(((4, 3), (2, 1))) + assert A.rot90(3) == A.rot90(-1) == A.rot90(7) == Matrix(((2, 4), (1, 3))) + assert A.rot90() == A.rot90(-7) == A.rot90(-3) == Matrix(((3, 1), (4, 2))) + + +def test_subs(): + assert Matrix([[1, x], [x, 4]]).subs(x, 5) == Matrix([[1, 5], [5, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).subs([[x, -1], [y, -2]]) == \ + Matrix([[-1, 2], [-3, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).subs([(x, -1), (y, -2)]) == \ + Matrix([[-1, 2], [-3, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).subs({x: -1, y: -2}) == \ + Matrix([[-1, 2], [-3, 4]]) + assert Matrix([[x*y]]).subs({x: y - 1, y: x - 1}, simultaneous=True) == \ + Matrix([[(x - 1)*(y - 1)]]) + + +def test_permute(): + a = Matrix(3, 4, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + + raises(IndexError, lambda: a.permute([[0, 5]])) + raises(ValueError, lambda: a.permute(Symbol('x'))) + b = a.permute_rows([[0, 2], [0, 1]]) + assert a.permute([[0, 2], [0, 1]]) == b == Matrix([ + [5, 6, 7, 8], + [9, 10, 11, 12], + [1, 2, 3, 4]]) + + b = a.permute_cols([[0, 2], [0, 1]]) + assert a.permute([[0, 2], [0, 1]], orientation='cols') == b ==\ + Matrix([ + [ 2, 3, 1, 4], + [ 6, 7, 5, 8], + [10, 11, 9, 12]]) + + b = a.permute_cols([[0, 2], [0, 1]], direction='backward') + assert a.permute([[0, 2], [0, 1]], orientation='cols', direction='backward') == b ==\ + Matrix([ + [ 3, 1, 2, 4], + [ 7, 5, 6, 8], + [11, 9, 10, 12]]) + + assert a.permute([1, 2, 0, 3]) == Matrix([ + [5, 6, 7, 8], + [9, 10, 11, 12], + [1, 2, 3, 4]]) + + from sympy.combinatorics import Permutation + assert a.permute(Permutation([1, 2, 0, 3])) == Matrix([ + [5, 6, 7, 8], + [9, 10, 11, 12], + [1, 2, 3, 4]]) + +def test_upper_triangular(): + + A = Matrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] + ]) + + R = A.upper_triangular(2) + assert R == Matrix([ + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0], + [0, 0, 0, 0] + ]) + + R = A.upper_triangular(-2) + assert R == Matrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [0, 1, 1, 1] + ]) + + R = A.upper_triangular() + assert R == Matrix([ + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 1] + ]) + + +def test_lower_triangular(): + A = Matrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] + ]) + + L = A.lower_triangular() + assert L == Matrix([ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]]) + + L = A.lower_triangular(2) + assert L == Matrix([ + [1, 1, 1, 0], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] + ]) + + L = A.lower_triangular(-2) + assert L == Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [1, 1, 0, 0] + ]) + + +def test_add(): + m = Matrix([[1, 2, 3], [x, y, x], [2*y, -50, z*x]]) + assert m + m == Matrix([[2, 4, 6], [2*x, 2*y, 2*x], [4*y, -100, 2*z*x]]) + n = Matrix(1, 2, [1, 2]) + raises(ShapeError, lambda: m + n) + + +def test_matmul(): + a = Matrix([[1, 2], [3, 4]]) + + assert a.__matmul__(2) == NotImplemented + + assert a.__rmatmul__(2) == NotImplemented + + #This is done this way because @ is only supported in Python 3.5+ + #To check 2@a case + try: + eval('2 @ a') + except SyntaxError: + pass + except TypeError: #TypeError is raised in case of NotImplemented is returned + pass + + #Check a@2 case + try: + eval('a @ 2') + except SyntaxError: + pass + except TypeError: #TypeError is raised in case of NotImplemented is returned + pass + + +def test_non_matmul(): + """ + Test that if explicitly specified as non-matrix, mul reverts + to scalar multiplication. + """ + class foo(Expr): + is_Matrix=False + is_MatrixLike=False + shape = (1, 1) + + A = Matrix([[1, 2], [3, 4]]) + b = foo() + assert b*A == Matrix([[b, 2*b], [3*b, 4*b]]) + assert A*b == Matrix([[b, 2*b], [3*b, 4*b]]) + + +def test_neg(): + n = Matrix(1, 2, [1, 2]) + assert -n == Matrix(1, 2, [-1, -2]) + + +def test_sub(): + n = Matrix(1, 2, [1, 2]) + assert n - n == Matrix(1, 2, [0, 0]) + + +def test_div(): + n = Matrix(1, 2, [1, 2]) + assert n/2 == Matrix(1, 2, [S.Half, S(2)/2]) + + +def test_eye(): + assert list(Matrix.eye(2, 2)) == [1, 0, 0, 1] + assert list(Matrix.eye(2)) == [1, 0, 0, 1] + assert type(Matrix.eye(2)) == Matrix + assert type(Matrix.eye(2, cls=Matrix)) == Matrix + + +def test_ones(): + assert list(Matrix.ones(2, 2)) == [1, 1, 1, 1] + assert list(Matrix.ones(2)) == [1, 1, 1, 1] + assert Matrix.ones(2, 3) == Matrix([[1, 1, 1], [1, 1, 1]]) + assert type(Matrix.ones(2)) == Matrix + assert type(Matrix.ones(2, cls=Matrix)) == Matrix + + +def test_zeros(): + assert list(Matrix.zeros(2, 2)) == [0, 0, 0, 0] + assert list(Matrix.zeros(2)) == [0, 0, 0, 0] + assert Matrix.zeros(2, 3) == Matrix([[0, 0, 0], [0, 0, 0]]) + assert type(Matrix.zeros(2)) == Matrix + assert type(Matrix.zeros(2, cls=Matrix)) == Matrix + + +def test_diag_make(): + diag = Matrix.diag + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + assert diag(a, b, b) == Matrix([ + [1, 2, 0, 0, 0, 0], + [2, 3, 0, 0, 0, 0], + [0, 0, 3, x, 0, 0], + [0, 0, y, 3, 0, 0], + [0, 0, 0, 0, 3, x], + [0, 0, 0, 0, y, 3], + ]) + assert diag(a, b, c) == Matrix([ + [1, 2, 0, 0, 0, 0, 0], + [2, 3, 0, 0, 0, 0, 0], + [0, 0, 3, x, 0, 0, 0], + [0, 0, y, 3, 0, 0, 0], + [0, 0, 0, 0, 3, x, 3], + [0, 0, 0, 0, y, 3, z], + [0, 0, 0, 0, x, y, z], + ]) + assert diag(a, c, b) == Matrix([ + [1, 2, 0, 0, 0, 0, 0], + [2, 3, 0, 0, 0, 0, 0], + [0, 0, 3, x, 3, 0, 0], + [0, 0, y, 3, z, 0, 0], + [0, 0, x, y, z, 0, 0], + [0, 0, 0, 0, 0, 3, x], + [0, 0, 0, 0, 0, y, 3], + ]) + a = Matrix([x, y, z]) + b = Matrix([[1, 2], [3, 4]]) + c = Matrix([[5, 6]]) + # this "wandering diagonal" is what makes this + # a block diagonal where each block is independent + # of the others + assert diag(a, 7, b, c) == Matrix([ + [x, 0, 0, 0, 0, 0], + [y, 0, 0, 0, 0, 0], + [z, 0, 0, 0, 0, 0], + [0, 7, 0, 0, 0, 0], + [0, 0, 1, 2, 0, 0], + [0, 0, 3, 4, 0, 0], + [0, 0, 0, 0, 5, 6]]) + raises(ValueError, lambda: diag(a, 7, b, c, rows=5)) + assert diag(1) == Matrix([[1]]) + assert diag(1, rows=2) == Matrix([[1, 0], [0, 0]]) + assert diag(1, cols=2) == Matrix([[1, 0], [0, 0]]) + assert diag(1, rows=3, cols=2) == Matrix([[1, 0], [0, 0], [0, 0]]) + assert diag(*[2, 3]) == Matrix([ + [2, 0], + [0, 3]]) + assert diag(Matrix([2, 3])) == Matrix([ + [2], + [3]]) + assert diag([1, [2, 3], 4], unpack=False) == \ + diag([[1], [2, 3], [4]], unpack=False) == Matrix([ + [1, 0], + [2, 3], + [4, 0]]) + assert type(diag(1)) == Matrix + assert type(diag(1, cls=Matrix)) == Matrix + assert Matrix.diag([1, 2, 3]) == Matrix.diag(1, 2, 3) + assert Matrix.diag([1, 2, 3], unpack=False).shape == (3, 1) + assert Matrix.diag([[1, 2, 3]]).shape == (3, 1) + assert Matrix.diag([[1, 2, 3]], unpack=False).shape == (1, 3) + assert Matrix.diag([[[1, 2, 3]]]).shape == (1, 3) + # kerning can be used to move the starting point + assert Matrix.diag(ones(0, 2), 1, 2) == Matrix([ + [0, 0, 1, 0], + [0, 0, 0, 2]]) + assert Matrix.diag(ones(2, 0), 1, 2) == Matrix([ + [0, 0], + [0, 0], + [1, 0], + [0, 2]]) + + +def test_diagonal(): + m = Matrix(3, 3, range(9)) + d = m.diagonal() + assert d == m.diagonal(0) + assert tuple(d) == (0, 4, 8) + assert tuple(m.diagonal(1)) == (1, 5) + assert tuple(m.diagonal(-1)) == (3, 7) + assert tuple(m.diagonal(2)) == (2,) + assert type(m.diagonal()) == type(m) + s = SparseMatrix(3, 3, {(1, 1): 1}) + assert type(s.diagonal()) == type(s) + assert type(m) != type(s) + raises(ValueError, lambda: m.diagonal(3)) + raises(ValueError, lambda: m.diagonal(-3)) + raises(ValueError, lambda: m.diagonal(pi)) + M = ones(2, 3) + assert banded({i: list(M.diagonal(i)) + for i in range(1-M.rows, M.cols)}) == M + + +def test_jordan_block(): + assert Matrix.jordan_block(3, 2) == Matrix.jordan_block(3, eigenvalue=2) \ + == Matrix.jordan_block(size=3, eigenvalue=2) \ + == Matrix.jordan_block(3, 2, band='upper') \ + == Matrix.jordan_block( + size=3, eigenval=2, eigenvalue=2) \ + == Matrix([ + [2, 1, 0], + [0, 2, 1], + [0, 0, 2]]) + + assert Matrix.jordan_block(3, 2, band='lower') == Matrix([ + [2, 0, 0], + [1, 2, 0], + [0, 1, 2]]) + # missing eigenvalue + raises(ValueError, lambda: Matrix.jordan_block(2)) + # non-integral size + raises(ValueError, lambda: Matrix.jordan_block(3.5, 2)) + # size not specified + raises(ValueError, lambda: Matrix.jordan_block(eigenvalue=2)) + # inconsistent eigenvalue + raises(ValueError, + lambda: Matrix.jordan_block( + eigenvalue=2, eigenval=4)) + + # Using alias keyword + assert Matrix.jordan_block(size=3, eigenvalue=2) == \ + Matrix.jordan_block(size=3, eigenval=2) + + +def test_orthogonalize(): + m = Matrix([[1, 2], [3, 4]]) + assert m.orthogonalize(Matrix([[2], [1]])) == [Matrix([[2], [1]])] + assert m.orthogonalize(Matrix([[2], [1]]), normalize=True) == \ + [Matrix([[2*sqrt(5)/5], [sqrt(5)/5]])] + assert m.orthogonalize(Matrix([[1], [2]]), Matrix([[-1], [4]])) == \ + [Matrix([[1], [2]]), Matrix([[Rational(-12, 5)], [Rational(6, 5)]])] + assert m.orthogonalize(Matrix([[0], [0]]), Matrix([[-1], [4]])) == \ + [Matrix([[-1], [4]])] + assert m.orthogonalize(Matrix([[0], [0]])) == [] + + n = Matrix([[9, 1, 9], [3, 6, 10], [8, 5, 2]]) + vecs = [Matrix([[-5], [1]]), Matrix([[-5], [2]]), Matrix([[-5], [-2]])] + assert n.orthogonalize(*vecs) == \ + [Matrix([[-5], [1]]), Matrix([[Rational(5, 26)], [Rational(25, 26)]])] + + vecs = [Matrix([0, 0, 0]), Matrix([1, 2, 3]), Matrix([1, 4, 5])] + raises(ValueError, lambda: Matrix.orthogonalize(*vecs, rankcheck=True)) + + vecs = [Matrix([1, 2, 3]), Matrix([4, 5, 6]), Matrix([7, 8, 9])] + raises(ValueError, lambda: Matrix.orthogonalize(*vecs, rankcheck=True)) + +def test_wilkinson(): + + wminus, wplus = Matrix.wilkinson(1) + assert wminus == Matrix([ + [-1, 1, 0], + [1, 0, 1], + [0, 1, 1]]) + assert wplus == Matrix([ + [1, 1, 0], + [1, 0, 1], + [0, 1, 1]]) + + wminus, wplus = Matrix.wilkinson(3) + assert wminus == Matrix([ + [-3, 1, 0, 0, 0, 0, 0], + [1, -2, 1, 0, 0, 0, 0], + [0, 1, -1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 2, 1], + + [0, 0, 0, 0, 0, 1, 3]]) + + assert wplus == Matrix([ + [3, 1, 0, 0, 0, 0, 0], + [1, 2, 1, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 2, 1], + [0, 0, 0, 0, 0, 1, 3]]) + + +def test_limit(): + x, y = symbols('x y') + m = Matrix(2, 1, [1/x, y]) + assert m.limit(x, 5) == Matrix(2, 1, [Rational(1, 5), y]) + A = Matrix(((1, 4, sin(x)/x), (y, 2, 4), (10, 5, x**2 + 1))) + assert A.limit(x, 0) == Matrix(((1, 4, 1), (y, 2, 4), (10, 5, 1))) + + +def test_issue_13774(): + M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + v = [1, 1, 1] + raises(TypeError, lambda: M*v) + raises(TypeError, lambda: v*M) + + +def test_companion(): + x = Symbol('x') + y = Symbol('y') + raises(ValueError, lambda: Matrix.companion(1)) + raises(ValueError, lambda: Matrix.companion(Poly([1], x))) + raises(ValueError, lambda: Matrix.companion(Poly([2, 1], x))) + raises(ValueError, lambda: Matrix.companion(Poly(x*y, [x, y]))) + + c0, c1, c2 = symbols('c0:3') + assert Matrix.companion(Poly([1, c0], x)) == Matrix([-c0]) + assert Matrix.companion(Poly([1, c1, c0], x)) == \ + Matrix([[0, -c0], [1, -c1]]) + assert Matrix.companion(Poly([1, c2, c1, c0], x)) == \ + Matrix([[0, 0, -c0], [1, 0, -c1], [0, 1, -c2]]) + + +def test_issue_10589(): + x, y, z = symbols("x, y z") + M1 = Matrix([x, y, z]) + M1 = M1.subs(zip([x, y, z], [1, 2, 3])) + assert M1 == Matrix([[1], [2], [3]]) + + M2 = Matrix([[x, x, x, x, x], [x, x, x, x, x], [x, x, x, x, x]]) + M2 = M2.subs(zip([x], [1])) + assert M2 == Matrix([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]) + + +def test_rmul_pr19860(): + class Foo(ImmutableDenseMatrix): + _op_priority = MutableDenseMatrix._op_priority + 0.01 + + a = Matrix(2, 2, [1, 2, 3, 4]) + b = Foo(2, 2, [1, 2, 3, 4]) + + # This would throw a RecursionError: maximum recursion depth + # since b always has higher priority even after a.as_mutable() + c = a*b + + assert isinstance(c, Foo) + assert c == Matrix([[7, 10], [15, 22]]) + + +def test_issue_18956(): + A = Array([[1, 2], [3, 4]]) + B = Matrix([[1,2],[3,4]]) + raises(TypeError, lambda: B + A) + raises(TypeError, lambda: A + B) + + +def test__eq__(): + class My(object): + def __iter__(self): + yield 1 + yield 2 + return + def __getitem__(self, i): + return list(self)[i] + a = Matrix(2, 1, [1, 2]) + assert a != My() + class My_sympy(My): + def _sympy_(self): + return Matrix(self) + assert a == My_sympy() + + +def test_args(): + for n, cls in enumerate(all_classes): + m = cls.zeros(3, 2) + # all should give back the same type of arguments, e.g. ints for shape + assert m.shape == (3, 2) and all(type(i) is int for i in m.shape) + assert m.rows == 3 and type(m.rows) is int + assert m.cols == 2 and type(m.cols) is int + if not n % 2: + assert type(m.flat()) in (list, tuple, Tuple) + else: + assert type(m.todok()) is dict + + +def test_deprecated_mat_smat(): + for cls in Matrix, ImmutableMatrix: + m = cls.zeros(3, 2) + with warns_deprecated_sympy(): + mat = m._mat + assert mat == m.flat() + for cls in SparseMatrix, ImmutableSparseMatrix: + m = cls.zeros(3, 2) + with warns_deprecated_sympy(): + smat = m._smat + assert smat == m.todok() + + +def test_division(): + v = Matrix(1, 2, [x, y]) + assert v/z == Matrix(1, 2, [x/z, y/z]) + + +def test_sum(): + m = Matrix([[1, 2, 3], [x, y, x], [2*y, -50, z*x]]) + assert m + m == Matrix([[2, 4, 6], [2*x, 2*y, 2*x], [4*y, -100, 2*z*x]]) + n = Matrix(1, 2, [1, 2]) + raises(ShapeError, lambda: m + n) + + +def test_abs(): + m = Matrix([[1, -2], [x, y]]) + assert abs(m) == Matrix([[1, 2], [Abs(x), Abs(y)]]) + m = Matrix(1, 2, [-3, x]) + n = Matrix(1, 2, [3, Abs(x)]) + assert abs(m) == n + + +def test_addition(): + a = Matrix(( + (1, 2), + (3, 1), + )) + + b = Matrix(( + (1, 2), + (3, 0), + )) + + assert a + b == a.add(b) == Matrix([[2, 4], [6, 1]]) + + +def test_fancy_index_matrix(): + for M in (Matrix, SparseMatrix): + a = M(3, 3, range(9)) + assert a == a[:, :] + assert a[1, :] == Matrix(1, 3, [3, 4, 5]) + assert a[:, 1] == Matrix([1, 4, 7]) + assert a[[0, 1], :] == Matrix([[0, 1, 2], [3, 4, 5]]) + assert a[[0, 1], 2] == a[[0, 1], [2]] + assert a[2, [0, 1]] == a[[2], [0, 1]] + assert a[:, [0, 1]] == Matrix([[0, 1], [3, 4], [6, 7]]) + assert a[0, 0] == 0 + assert a[0:2, :] == Matrix([[0, 1, 2], [3, 4, 5]]) + assert a[:, 0:2] == Matrix([[0, 1], [3, 4], [6, 7]]) + assert a[::2, 1] == a[[0, 2], 1] + assert a[1, ::2] == a[1, [0, 2]] + a = M(3, 3, range(9)) + assert a[[0, 2, 1, 2, 1], :] == Matrix([ + [0, 1, 2], + [6, 7, 8], + [3, 4, 5], + [6, 7, 8], + [3, 4, 5]]) + assert a[:, [0,2,1,2,1]] == Matrix([ + [0, 2, 1, 2, 1], + [3, 5, 4, 5, 4], + [6, 8, 7, 8, 7]]) + + a = SparseMatrix.zeros(3) + a[1, 2] = 2 + a[0, 1] = 3 + a[2, 0] = 4 + assert a.extract([1, 1], [2]) == Matrix([ + [2], + [2]]) + assert a.extract([1, 0], [2, 2, 2]) == Matrix([ + [2, 2, 2], + [0, 0, 0]]) + assert a.extract([1, 0, 1, 2], [2, 0, 1, 0]) == Matrix([ + [2, 0, 0, 0], + [0, 0, 3, 0], + [2, 0, 0, 0], + [0, 4, 0, 4]]) + + +def test_multiplication(): + a = Matrix(( + (1, 2), + (3, 1), + (0, 6), + )) + + b = Matrix(( + (1, 2), + (3, 0), + )) + + raises(ShapeError, lambda: b*a) + raises(TypeError, lambda: a*{}) + + c = a*b + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + c = a @ b + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + h = matrix_multiply_elementwise(a, c) + assert h == a.multiply_elementwise(c) + assert h[0, 0] == 7 + assert h[0, 1] == 4 + assert h[1, 0] == 18 + assert h[1, 1] == 6 + assert h[2, 0] == 0 + assert h[2, 1] == 0 + raises(ShapeError, lambda: matrix_multiply_elementwise(a, b)) + + c = b * Symbol("x") + assert isinstance(c, Matrix) + assert c[0, 0] == x + assert c[0, 1] == 2*x + assert c[1, 0] == 3*x + assert c[1, 1] == 0 + + c2 = x * b + assert c == c2 + + c = 5 * b + assert isinstance(c, Matrix) + assert c[0, 0] == 5 + assert c[0, 1] == 2*5 + assert c[1, 0] == 3*5 + assert c[1, 1] == 0 + + M = Matrix([[oo, 0], [0, oo]]) + assert M ** 2 == M + + M = Matrix([[oo, oo], [0, 0]]) + assert M ** 2 == Matrix([[nan, nan], [nan, nan]]) + + # https://github.com/sympy/sympy/issues/22353 + A = Matrix(ones(3, 1)) + _h = -Rational(1, 2) + B = Matrix([_h, _h, _h]) + assert A.multiply_elementwise(B) == Matrix([ + [_h], + [_h], + [_h]]) + + +def test_power(): + raises(NonSquareMatrixError, lambda: Matrix((1, 2))**2) + + A = Matrix([[2, 3], [4, 5]]) + assert A**5 == Matrix([[6140, 8097], [10796, 14237]]) + A = Matrix([[2, 1, 3], [4, 2, 4], [6, 12, 1]]) + assert A**3 == Matrix([[290, 262, 251], [448, 440, 368], [702, 954, 433]]) + assert A**0 == eye(3) + assert A**1 == A + assert (Matrix([[2]]) ** 100)[0, 0] == 2**100 + assert Matrix([[1, 2], [3, 4]])**Integer(2) == Matrix([[7, 10], [15, 22]]) + A = Matrix([[1,2],[4,5]]) + assert A.pow(20, method='cayley') == A.pow(20, method='multiply') + assert A**Integer(2) == Matrix([[9, 12], [24, 33]]) + assert eye(2)**10000000 == eye(2) + + A = Matrix([[33, 24], [48, 57]]) + assert (A**S.Half)[:] == [5, 2, 4, 7] + A = Matrix([[0, 4], [-1, 5]]) + assert (A**S.Half)**2 == A + + assert Matrix([[1, 0], [1, 1]])**S.Half == Matrix([[1, 0], [S.Half, 1]]) + assert Matrix([[1, 0], [1, 1]])**0.5 == Matrix([[1, 0], [0.5, 1]]) + from sympy.abc import n + assert Matrix([[1, a], [0, 1]])**n == Matrix([[1, a*n], [0, 1]]) + assert Matrix([[b, a], [0, b]])**n == Matrix([[b**n, a*b**(n-1)*n], [0, b**n]]) + assert Matrix([ + [a**n, a**(n - 1)*n, (a**n*n**2 - a**n*n)/(2*a**2)], + [ 0, a**n, a**(n - 1)*n], + [ 0, 0, a**n]]) + assert Matrix([[a, 1, 0], [0, a, 0], [0, 0, b]])**n == Matrix([ + [a**n, a**(n-1)*n, 0], + [0, a**n, 0], + [0, 0, b**n]]) + + A = Matrix([[1, 0], [1, 7]]) + assert A._matrix_pow_by_jordan_blocks(S(3)) == A._eval_pow_by_recursion(3) + A = Matrix([[2]]) + assert A**10 == Matrix([[2**10]]) == A._matrix_pow_by_jordan_blocks(S(10)) == \ + A._eval_pow_by_recursion(10) + + # testing a matrix that cannot be jordan blocked issue 11766 + m = Matrix([[3, 0, 0, 0, -3], [0, -3, -3, 0, 3], [0, 3, 0, 3, 0], [0, 0, 3, 0, 3], [3, 0, 0, 3, 0]]) + raises(MatrixError, lambda: m._matrix_pow_by_jordan_blocks(S(10))) + + # test issue 11964 + raises(MatrixError, lambda: Matrix([[1, 1], [3, 3]])._matrix_pow_by_jordan_blocks(S(-10))) + A = Matrix([[0, 1, 0], [0, 0, 1], [0, 0, 0]]) # Nilpotent jordan block size 3 + assert A**10.0 == Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + raises(ValueError, lambda: A**2.1) + raises(ValueError, lambda: A**Rational(3, 2)) + A = Matrix([[8, 1], [3, 2]]) + assert A**10.0 == Matrix([[1760744107, 272388050], [817164150, 126415807]]) + A = Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) # Nilpotent jordan block size 1 + assert A**10.0 == Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + A = Matrix([[0, 1, 0], [0, 0, 1], [0, 0, 1]]) # Nilpotent jordan block size 2 + assert A**10.0 == Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + n = Symbol('n', integer=True) + assert isinstance(A**n, MatPow) + n = Symbol('n', integer=True, negative=True) + raises(ValueError, lambda: A**n) + n = Symbol('n', integer=True, nonnegative=True) + assert A**n == Matrix([ + [KroneckerDelta(0, n), KroneckerDelta(1, n), -KroneckerDelta(0, n) - KroneckerDelta(1, n) + 1], + [ 0, KroneckerDelta(0, n), 1 - KroneckerDelta(0, n)], + [ 0, 0, 1]]) + assert A**(n + 2) == Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + raises(ValueError, lambda: A**Rational(3, 2)) + A = Matrix([[0, 0, 1], [3, 0, 1], [4, 3, 1]]) + assert A**5.0 == Matrix([[168, 72, 89], [291, 144, 161], [572, 267, 329]]) + assert A**5.0 == A**5 + A = Matrix([[0, 1, 0],[-1, 0, 0],[0, 0, 0]]) + n = Symbol("n") + An = A**n + assert An.subs(n, 2).doit() == A**2 + raises(ValueError, lambda: An.subs(n, -2).doit()) + assert An * An == A**(2*n) + + # concretizing behavior for non-integer and complex powers + A = Matrix([[0,0,0],[0,0,0],[0,0,0]]) + n = Symbol('n', integer=True, positive=True) + assert A**n == A + n = Symbol('n', integer=True, nonnegative=True) + assert A**n == diag(0**n, 0**n, 0**n) + assert (A**n).subs(n, 0) == eye(3) + assert (A**n).subs(n, 1) == zeros(3) + A = Matrix ([[2,0,0],[0,2,0],[0,0,2]]) + assert A**2.1 == diag (2**2.1, 2**2.1, 2**2.1) + assert A**I == diag (2**I, 2**I, 2**I) + A = Matrix([[0, 1, 0], [0, 0, 1], [0, 0, 1]]) + raises(ValueError, lambda: A**2.1) + raises(ValueError, lambda: A**I) + A = Matrix([[S.Half, S.Half], [S.Half, S.Half]]) + assert A**S.Half == A + A = Matrix([[1, 1],[3, 3]]) + assert A**S.Half == Matrix ([[S.Half, S.Half], [3*S.Half, 3*S.Half]]) + + +def test_issue_17247_expression_blowup_1(): + M = Matrix([[1+x, 1-x], [1-x, 1+x]]) + with dotprodsimp(True): + assert M.exp().expand() == Matrix([ + [ (exp(2*x) + exp(2))/2, (-exp(2*x) + exp(2))/2], + [(-exp(2*x) + exp(2))/2, (exp(2*x) + exp(2))/2]]) + + +def test_issue_17247_expression_blowup_2(): + M = Matrix([[1+x, 1-x], [1-x, 1+x]]) + with dotprodsimp(True): + P, J = M.jordan_form () + assert P*J*P.inv() + + +def test_issue_17247_expression_blowup_3(): + M = Matrix([[1+x, 1-x], [1-x, 1+x]]) + with dotprodsimp(True): + assert M**100 == Matrix([ + [633825300114114700748351602688*x**100 + 633825300114114700748351602688, 633825300114114700748351602688 - 633825300114114700748351602688*x**100], + [633825300114114700748351602688 - 633825300114114700748351602688*x**100, 633825300114114700748351602688*x**100 + 633825300114114700748351602688]]) + + +def test_issue_17247_expression_blowup_4(): +# This matrix takes extremely long on current master even with intermediate simplification so an abbreviated version is used. It is left here for test in case of future optimizations. +# M = Matrix(S('''[ +# [ -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128, 3/64 + 13*I/64, -23/32 - 59*I/256, 15/128 - 3*I/32, 19/256 + 551*I/1024], +# [-149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024, 119/128 + 143*I/128, -10879/2048 + 4343*I/4096, 129/256 - 549*I/512, 42533/16384 + 29103*I/8192], +# [ 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128, 3/64 + 13*I/64, -23/32 - 59*I/256], +# [ -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024, 119/128 + 143*I/128, -10879/2048 + 4343*I/4096], +# [ 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128], +# [ 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024], +# [ -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64], +# [ 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512], +# [ -4*I, 27/2 + 6*I, -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64], +# [ 1/4 + 5*I/2, -23/8 - 57*I/16, 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128], +# [ -4, 9 - 5*I, -4*I, 27/2 + 6*I, -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16], +# [ -2*I, 119/8 + 29*I/4, 1/4 + 5*I/2, -23/8 - 57*I/16, 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128]]''')) +# assert M**10 == Matrix([ +# [ 7*(-221393644768594642173548179825793834595 - 1861633166167425978847110897013541127952*I)/9671406556917033397649408, 15*(31670992489131684885307005100073928751695 + 10329090958303458811115024718207404523808*I)/77371252455336267181195264, 7*(-3710978679372178839237291049477017392703 + 1377706064483132637295566581525806894169*I)/19342813113834066795298816, (9727707023582419994616144751727760051598 - 59261571067013123836477348473611225724433*I)/9671406556917033397649408, (31896723509506857062605551443641668183707 + 54643444538699269118869436271152084599580*I)/38685626227668133590597632, (-2024044860947539028275487595741003997397402 + 130959428791783397562960461903698670485863*I)/309485009821345068724781056, 3*(26190251453797590396533756519358368860907 - 27221191754180839338002754608545400941638*I)/77371252455336267181195264, (1154643595139959842768960128434994698330461 + 3385496216250226964322872072260446072295634*I)/618970019642690137449562112, 3*(-31849347263064464698310044805285774295286 - 11877437776464148281991240541742691164309*I)/77371252455336267181195264, (4661330392283532534549306589669150228040221 - 4171259766019818631067810706563064103956871*I)/1237940039285380274899124224, (9598353794289061833850770474812760144506 + 358027153990999990968244906482319780943983*I)/309485009821345068724781056, (-9755135335127734571547571921702373498554177 - 4837981372692695195747379349593041939686540*I)/2475880078570760549798248448], +# [(-379516731607474268954110071392894274962069 - 422272153179747548473724096872271700878296*I)/77371252455336267181195264, (41324748029613152354787280677832014263339501 - 12715121258662668420833935373453570749288074*I)/1237940039285380274899124224, (-339216903907423793947110742819264306542397 + 494174755147303922029979279454787373566517*I)/77371252455336267181195264, (-18121350839962855576667529908850640619878381 - 37413012454129786092962531597292531089199003*I)/1237940039285380274899124224, (2489661087330511608618880408199633556675926 + 1137821536550153872137379935240732287260863*I)/309485009821345068724781056, (-136644109701594123227587016790354220062972119 + 110130123468183660555391413889600443583585272*I)/4951760157141521099596496896, (1488043981274920070468141664150073426459593 - 9691968079933445130866371609614474474327650*I)/1237940039285380274899124224, 27*(4636797403026872518131756991410164760195942 + 3369103221138229204457272860484005850416533*I)/4951760157141521099596496896, (-8534279107365915284081669381642269800472363 + 2241118846262661434336333368511372725482742*I)/1237940039285380274899124224, (60923350128174260992536531692058086830950875 - 263673488093551053385865699805250505661590126*I)/9903520314283042199192993792, (18520943561240714459282253753348921824172569 + 24846649186468656345966986622110971925703604*I)/4951760157141521099596496896, (-232781130692604829085973604213529649638644431 + 35981505277760667933017117949103953338570617*I)/9903520314283042199192993792], +# [ (8742968295129404279528270438201520488950 + 3061473358639249112126847237482570858327*I)/4835703278458516698824704, (-245657313712011778432792959787098074935273 + 253113767861878869678042729088355086740856*I)/38685626227668133590597632, (1947031161734702327107371192008011621193 - 19462330079296259148177542369999791122762*I)/9671406556917033397649408, (552856485625209001527688949522750288619217 + 392928441196156725372494335248099016686580*I)/77371252455336267181195264, (-44542866621905323121630214897126343414629 + 3265340021421335059323962377647649632959*I)/19342813113834066795298816, (136272594005759723105646069956434264218730 - 330975364731707309489523680957584684763587*I)/38685626227668133590597632, (27392593965554149283318732469825168894401 + 75157071243800133880129376047131061115278*I)/38685626227668133590597632, 7*(-357821652913266734749960136017214096276154 - 45509144466378076475315751988405961498243*I)/309485009821345068724781056, (104485001373574280824835174390219397141149 - 99041000529599568255829489765415726168162*I)/77371252455336267181195264, (1198066993119982409323525798509037696321291 + 4249784165667887866939369628840569844519936*I)/618970019642690137449562112, (-114985392587849953209115599084503853611014 - 52510376847189529234864487459476242883449*I)/77371252455336267181195264, (6094620517051332877965959223269600650951573 - 4683469779240530439185019982269137976201163*I)/1237940039285380274899124224], +# [ (611292255597977285752123848828590587708323 - 216821743518546668382662964473055912169502*I)/77371252455336267181195264, (-1144023204575811464652692396337616594307487 + 12295317806312398617498029126807758490062855*I)/309485009821345068724781056, (-374093027769390002505693378578475235158281 - 573533923565898290299607461660384634333639*I)/77371252455336267181195264, (47405570632186659000138546955372796986832987 - 2837476058950808941605000274055970055096534*I)/1237940039285380274899124224, (-571573207393621076306216726219753090535121 + 533381457185823100878764749236639320783831*I)/77371252455336267181195264, (-7096548151856165056213543560958582513797519 - 24035731898756040059329175131592138642195366*I)/618970019642690137449562112, (2396762128833271142000266170154694033849225 + 1448501087375679588770230529017516492953051*I)/309485009821345068724781056, (-150609293845161968447166237242456473262037053 + 92581148080922977153207018003184520294188436*I)/4951760157141521099596496896, 5*(270278244730804315149356082977618054486347 - 1997830155222496880429743815321662710091562*I)/1237940039285380274899124224, (62978424789588828258068912690172109324360330 + 44803641177219298311493356929537007630129097*I)/2475880078570760549798248448, 19*(-451431106327656743945775812536216598712236 + 114924966793632084379437683991151177407937*I)/1237940039285380274899124224, (63417747628891221594106738815256002143915995 - 261508229397507037136324178612212080871150958*I)/9903520314283042199192993792], +# [ (-2144231934021288786200752920446633703357 + 2305614436009705803670842248131563850246*I)/1208925819614629174706176, (-90720949337459896266067589013987007078153 - 221951119475096403601562347412753844534569*I)/19342813113834066795298816, (11590973613116630788176337262688659880376 + 6514520676308992726483494976339330626159*I)/4835703278458516698824704, 3*(-131776217149000326618649542018343107657237 + 79095042939612668486212006406818285287004*I)/38685626227668133590597632, (10100577916793945997239221374025741184951 - 28631383488085522003281589065994018550748*I)/9671406556917033397649408, 67*(10090295594251078955008130473573667572549 + 10449901522697161049513326446427839676762*I)/77371252455336267181195264, (-54270981296988368730689531355811033930513 - 3413683117592637309471893510944045467443*I)/19342813113834066795298816, (440372322928679910536575560069973699181278 - 736603803202303189048085196176918214409081*I)/77371252455336267181195264, (33220374714789391132887731139763250155295 + 92055083048787219934030779066298919603554*I)/38685626227668133590597632, 5*(-594638554579967244348856981610805281527116 - 82309245323128933521987392165716076704057*I)/309485009821345068724781056, (128056368815300084550013708313312073721955 - 114619107488668120303579745393765245911404*I)/77371252455336267181195264, 21*(59839959255173222962789517794121843393573 + 241507883613676387255359616163487405826334*I)/618970019642690137449562112], +# [ (-13454485022325376674626653802541391955147 + 184471402121905621396582628515905949793486*I)/19342813113834066795298816, (-6158730123400322562149780662133074862437105 - 3416173052604643794120262081623703514107476*I)/154742504910672534362390528, (770558003844914708453618983120686116100419 - 127758381209767638635199674005029818518766*I)/77371252455336267181195264, (-4693005771813492267479835161596671660631703 + 12703585094750991389845384539501921531449948*I)/309485009821345068724781056, (-295028157441149027913545676461260860036601 - 841544569970643160358138082317324743450770*I)/77371252455336267181195264, (56716442796929448856312202561538574275502893 + 7216818824772560379753073185990186711454778*I)/1237940039285380274899124224, 15*(-87061038932753366532685677510172566368387 + 61306141156647596310941396434445461895538*I)/154742504910672534362390528, (-3455315109680781412178133042301025723909347 - 24969329563196972466388460746447646686670670*I)/618970019642690137449562112, (2453418854160886481106557323699250865361849 + 1497886802326243014471854112161398141242514*I)/309485009821345068724781056, (-151343224544252091980004429001205664193082173 + 90471883264187337053549090899816228846836628*I)/4951760157141521099596496896, (1652018205533026103358164026239417416432989 - 9959733619236515024261775397109724431400162*I)/1237940039285380274899124224, 3*(40676374242956907656984876692623172736522006 + 31023357083037817469535762230872667581366205*I)/4951760157141521099596496896], +# [ (-1226990509403328460274658603410696548387 - 4131739423109992672186585941938392788458*I)/1208925819614629174706176, (162392818524418973411975140074368079662703 + 23706194236915374831230612374344230400704*I)/9671406556917033397649408, (-3935678233089814180000602553655565621193 + 2283744757287145199688061892165659502483*I)/1208925819614629174706176, (-2400210250844254483454290806930306285131 - 315571356806370996069052930302295432758205*I)/19342813113834066795298816, (13365917938215281056563183751673390817910 + 15911483133819801118348625831132324863881*I)/4835703278458516698824704, 3*(-215950551370668982657516660700301003897855 + 51684341999223632631602864028309400489378*I)/38685626227668133590597632, (20886089946811765149439844691320027184765 - 30806277083146786592790625980769214361844*I)/9671406556917033397649408, (562180634592713285745940856221105667874855 + 1031543963988260765153550559766662245114916*I)/77371252455336267181195264, (-65820625814810177122941758625652476012867 - 12429918324787060890804395323920477537595*I)/19342813113834066795298816, (319147848192012911298771180196635859221089 - 402403304933906769233365689834404519960394*I)/38685626227668133590597632, (23035615120921026080284733394359587955057 + 115351677687031786114651452775242461310624*I)/38685626227668133590597632, (-3426830634881892756966440108592579264936130 - 1022954961164128745603407283836365128598559*I)/309485009821345068724781056], +# [ (-192574788060137531023716449082856117537757 - 69222967328876859586831013062387845780692*I)/19342813113834066795298816, (2736383768828013152914815341491629299773262 - 2773252698016291897599353862072533475408743*I)/77371252455336267181195264, (-23280005281223837717773057436155921656805 + 214784953368021840006305033048142888879224*I)/19342813113834066795298816, (-3035247484028969580570400133318947903462326 - 2195168903335435855621328554626336958674325*I)/77371252455336267181195264, (984552428291526892214541708637840971548653 - 64006622534521425620714598573494988589378*I)/77371252455336267181195264, (-3070650452470333005276715136041262898509903 + 7286424705750810474140953092161794621989080*I)/154742504910672534362390528, (-147848877109756404594659513386972921139270 - 416306113044186424749331418059456047650861*I)/38685626227668133590597632, (55272118474097814260289392337160619494260781 + 7494019668394781211907115583302403519488058*I)/1237940039285380274899124224, (-581537886583682322424771088996959213068864 + 542191617758465339135308203815256798407429*I)/77371252455336267181195264, (-6422548983676355789975736799494791970390991 - 23524183982209004826464749309156698827737702*I)/618970019642690137449562112, 7*(180747195387024536886923192475064903482083 + 84352527693562434817771649853047924991804*I)/154742504910672534362390528, (-135485179036717001055310712747643466592387031 + 102346575226653028836678855697782273460527608*I)/4951760157141521099596496896], +# [ (3384238362616083147067025892852431152105 + 156724444932584900214919898954874618256*I)/604462909807314587353088, (-59558300950677430189587207338385764871866 + 114427143574375271097298201388331237478857*I)/4835703278458516698824704, (-1356835789870635633517710130971800616227 - 7023484098542340388800213478357340875410*I)/1208925819614629174706176, (234884918567993750975181728413524549575881 + 79757294640629983786895695752733890213506*I)/9671406556917033397649408, (-7632732774935120473359202657160313866419 + 2905452608512927560554702228553291839465*I)/1208925819614629174706176, (52291747908702842344842889809762246649489 - 520996778817151392090736149644507525892649*I)/19342813113834066795298816, (17472406829219127839967951180375981717322 + 23464704213841582137898905375041819568669*I)/4835703278458516698824704, (-911026971811893092350229536132730760943307 + 150799318130900944080399439626714846752360*I)/38685626227668133590597632, (26234457233977042811089020440646443590687 - 45650293039576452023692126463683727692890*I)/9671406556917033397649408, 3*(288348388717468992528382586652654351121357 + 454526517721403048270274049572136109264668*I)/77371252455336267181195264, (-91583492367747094223295011999405657956347 - 12704691128268298435362255538069612411331*I)/19342813113834066795298816, (411208730251327843849027957710164064354221 - 569898526380691606955496789378230959965898*I)/38685626227668133590597632], +# [ (27127513117071487872628354831658811211795 - 37765296987901990355760582016892124833857*I)/4835703278458516698824704, (1741779916057680444272938534338833170625435 + 3083041729779495966997526404685535449810378*I)/77371252455336267181195264, 3*(-60642236251815783728374561836962709533401 - 24630301165439580049891518846174101510744*I)/19342813113834066795298816, 3*(445885207364591681637745678755008757483408 - 350948497734812895032502179455610024541643*I)/38685626227668133590597632, (-47373295621391195484367368282471381775684 + 219122969294089357477027867028071400054973*I)/19342813113834066795298816, (-2801565819673198722993348253876353741520438 - 2250142129822658548391697042460298703335701*I)/77371252455336267181195264, (801448252275607253266997552356128790317119 - 50890367688077858227059515894356594900558*I)/77371252455336267181195264, (-5082187758525931944557763799137987573501207 + 11610432359082071866576699236013484487676124*I)/309485009821345068724781056, (-328925127096560623794883760398247685166830 - 643447969697471610060622160899409680422019*I)/77371252455336267181195264, 15*(2954944669454003684028194956846659916299765 + 33434406416888505837444969347824812608566*I)/1237940039285380274899124224, (-415749104352001509942256567958449835766827 + 479330966144175743357171151440020955412219*I)/77371252455336267181195264, 3*(-4639987285852134369449873547637372282914255 - 11994411888966030153196659207284951579243273*I)/1237940039285380274899124224], +# [ (-478846096206269117345024348666145495601 + 1249092488629201351470551186322814883283*I)/302231454903657293676544, (-17749319421930878799354766626365926894989 - 18264580106418628161818752318217357231971*I)/1208925819614629174706176, (2801110795431528876849623279389579072819 + 363258850073786330770713557775566973248*I)/604462909807314587353088, (-59053496693129013745775512127095650616252 + 78143588734197260279248498898321500167517*I)/4835703278458516698824704, (-283186724922498212468162690097101115349 - 6443437753863179883794497936345437398276*I)/1208925819614629174706176, (188799118826748909206887165661384998787543 + 84274736720556630026311383931055307398820*I)/9671406556917033397649408, (-5482217151670072904078758141270295025989 + 1818284338672191024475557065444481298568*I)/1208925819614629174706176, (56564463395350195513805521309731217952281 - 360208541416798112109946262159695452898431*I)/19342813113834066795298816, 11*(1259539805728870739006416869463689438068 + 1409136581547898074455004171305324917387*I)/4835703278458516698824704, 5*(-123701190701414554945251071190688818343325 + 30997157322590424677294553832111902279712*I)/38685626227668133590597632, (16130917381301373033736295883982414239781 - 32752041297570919727145380131926943374516*I)/9671406556917033397649408, (650301385108223834347093740500375498354925 + 899526407681131828596801223402866051809258*I)/77371252455336267181195264], +# [ (9011388245256140876590294262420614839483 + 8167917972423946282513000869327525382672*I)/1208925819614629174706176, (-426393174084720190126376382194036323028924 + 180692224825757525982858693158209545430621*I)/9671406556917033397649408, (24588556702197802674765733448108154175535 - 45091766022876486566421953254051868331066*I)/4835703278458516698824704, (1872113939365285277373877183750416985089691 + 3030392393733212574744122057679633775773130*I)/77371252455336267181195264, (-222173405538046189185754954524429864167549 - 75193157893478637039381059488387511299116*I)/19342813113834066795298816, (2670821320766222522963689317316937579844558 - 2645837121493554383087981511645435472169191*I)/77371252455336267181195264, 5*(-2100110309556476773796963197283876204940 + 41957457246479840487980315496957337371937*I)/19342813113834066795298816, (-5733743755499084165382383818991531258980593 - 3328949988392698205198574824396695027195732*I)/154742504910672534362390528, (707827994365259025461378911159398206329247 - 265730616623227695108042528694302299777294*I)/77371252455336267181195264, (-1442501604682933002895864804409322823788319 + 11504137805563265043376405214378288793343879*I)/309485009821345068724781056, (-56130472299445561499538726459719629522285 - 61117552419727805035810982426639329818864*I)/9671406556917033397649408, (39053692321126079849054272431599539429908717 - 10209127700342570953247177602860848130710666*I)/1237940039285380274899124224]]) + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512], + [ 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64], + [ -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128], + [ 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16], + [ 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M**10 == Matrix(S('''[ + [ 7369525394972778926719607798014571861/604462909807314587353088 - 229284202061790301477392339912557559*I/151115727451828646838272, -19704281515163975949388435612632058035/1208925819614629174706176 + 14319858347987648723768698170712102887*I/302231454903657293676544, -3623281909451783042932142262164941211/604462909807314587353088 - 6039240602494288615094338643452320495*I/604462909807314587353088, 109260497799140408739847239685705357695/2417851639229258349412352 - 7427566006564572463236368211555511431*I/2417851639229258349412352, -16095803767674394244695716092817006641/2417851639229258349412352 + 10336681897356760057393429626719177583*I/1208925819614629174706176, -42207883340488041844332828574359769743/2417851639229258349412352 - 182332262671671273188016400290188468499*I/4835703278458516698824704], + [50566491050825573392726324995779608259/1208925819614629174706176 - 90047007594468146222002432884052362145*I/2417851639229258349412352, 74273703462900000967697427843983822011/1208925819614629174706176 + 265947522682943571171988741842776095421*I/1208925819614629174706176, -116900341394390200556829767923360888429/2417851639229258349412352 - 53153263356679268823910621474478756845*I/2417851639229258349412352, 195407378023867871243426523048612490249/1208925819614629174706176 - 1242417915995360200584837585002906728929*I/9671406556917033397649408, -863597594389821970177319682495878193/302231454903657293676544 + 476936100741548328800725360758734300481*I/9671406556917033397649408, -3154451590535653853562472176601754835575/19342813113834066795298816 - 232909875490506237386836489998407329215*I/2417851639229258349412352], + [ -1715444997702484578716037230949868543/302231454903657293676544 + 5009695651321306866158517287924120777*I/302231454903657293676544, -30551582497996879620371947949342101301/604462909807314587353088 - 7632518367986526187139161303331519629*I/151115727451828646838272, 312680739924495153190604170938220575/18889465931478580854784 - 108664334509328818765959789219208459*I/75557863725914323419136, -14693696966703036206178521686918865509/604462909807314587353088 + 72345386220900843930147151999899692401*I/1208925819614629174706176, -8218872496728882299722894680635296519/1208925819614629174706176 - 16776782833358893712645864791807664983*I/1208925819614629174706176, 143237839169380078671242929143670635137/2417851639229258349412352 + 2883817094806115974748882735218469447*I/2417851639229258349412352], + [ 3087979417831061365023111800749855987/151115727451828646838272 + 34441942370802869368851419102423997089*I/604462909807314587353088, -148309181940158040917731426845476175667/604462909807314587353088 - 263987151804109387844966835369350904919*I/9671406556917033397649408, 50259518594816377378747711930008883165/1208925819614629174706176 - 95713974916869240305450001443767979653*I/2417851639229258349412352, 153466447023875527996457943521467271119/2417851639229258349412352 + 517285524891117105834922278517084871349*I/2417851639229258349412352, -29184653615412989036678939366291205575/604462909807314587353088 - 27551322282526322041080173287022121083*I/1208925819614629174706176, 196404220110085511863671393922447671649/1208925819614629174706176 - 1204712019400186021982272049902206202145*I/9671406556917033397649408], + [ -2632581805949645784625606590600098779/151115727451828646838272 - 589957435912868015140272627522612771*I/37778931862957161709568, 26727850893953715274702844733506310247/302231454903657293676544 - 10825791956782128799168209600694020481*I/302231454903657293676544, -1036348763702366164044671908440791295/151115727451828646838272 + 3188624571414467767868303105288107375*I/151115727451828646838272, -36814959939970644875593411585393242449/604462909807314587353088 - 18457555789119782404850043842902832647*I/302231454903657293676544, 12454491297984637815063964572803058647/604462909807314587353088 - 340489532842249733975074349495329171*I/302231454903657293676544, -19547211751145597258386735573258916681/604462909807314587353088 + 87299583775782199663414539883938008933*I/1208925819614629174706176], + [ -40281994229560039213253423262678393183/604462909807314587353088 - 2939986850065527327299273003299736641*I/604462909807314587353088, 331940684638052085845743020267462794181/2417851639229258349412352 - 284574901963624403933361315517248458969*I/1208925819614629174706176, 6453843623051745485064693628073010961/302231454903657293676544 + 36062454107479732681350914931391590957*I/604462909807314587353088, -147665869053634695632880753646441962067/604462909807314587353088 - 305987938660447291246597544085345123927*I/9671406556917033397649408, 107821369195275772166593879711259469423/2417851639229258349412352 - 11645185518211204108659001435013326687*I/302231454903657293676544, 64121228424717666402009446088588091619/1208925819614629174706176 + 265557133337095047883844369272389762133*I/1208925819614629174706176]]''')) + + +def test_issue_17247_expression_blowup_5(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.charpoly('x') == PurePoly(x**6 + (-6 - 6*I)*x**5 + 36*I*x**4, x, domain='EX') + + +def test_issue_17247_expression_blowup_6(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.det('bareiss') == 0 + + +def test_issue_17247_expression_blowup_7(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.det('berkowitz') == 0 + + +def test_issue_17247_expression_blowup_8(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.det('lu') == 0 + + +def test_issue_17247_expression_blowup_9(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.rref() == (Matrix([ + [1, 0, -1, -2, -3, -4, -5, -6], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0]]), (0, 1)) + + +def test_issue_17247_expression_blowup_10(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.cofactor(0, 0) == 0 + + +def test_issue_17247_expression_blowup_11(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.cofactor_matrix() == Matrix(6, 6, [0]*36) + + +def test_issue_17247_expression_blowup_12(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.eigenvals() == {6: 1, 6*I: 1, 0: 4} + + +def test_issue_17247_expression_blowup_13(): + M = Matrix([ + [ 0, 1 - x, x + 1, 1 - x], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 1 - x], + [ 0, 0, 1 - x, 0]]) + + ev = M.eigenvects() + assert ev[0] == (0, 2, [Matrix([0, -1, 0, 1])]) + assert ev[1][0] == x - sqrt(2)*(x - 1) + 1 + assert ev[1][1] == 1 + assert ev[1][2][0].expand(deep=False, numer=True) == Matrix([ + [(-x + sqrt(2)*(x - 1) - 1)/(x - 1)], + [-4*x/(x**2 - 2*x + 1) + (x + 1)*(x - sqrt(2)*(x - 1) + 1)/(x**2 - 2*x + 1)], + [(-x + sqrt(2)*(x - 1) - 1)/(x - 1)], + [1] + ]) + + assert ev[2][0] == x + sqrt(2)*(x - 1) + 1 + assert ev[2][1] == 1 + assert ev[2][2][0].expand(deep=False, numer=True) == Matrix([ + [(-x - sqrt(2)*(x - 1) - 1)/(x - 1)], + [-4*x/(x**2 - 2*x + 1) + (x + 1)*(x + sqrt(2)*(x - 1) + 1)/(x**2 - 2*x + 1)], + [(-x - sqrt(2)*(x - 1) - 1)/(x - 1)], + [1] + ]) + + +def test_issue_17247_expression_blowup_14(): + M = Matrix(8, 8, ([1+x, 1-x]*4 + [1-x, 1+x]*4)*4) + with dotprodsimp(True): + assert M.echelon_form() == Matrix([ + [x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x], + [ 0, 4*x, 0, 4*x, 0, 4*x, 0, 4*x], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0]]) + + +def test_issue_17247_expression_blowup_15(): + M = Matrix(8, 8, ([1+x, 1-x]*4 + [1-x, 1+x]*4)*4) + with dotprodsimp(True): + assert M.rowspace() == [Matrix([[x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x]]), Matrix([[0, 4*x, 0, 4*x, 0, 4*x, 0, 4*x]])] + + +def test_issue_17247_expression_blowup_16(): + M = Matrix(8, 8, ([1+x, 1-x]*4 + [1-x, 1+x]*4)*4) + with dotprodsimp(True): + assert M.columnspace() == [Matrix([[x + 1],[1 - x],[x + 1],[1 - x],[x + 1],[1 - x],[x + 1],[1 - x]]), Matrix([[1 - x],[x + 1],[1 - x],[x + 1],[1 - x],[x + 1],[1 - x],[x + 1]])] + + +def test_issue_17247_expression_blowup_17(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.nullspace() == [ + Matrix([[1],[-2],[1],[0],[0],[0],[0],[0]]), + Matrix([[2],[-3],[0],[1],[0],[0],[0],[0]]), + Matrix([[3],[-4],[0],[0],[1],[0],[0],[0]]), + Matrix([[4],[-5],[0],[0],[0],[1],[0],[0]]), + Matrix([[5],[-6],[0],[0],[0],[0],[1],[0]]), + Matrix([[6],[-7],[0],[0],[0],[0],[0],[1]])] + + +def test_issue_17247_expression_blowup_18(): + M = Matrix(6, 6, ([1+x, 1-x]*3 + [1-x, 1+x]*3)*3) + with dotprodsimp(True): + assert not M.is_nilpotent() + + +def test_issue_17247_expression_blowup_19(): + M = Matrix(S('''[ + [ -3/4, 0, 1/4 + I/2, 0], + [ 0, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 1/2 - I, 0, 0, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert not M.is_diagonalizable() + + +def test_issue_17247_expression_blowup_20(): + M = Matrix([ + [x + 1, 1 - x, 0, 0], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 0], + [ 0, 0, 0, x + 1]]) + with dotprodsimp(True): + assert M.diagonalize() == (Matrix([ + [1, 1, 0, (x + 1)/(x - 1)], + [1, -1, 0, 0], + [1, 1, 1, 0], + [0, 0, 0, 1]]), + Matrix([ + [2, 0, 0, 0], + [0, 2*x, 0, 0], + [0, 0, x + 1, 0], + [0, 0, 0, x + 1]])) + + +def test_issue_17247_expression_blowup_21(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='GE') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + + +def test_issue_17247_expression_blowup_22(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='LU') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + + +def test_issue_17247_expression_blowup_23(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='ADJ').expand() == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + + +def test_issue_17247_expression_blowup_24(): + M = SparseMatrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='CH') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + + +def test_issue_17247_expression_blowup_25(): + M = SparseMatrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='LDL') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + + +def test_issue_17247_expression_blowup_26(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024], + [ 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64], + [ -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512], + [ 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64], + [ 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128], + [ -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16], + [ 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.rank() == 4 + + +def test_issue_17247_expression_blowup_27(): + M = Matrix([ + [ 0, 1 - x, x + 1, 1 - x], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 1 - x], + [ 0, 0, 1 - x, 0]]) + with dotprodsimp(True): + P, J = M.jordan_form() + assert P.expand() == Matrix(S('''[ + [ 0, 4*x/(x**2 - 2*x + 1), -(-17*x**4 + 12*sqrt(2)*x**4 - 4*sqrt(2)*x**3 + 6*x**3 - 6*x - 4*sqrt(2)*x + 12*sqrt(2) + 17)/(-7*x**4 + 5*sqrt(2)*x**4 - 6*sqrt(2)*x**3 + 8*x**3 - 2*x**2 + 8*x + 6*sqrt(2)*x - 5*sqrt(2) - 7), -(12*sqrt(2)*x**4 + 17*x**4 - 6*x**3 - 4*sqrt(2)*x**3 - 4*sqrt(2)*x + 6*x - 17 + 12*sqrt(2))/(7*x**4 + 5*sqrt(2)*x**4 - 6*sqrt(2)*x**3 - 8*x**3 + 2*x**2 - 8*x + 6*sqrt(2)*x - 5*sqrt(2) + 7)], + [x - 1, x/(x - 1) + 1/(x - 1), (-7*x**3 + 5*sqrt(2)*x**3 - x**2 + sqrt(2)*x**2 - sqrt(2)*x - x - 5*sqrt(2) - 7)/(-3*x**3 + 2*sqrt(2)*x**3 - 2*sqrt(2)*x**2 + 3*x**2 + 2*sqrt(2)*x + 3*x - 3 - 2*sqrt(2)), (7*x**3 + 5*sqrt(2)*x**3 + x**2 + sqrt(2)*x**2 - sqrt(2)*x + x - 5*sqrt(2) + 7)/(2*sqrt(2)*x**3 + 3*x**3 - 3*x**2 - 2*sqrt(2)*x**2 - 3*x + 2*sqrt(2)*x - 2*sqrt(2) + 3)], + [ 0, 1, -(-3*x**2 + 2*sqrt(2)*x**2 + 2*x - 3 - 2*sqrt(2))/(-x**2 + sqrt(2)*x**2 - 2*sqrt(2)*x + 1 + sqrt(2)), -(2*sqrt(2)*x**2 + 3*x**2 - 2*x - 2*sqrt(2) + 3)/(x**2 + sqrt(2)*x**2 - 2*sqrt(2)*x - 1 + sqrt(2))], + [1 - x, 0, 1, 1]]''')).expand() + assert J == Matrix(S('''[ + [0, 1, 0, 0], + [0, 0, 0, 0], + [0, 0, x - sqrt(2)*(x - 1) + 1, 0], + [0, 0, 0, x + sqrt(2)*(x - 1) + 1]]''')) + + +def test_issue_17247_expression_blowup_28(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.singular_values() == S('''[ + sqrt(14609315/131072 + sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) + 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2 + sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2), + sqrt(14609315/131072 - sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) + 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2 + sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2), + sqrt(14609315/131072 - sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2 + sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) - 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2), + sqrt(14609315/131072 - sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2 - sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) - 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2)]''') + + +def test_issue_16823(): + # This still needs to be fixed if not using dotprodsimp. + M = Matrix(S('''[ + [1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I,-9/32-1/16*I,183/256-97/128*I,3/64+13/64*I,-23/32-59/256*I,15/128-3/32*I,19/256+551/1024*I], + [21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I,-219/128+115/256*I,6301/4096-6609/1024*I,119/128+143/128*I,-10879/2048+4343/4096*I,129/256-549/512*I,42533/16384+29103/8192*I], + [-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I,-9/32-1/16*I,183/256-97/128*I,3/64+13/64*I,-23/32-59/256*I], + [1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I,-219/128+115/256*I,6301/4096-6609/1024*I,119/128+143/128*I,-10879/2048+4343/4096*I], + [-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I,-9/32-1/16*I,183/256-97/128*I], + [1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I,-219/128+115/256*I,6301/4096-6609/1024*I], + [-4,9-5*I,-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I], + [-2*I,119/8+29/4*I,1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I], + [0,-6,-4,9-5*I,-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I], + [1,-9/4+3*I,-2*I,119/8+29/4*I,1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I], + [0,-4*I,0,-6,-4,9-5*I,-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I], + [0,1/4+1/2*I,1,-9/4+3*I,-2*I,119/8+29/4*I,1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I]]''')) + with dotprodsimp(True): + assert M.rank() == 8 + + +def test_issue_18531(): + # solve_linear_system still needs fixing but the rref works. + M = Matrix([ + [1, 1, 1, 1, 1, 0, 1, 0, 0], + [1 + sqrt(2), -1 + sqrt(2), 1 - sqrt(2), -sqrt(2) - 1, 1, 1, -1, 1, 1], + [-5 + 2*sqrt(2), -5 - 2*sqrt(2), -5 - 2*sqrt(2), -5 + 2*sqrt(2), -7, 2, -7, -2, 0], + [-3*sqrt(2) - 1, 1 - 3*sqrt(2), -1 + 3*sqrt(2), 1 + 3*sqrt(2), -7, -5, 7, -5, 3], + [7 - 4*sqrt(2), 4*sqrt(2) + 7, 4*sqrt(2) + 7, 7 - 4*sqrt(2), 7, -12, 7, 12, 0], + [-1 + 3*sqrt(2), 1 + 3*sqrt(2), -3*sqrt(2) - 1, 1 - 3*sqrt(2), 7, -5, -7, -5, 3], + [-3 + 2*sqrt(2), -3 - 2*sqrt(2), -3 - 2*sqrt(2), -3 + 2*sqrt(2), -1, 2, -1, -2, 0], + [1 - sqrt(2), -sqrt(2) - 1, 1 + sqrt(2), -1 + sqrt(2), -1, 1, 1, 1, 1] + ]) + with dotprodsimp(True): + assert M.rref() == (Matrix([ + [1, 0, 0, 0, 0, 0, 0, 0, S(1)/2], + [0, 1, 0, 0, 0, 0, 0, 0, -S(1)/2], + [0, 0, 1, 0, 0, 0, 0, 0, S(1)/2], + [0, 0, 0, 1, 0, 0, 0, 0, -S(1)/2], + [0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, -S(1)/2], + [0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, -S(1)/2]]), (0, 1, 2, 3, 4, 5, 6, 7)) + + +def test_creation(): + raises(ValueError, lambda: Matrix(5, 5, range(20))) + raises(ValueError, lambda: Matrix(5, -1, [])) + raises(IndexError, lambda: Matrix((1, 2))[2]) + with raises(IndexError): + Matrix((1, 2))[3] = 5 + + assert Matrix() == Matrix([]) == Matrix(0, 0, []) + assert Matrix([[]]) == Matrix(1, 0, []) + assert Matrix([[], []]) == Matrix(2, 0, []) + + # anything used to be allowed in a matrix + with warns_deprecated_sympy(): + assert Matrix([[[1], (2,)]]).tolist() == [[[1], (2,)]] + with warns_deprecated_sympy(): + assert Matrix([[[1], (2,)]]).T.tolist() == [[[1]], [(2,)]] + M = Matrix([[0]]) + with warns_deprecated_sympy(): + M[0, 0] = S.EmptySet + + a = Matrix([[x, 0], [0, 0]]) + m = a + assert m.cols == m.rows + assert m.cols == 2 + assert m[:] == [x, 0, 0, 0] + + b = Matrix(2, 2, [x, 0, 0, 0]) + m = b + assert m.cols == m.rows + assert m.cols == 2 + assert m[:] == [x, 0, 0, 0] + + assert a == b + + assert Matrix(b) == b + + c23 = Matrix(2, 3, range(1, 7)) + c13 = Matrix(1, 3, range(7, 10)) + c = Matrix([c23, c13]) + assert c.cols == 3 + assert c.rows == 3 + assert c[:] == [1, 2, 3, 4, 5, 6, 7, 8, 9] + + assert Matrix(eye(2)) == eye(2) + assert ImmutableMatrix(ImmutableMatrix(eye(2))) == ImmutableMatrix(eye(2)) + assert ImmutableMatrix(c) == c.as_immutable() + assert Matrix(ImmutableMatrix(c)) == ImmutableMatrix(c).as_mutable() + + assert c is not Matrix(c) + + dat = [[ones(3,2), ones(3,3)*2], [ones(2,3)*3, ones(2,2)*4]] + M = Matrix(dat) + assert M == Matrix([ + [1, 1, 2, 2, 2], + [1, 1, 2, 2, 2], + [1, 1, 2, 2, 2], + [3, 3, 3, 4, 4], + [3, 3, 3, 4, 4]]) + assert M.tolist() != dat + # keep block form if evaluate=False + assert Matrix(dat, evaluate=False).tolist() == dat + A = MatrixSymbol("A", 2, 2) + dat = [ones(2), A] + assert Matrix(dat) == Matrix([ + [ 1, 1], + [ 1, 1], + [A[0, 0], A[0, 1]], + [A[1, 0], A[1, 1]]]) + with warns_deprecated_sympy(): + assert Matrix(dat, evaluate=False).tolist() == [[i] for i in dat] + + # 0-dim tolerance + assert Matrix([ones(2), ones(0)]) == Matrix([ones(2)]) + raises(ValueError, lambda: Matrix([ones(2), ones(0, 3)])) + raises(ValueError, lambda: Matrix([ones(2), ones(3, 0)])) + + # mix of Matrix and iterable + M = Matrix([[1, 2], [3, 4]]) + M2 = Matrix([M, (5, 6)]) + assert M2 == Matrix([[1, 2], [3, 4], [5, 6]]) + + +def test_irregular_block(): + assert Matrix.irregular(3, ones(2,1), ones(3,3)*2, ones(2,2)*3, + ones(1,1)*4, ones(2,2)*5, ones(1,2)*6, ones(1,2)*7) == Matrix([ + [1, 2, 2, 2, 3, 3], + [1, 2, 2, 2, 3, 3], + [4, 2, 2, 2, 5, 5], + [6, 6, 7, 7, 5, 5]]) + + +def test_slicing(): + m0 = eye(4) + assert m0[:3, :3] == eye(3) + assert m0[2:4, 0:2] == zeros(2) + + m1 = Matrix(3, 3, lambda i, j: i + j) + assert m1[0, :] == Matrix(1, 3, (0, 1, 2)) + assert m1[1:3, 1] == Matrix(2, 1, (2, 3)) + + m2 = Matrix([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]) + assert m2[:, -1] == Matrix(4, 1, [3, 7, 11, 15]) + assert m2[-2:, :] == Matrix([[8, 9, 10, 11], [12, 13, 14, 15]]) + + +def test_submatrix_assignment(): + m = zeros(4) + m[2:4, 2:4] = eye(2) + assert m == Matrix(((0, 0, 0, 0), + (0, 0, 0, 0), + (0, 0, 1, 0), + (0, 0, 0, 1))) + m[:2, :2] = eye(2) + assert m == eye(4) + m[:, 0] = Matrix(4, 1, (1, 2, 3, 4)) + assert m == Matrix(((1, 0, 0, 0), + (2, 1, 0, 0), + (3, 0, 1, 0), + (4, 0, 0, 1))) + m[:, :] = zeros(4) + assert m == zeros(4) + m[:, :] = [(1, 2, 3, 4), (5, 6, 7, 8), (9, 10, 11, 12), (13, 14, 15, 16)] + assert m == Matrix(((1, 2, 3, 4), + (5, 6, 7, 8), + (9, 10, 11, 12), + (13, 14, 15, 16))) + m[:2, 0] = [0, 0] + assert m == Matrix(((0, 2, 3, 4), + (0, 6, 7, 8), + (9, 10, 11, 12), + (13, 14, 15, 16))) + + +def test_reshape(): + m0 = eye(3) + assert m0.reshape(1, 9) == Matrix(1, 9, (1, 0, 0, 0, 1, 0, 0, 0, 1)) + m1 = Matrix(3, 4, lambda i, j: i + j) + assert m1.reshape( + 4, 3) == Matrix(((0, 1, 2), (3, 1, 2), (3, 4, 2), (3, 4, 5))) + assert m1.reshape(2, 6) == Matrix(((0, 1, 2, 3, 1, 2), (3, 4, 2, 3, 4, 5))) + + +def test_applyfunc(): + m0 = eye(3) + assert m0.applyfunc(lambda x: 2*x) == eye(3)*2 + assert m0.applyfunc(lambda x: 0) == zeros(3) + + +def test_expand(): + m0 = Matrix([[x*(x + y), 2], [((x + y)*y)*x, x*(y + x*(x + y))]]) + # Test if expand() returns a matrix + m1 = m0.expand() + assert m1 == Matrix( + [[x*y + x**2, 2], [x*y**2 + y*x**2, x*y + y*x**2 + x**3]]) + + a = Symbol('a', real=True) + + assert Matrix([exp(I*a)]).expand(complex=True) == \ + Matrix([cos(a) + I*sin(a)]) + + assert Matrix([[0, 1, 2], [0, 0, -1], [0, 0, 0]]).exp() == Matrix([ + [1, 1, Rational(3, 2)], + [0, 1, -1], + [0, 0, 1]] + ) + + +def test_refine(): + m0 = Matrix([[Abs(x)**2, sqrt(x**2)], + [sqrt(x**2)*Abs(y)**2, sqrt(y**2)*Abs(x)**2]]) + m1 = m0.refine(Q.real(x) & Q.real(y)) + assert m1 == Matrix([[x**2, Abs(x)], [y**2*Abs(x), x**2*Abs(y)]]) + + m1 = m0.refine(Q.positive(x) & Q.positive(y)) + assert m1 == Matrix([[x**2, x], [x*y**2, x**2*y]]) + + m1 = m0.refine(Q.negative(x) & Q.negative(y)) + assert m1 == Matrix([[x**2, -x], [-x*y**2, -x**2*y]]) + + +def test_random(): + M = randMatrix(3, 3) + M = randMatrix(3, 3, seed=3) + assert M == randMatrix(3, 3, seed=3) + + M = randMatrix(3, 4, 0, 150) + M = randMatrix(3, seed=4, symmetric=True) + assert M == randMatrix(3, seed=4, symmetric=True) + + S = M.copy() + S.simplify() + assert S == M # doesn't fail when elements are Numbers, not int + + rng = random.Random(4) + assert M == randMatrix(3, symmetric=True, prng=rng) + + # Ensure symmetry + for size in (10, 11): # Test odd and even + for percent in (100, 70, 30): + M = randMatrix(size, symmetric=True, percent=percent, prng=rng) + assert M == M.T + + M = randMatrix(10, min=1, percent=70) + zero_count = 0 + for i in range(M.shape[0]): + for j in range(M.shape[1]): + if M[i, j] == 0: + zero_count += 1 + assert zero_count == 30 + + +def test_inverse(): + A = eye(4) + assert A.inv() == eye(4) + assert A.inv(method="LU") == eye(4) + assert A.inv(method="ADJ") == eye(4) + assert A.inv(method="CH") == eye(4) + assert A.inv(method="LDL") == eye(4) + assert A.inv(method="QR") == eye(4) + A = Matrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + Ainv = A.inv() + assert A*Ainv == eye(3) + assert A.inv(method="LU") == Ainv + assert A.inv(method="ADJ") == Ainv + assert A.inv(method="CH") == Ainv + assert A.inv(method="LDL") == Ainv + assert A.inv(method="QR") == Ainv + + AA = Matrix([[0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0], + [1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0], + [1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], + [1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0], + [1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1], + [0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0], + [1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1], + [0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], + [1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0], + [0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0], + [1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0], + [0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1], + [1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0], + [0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0], + [1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1], + [0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1], + [1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1], + [0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1], + [0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1], + [0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0]]) + assert AA.inv(method="BLOCK") * AA == eye(AA.shape[0]) + # test that immutability is not a problem + cls = ImmutableMatrix + m = cls([[48, 49, 31], + [ 9, 71, 94], + [59, 28, 65]]) + assert all(type(m.inv(s)) is cls for s in 'GE ADJ LU CH LDL QR'.split()) + cls = ImmutableSparseMatrix + m = cls([[48, 49, 31], + [ 9, 71, 94], + [59, 28, 65]]) + assert all(type(m.inv(s)) is cls for s in 'GE ADJ LU CH LDL QR'.split()) + + +def test_inverse_symbolic_float_issue_26821(): + Tau, Tau_syn_in, Tau_syn_ex, C_m, Tau_syn_gap = symbols("Tau Tau_syn_in Tau_syn_ex C_m Tau_syn_gap") + __h = symbols("__h") + + M = Matrix([ + [0,0,0,0,0,(1.0*Tau*__h-1.0*Tau_syn_in*__h)/(2.0*Tau-1.0*Tau_syn_in),-1.0*Tau*Tau_syn_in/(2.0*Tau-1.0*Tau_syn_in)], + [0,0,0,0,0,(-1.0*Tau*__h+1.0*Tau_syn_in*__h)/(2.0*Tau*Tau_syn_in-1.0*Tau_syn_in**2),1.0], + [0,(1.0*Tau*__h-1.0*Tau_syn_ex*__h)/(2.0*Tau-1.0*Tau_syn_ex),-1.0*Tau*Tau_syn_ex/(2.0*Tau-1.0*Tau_syn_ex),0,0,0,0], + [0,(-1.0*Tau*__h+1.0*Tau_syn_ex*__h)/(2.0*Tau*Tau_syn_ex-1.0*Tau_syn_ex**2),1.0,0,0,0,0], + [0,0,0,(1.0*Tau*__h-1.0*Tau_syn_gap*__h)/(2.0*Tau-1.0*Tau_syn_gap),-1.0*Tau*Tau_syn_gap/(2.0*Tau-1.0*Tau_syn_gap),0,0], + [0,0,0,(-1.0*Tau*__h+1.0*Tau_syn_gap*__h)/(2.0*Tau*Tau_syn_gap-1.0*Tau_syn_gap**2),1.0,0,0], + [1.0,-1.0*Tau*Tau_syn_ex*__h/(2.0*C_m*Tau-1.0*C_m*Tau_syn_ex),0,-1.0*Tau*Tau_syn_gap*__h/(2.0*C_m*Tau-1.0*C_m*Tau_syn_gap),0,-1.0*Tau*Tau_syn_in*__h/(2.0*C_m*Tau-1.0*C_m*Tau_syn_in),0] + ]) + + Mi = M.inv() + + assert (M*Mi - eye(7)).applyfunc(cancel) == zeros(7) + + # https://github.com/sympy/sympy/issues/26821 + # Previously very large floats were in the result. + assert max(abs(f) for f in Mi.atoms(Float)) < 1e3 + + +@slow +def test_matrix_exponential_issue_26821(): + # The symbol names matter in the original bug... + a, b, c, d, e = symbols("Tau, Tau_syn_in, Tau_syn_ex, C_m, Tau_syn_gap") + t = symbols("__h") + M = Matrix([ + [ 0, 1.0, 0, 0, 0, 0, 0], + [-1/b**2, -2/b, 0, 0, 0, 0, 0], + [ 0, 0, 0, 1.0, 0, 0, 0], + [ 0, 0, -1/c**2, -2/c, 0, 0, 0], + [ 0, 0, 0, 0, 0, 1, 0], + [ 0, 0, 0, 0, -1/e**2, -2/e, 0], + [ 1/d, 0, 1/d, 0, 1/d, 0, -1/a] + ]) + + Me = (t*M).exp() + assert (Me.diff(t) - M*Me).applyfunc(cancel) == zeros(7) + # https://github.com/sympy/sympy/issues/26821 + # Previously very large floats were in the result. + assert max(abs(f) for f in Me.atoms(Float)) < 1e3 + + +def test_jacobian_hessian(): + L = Matrix(1, 2, [x**2*y, 2*y**2 + x*y]) + syms = [x, y] + assert L.jacobian(syms) == Matrix([[2*x*y, x**2], [y, 4*y + x]]) + + L = Matrix(1, 2, [x, x**2*y**3]) + assert L.jacobian(syms) == Matrix([[1, 0], [2*x*y**3, x**2*3*y**2]]) + + f = x**2*y + syms = [x, y] + assert hessian(f, syms) == Matrix([[2*y, 2*x], [2*x, 0]]) + + f = x**2*y**3 + assert hessian(f, syms) == \ + Matrix([[2*y**3, 6*x*y**2], [6*x*y**2, 6*x**2*y]]) + + f = z + x*y**2 + g = x**2 + 2*y**3 + ans = Matrix([[0, 2*y], + [2*y, 2*x]]) + assert ans == hessian(f, Matrix([x, y])) + assert ans == hessian(f, Matrix([x, y]).T) + assert hessian(f, (y, x), [g]) == Matrix([ + [ 0, 6*y**2, 2*x], + [6*y**2, 2*x, 2*y], + [ 2*x, 2*y, 0]]) + + +def test_wronskian(): + assert wronskian([cos(x), sin(x)], x) == cos(x)**2 + sin(x)**2 + assert wronskian([exp(x), exp(2*x)], x) == exp(3*x) + assert wronskian([exp(x), x], x) == exp(x) - x*exp(x) + assert wronskian([1, x, x**2], x) == 2 + w1 = -6*exp(x)*sin(x)*x + 6*cos(x)*exp(x)*x**2 - 6*exp(x)*cos(x)*x - \ + exp(x)*cos(x)*x**3 + exp(x)*sin(x)*x**3 + assert wronskian([exp(x), cos(x), x**3], x).expand() == w1 + assert wronskian([exp(x), cos(x), x**3], x, method='berkowitz').expand() \ + == w1 + w2 = -x**3*cos(x)**2 - x**3*sin(x)**2 - 6*x*cos(x)**2 - 6*x*sin(x)**2 + assert wronskian([sin(x), cos(x), x**3], x).expand() == w2 + assert wronskian([sin(x), cos(x), x**3], x, method='berkowitz').expand() \ + == w2 + assert wronskian([], x) == 1 + + +def test_xreplace(): + assert Matrix([[1, x], [x, 4]]).xreplace({x: 5}) == \ + Matrix([[1, 5], [5, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).xreplace({x: -1, y: -2}) == \ + Matrix([[-1, 2], [-3, 4]]) + for cls in all_classes: + assert Matrix([[2, 0], [0, 2]]) == cls.eye(2).xreplace({1: 2}) + + +def test_simplify(): + n = Symbol('n') + f = Function('f') + + M = Matrix([[ 1/x + 1/y, (x + x*y) / x ], + [ (f(x) + y*f(x))/f(x), 2 * (1/n - cos(n * pi)/n) / pi ]]) + M.simplify() + assert M == Matrix([[ (x + y)/(x * y), 1 + y ], + [ 1 + y, 2*((1 - 1*cos(pi*n))/(pi*n)) ]]) + eq = (1 + x)**2 + M = Matrix([[eq]]) + M.simplify() + assert M == Matrix([[eq]]) + M.simplify(ratio=oo) + assert M == Matrix([[eq.simplify(ratio=oo)]]) + + n = Symbol('n') + f = Function('f') + + M = ImmutableMatrix([ + [ 1/x + 1/y, (x + x*y) / x ], + [ (f(x) + y*f(x))/f(x), 2 * (1/n - cos(n * pi)/n) / pi ] + ]) + assert M.simplify() == Matrix([ + [ (x + y)/(x * y), 1 + y ], + [ 1 + y, 2*((1 - 1*cos(pi*n))/(pi*n)) ] + ]) + + eq = (1 + x)**2 + M = ImmutableMatrix([[eq]]) + assert M.simplify() == Matrix([[eq]]) + assert M.simplify(ratio=oo) == Matrix([[eq.simplify(ratio=oo)]]) + + assert simplify(ImmutableMatrix([[sin(x)**2 + cos(x)**2]])) == \ + ImmutableMatrix([[1]]) + + # https://github.com/sympy/sympy/issues/19353 + m = Matrix([[30, 2], [3, 4]]) + assert (1/(m.trace())).simplify() == Rational(1, 34) + +def test_transpose(): + M = Matrix([[1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]]) + assert M.T == Matrix( [ [1, 1], + [2, 2], + [3, 3], + [4, 4], + [5, 5], + [6, 6], + [7, 7], + [8, 8], + [9, 9], + [0, 0] ]) + assert M.T.T == M + assert M.T == M.transpose() + + +def test_conj_dirac(): + raises(AttributeError, lambda: eye(3).D) + + M = Matrix([[1, I, I, I], + [0, 1, I, I], + [0, 0, 1, I], + [0, 0, 0, 1]]) + + assert M.D == Matrix([[ 1, 0, 0, 0], + [-I, 1, 0, 0], + [-I, -I, -1, 0], + [-I, -I, I, -1]]) + + +def test_trace(): + M = Matrix([[1, 0, 0], + [0, 5, 0], + [0, 0, 8]]) + assert M.trace() == 14 + + +def test_shape(): + m = Matrix(1, 2, [0, 0]) + assert m.shape == (1, 2) + M = Matrix([[x, 0, 0], + [0, y, 0]]) + assert M.shape == (2, 3) + + +def test_col_row_op(): + M = Matrix([[x, 0, 0], + [0, y, 0]]) + M.row_op(1, lambda r, j: r + j + 1) + assert M == Matrix([[x, 0, 0], + [1, y + 2, 3]]) + + M.col_op(0, lambda c, j: c + y**j) + assert M == Matrix([[x + 1, 0, 0], + [1 + y, y + 2, 3]]) + + # neither row nor slice give copies that allow the original matrix to + # be changed + assert M.row(0) == Matrix([[x + 1, 0, 0]]) + r1 = M.row(0) + r1[0] = 42 + assert M[0, 0] == x + 1 + r1 = M[0, :-1] # also testing negative slice + r1[0] = 42 + assert M[0, 0] == x + 1 + c1 = M.col(0) + assert c1 == Matrix([x + 1, 1 + y]) + c1[0] = 0 + assert M[0, 0] == x + 1 + c1 = M[:, 0] + c1[0] = 42 + assert M[0, 0] == x + 1 + + +def test_row_mult(): + M = Matrix([[1,2,3], + [4,5,6]]) + M.row_mult(1,3) + assert M[1,0] == 12 + assert M[0,0] == 1 + assert M[1,2] == 18 + + +def test_row_add(): + M = Matrix([[1,2,3], + [4,5,6], + [1,1,1]]) + M.row_add(2,0,5) + assert M[0,0] == 6 + assert M[1,0] == 4 + assert M[0,2] == 8 + + +def test_zip_row_op(): + for cls in mutable_classes: # XXX: immutable matrices don't support row ops + M = cls.eye(3) + M.zip_row_op(1, 0, lambda v, u: v + 2*u) + assert M == cls([[1, 0, 0], + [2, 1, 0], + [0, 0, 1]]) + + M = cls.eye(3)*2 + M[0, 1] = -1 + M.zip_row_op(1, 0, lambda v, u: v + 2*u); M + assert M == cls([[2, -1, 0], + [4, 0, 0], + [0, 0, 2]]) + + +def test_issue_3950(): + m = Matrix([1, 2, 3]) + a = Matrix([1, 2, 3]) + b = Matrix([2, 2, 3]) + assert not (m in []) + assert not (m in [1]) + assert m != 1 + assert m == a + assert m != b + + +def test_issue_3981(): + class Index1: + def __index__(self): + return 1 + + class Index2: + def __index__(self): + return 2 + index1 = Index1() + index2 = Index2() + + m = Matrix([1, 2, 3]) + + assert m[index2] == 3 + + m[index2] = 5 + assert m[2] == 5 + + m = Matrix([[1, 2, 3], [4, 5, 6]]) + assert m[index1, index2] == 6 + assert m[1, index2] == 6 + assert m[index1, 2] == 6 + + m[index1, index2] = 4 + assert m[1, 2] == 4 + m[1, index2] = 6 + assert m[1, 2] == 6 + m[index1, 2] = 8 + assert m[1, 2] == 8 + + +def test_is_upper(): + a = Matrix([[1, 2, 3]]) + assert a.is_upper is True + a = Matrix([[1], [2], [3]]) + assert a.is_upper is False + a = zeros(4, 2) + assert a.is_upper is True + + +def test_is_lower(): + a = Matrix([[1, 2, 3]]) + assert a.is_lower is False + a = Matrix([[1], [2], [3]]) + assert a.is_lower is True + + +def test_is_nilpotent(): + a = Matrix(4, 4, [0, 2, 1, 6, 0, 0, 1, 2, 0, 0, 0, 3, 0, 0, 0, 0]) + assert a.is_nilpotent() + a = Matrix([[1, 0], [0, 1]]) + assert not a.is_nilpotent() + a = Matrix([]) + assert a.is_nilpotent() + + +def test_zeros_ones_fill(): + n, m = 3, 5 + + a = zeros(n, m) + a.fill( 5 ) + + b = 5 * ones(n, m) + + assert a == b + assert a.rows == b.rows == 3 + assert a.cols == b.cols == 5 + assert a.shape == b.shape == (3, 5) + assert zeros(2) == zeros(2, 2) + assert ones(2) == ones(2, 2) + assert zeros(2, 3) == Matrix(2, 3, [0]*6) + assert ones(2, 3) == Matrix(2, 3, [1]*6) + + a.fill(0) + assert a == zeros(n, m) + + +def test_empty_zeros(): + a = zeros(0) + assert a == Matrix() + a = zeros(0, 2) + assert a.rows == 0 + assert a.cols == 2 + a = zeros(2, 0) + assert a.rows == 2 + assert a.cols == 0 + + +def test_issue_3749(): + a = Matrix([[x**2, x*y], [x*sin(y), x*cos(y)]]) + assert a.diff(x) == Matrix([[2*x, y], [sin(y), cos(y)]]) + assert Matrix([ + [x, -x, x**2], + [exp(x), 1/x - exp(-x), x + 1/x]]).limit(x, oo) == \ + Matrix([[oo, -oo, oo], [oo, 0, oo]]) + assert Matrix([ + [(exp(x) - 1)/x, 2*x + y*x, x**x ], + [1/x, abs(x), abs(sin(x + 1))]]).limit(x, 0) == \ + Matrix([[1, 0, 1], [oo, 0, sin(1)]]) + assert a.integrate(x) == Matrix([ + [Rational(1, 3)*x**3, y*x**2/2], + [x**2*sin(y)/2, x**2*cos(y)/2]]) + + +def test_inv_iszerofunc(): + A = eye(4) + A.col_swap(0, 1) + for method in "GE", "LU": + assert A.inv(method=method, iszerofunc=lambda x: x == 0) == \ + A.inv(method="ADJ") + + +def test_jacobian_metrics(): + rho, phi = symbols("rho,phi") + X = Matrix([rho*cos(phi), rho*sin(phi)]) + Y = Matrix([rho, phi]) + J = X.jacobian(Y) + assert J == X.jacobian(Y.T) + assert J == (X.T).jacobian(Y) + assert J == (X.T).jacobian(Y.T) + g = J.T*eye(J.shape[0])*J + g = g.applyfunc(trigsimp) + assert g == Matrix([[1, 0], [0, rho**2]]) + + +def test_jacobian2(): + rho, phi = symbols("rho,phi") + X = Matrix([rho*cos(phi), rho*sin(phi), rho**2]) + Y = Matrix([rho, phi]) + J = Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)], + [ 2*rho, 0], + ]) + assert X.jacobian(Y) == J + + +def test_issue_4564(): + X = Matrix([exp(x + y + z), exp(x + y + z), exp(x + y + z)]) + Y = Matrix([x, y, z]) + for i in range(1, 3): + for j in range(1, 3): + X_slice = X[:i, :] + Y_slice = Y[:j, :] + J = X_slice.jacobian(Y_slice) + assert J.rows == i + assert J.cols == j + for k in range(j): + assert J[:, k] == X_slice + + +def test_nonvectorJacobian(): + X = Matrix([[exp(x + y + z), exp(x + y + z)], + [exp(x + y + z), exp(x + y + z)]]) + raises(TypeError, lambda: X.jacobian(Matrix([x, y, z]))) + X = X[0, :] + Y = Matrix([[x, y], [x, z]]) + raises(TypeError, lambda: X.jacobian(Y)) + raises(TypeError, lambda: X.jacobian(Matrix([ [x, y], [x, z] ]))) + + +def test_vec(): + m = Matrix([[1, 3], [2, 4]]) + m_vec = m.vec() + assert m_vec.cols == 1 + for i in range(4): + assert m_vec[i] == i + 1 + + +def test_vech(): + m = Matrix([[1, 2], [2, 3]]) + m_vech = m.vech() + assert m_vech.cols == 1 + for i in range(3): + assert m_vech[i] == i + 1 + m_vech = m.vech(diagonal=False) + assert m_vech[0] == 2 + + m = Matrix([[1, x*(x + y)], [y*x + x**2, 1]]) + m_vech = m.vech(diagonal=False) + assert m_vech[0] == y*x + x**2 + + m = Matrix([[1, x*(x + y)], [y*x, 1]]) + m_vech = m.vech(diagonal=False, check_symmetry=False) + assert m_vech[0] == y*x + + raises(ShapeError, lambda: Matrix([[1, 3]]).vech()) + raises(ValueError, lambda: Matrix([[1, 3], [2, 4]]).vech()) + raises(ShapeError, lambda: Matrix([[1, 3]]).vech()) + raises(ValueError, lambda: Matrix([[1, 3], [2, 4]]).vech()) + + +def test_diag(): + # mostly tested in testcommonmatrix.py + assert diag([1, 2, 3]) == Matrix([1, 2, 3]) + m = [1, 2, [3]] + raises(ValueError, lambda: diag(m)) + assert diag(m, strict=False) == Matrix([1, 2, 3]) + + +def test_inv_block(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + A = diag(a, b, b) + assert A.inv(try_block_diag=True) == diag(a.inv(), b.inv(), b.inv()) + A = diag(a, b, c) + assert A.inv(try_block_diag=True) == diag(a.inv(), b.inv(), c.inv()) + A = diag(a, c, b) + assert A.inv(try_block_diag=True) == diag(a.inv(), c.inv(), b.inv()) + A = diag(a, a, b, a, c, a) + assert A.inv(try_block_diag=True) == diag( + a.inv(), a.inv(), b.inv(), a.inv(), c.inv(), a.inv()) + assert A.inv(try_block_diag=True, method="ADJ") == diag( + a.inv(method="ADJ"), a.inv(method="ADJ"), b.inv(method="ADJ"), + a.inv(method="ADJ"), c.inv(method="ADJ"), a.inv(method="ADJ")) + + +def test_creation_args(): + """ + Check that matrix dimensions can be specified using any reasonable type + (see issue 4614). + """ + raises(ValueError, lambda: zeros(3, -1)) + raises(TypeError, lambda: zeros(1, 2, 3, 4)) + assert zeros(int(3)) == zeros(3) + assert zeros(Integer(3)) == zeros(3) + raises(ValueError, lambda: zeros(3.)) + assert eye(int(3)) == eye(3) + assert eye(Integer(3)) == eye(3) + raises(ValueError, lambda: eye(3.)) + assert ones(int(3), Integer(4)) == ones(3, 4) + raises(TypeError, lambda: Matrix(5)) + raises(TypeError, lambda: Matrix(1, 2)) + raises(ValueError, lambda: Matrix([1, [2]])) + + +def test_diagonal_symmetrical(): + m = Matrix(2, 2, [0, 1, 1, 0]) + assert not m.is_diagonal() + assert m.is_symmetric() + assert m.is_symmetric(simplify=False) + + m = Matrix(2, 2, [1, 0, 0, 1]) + assert m.is_diagonal() + + m = diag(1, 2, 3) + assert m.is_diagonal() + assert m.is_symmetric() + + m = Matrix(3, 3, [1, 0, 0, 0, 2, 0, 0, 0, 3]) + assert m == diag(1, 2, 3) + + m = Matrix(2, 3, zeros(2, 3)) + assert not m.is_symmetric() + assert m.is_diagonal() + + m = Matrix(((5, 0), (0, 6), (0, 0))) + assert m.is_diagonal() + + m = Matrix(((5, 0, 0), (0, 6, 0))) + assert m.is_diagonal() + + m = Matrix(3, 3, [1, x**2 + 2*x + 1, y, (x + 1)**2, 2, 0, y, 0, 3]) + assert m.is_symmetric() + assert not m.is_symmetric(simplify=False) + assert m.expand().is_symmetric(simplify=False) + + +def test_diagonalization(): + m = Matrix([[1, 2+I], [2-I, 3]]) + assert m.is_diagonalizable() + + m = Matrix(3, 2, [-3, 1, -3, 20, 3, 10]) + assert not m.is_diagonalizable() + assert not m.is_symmetric() + raises(NonSquareMatrixError, lambda: m.diagonalize()) + + # diagonalizable + m = diag(1, 2, 3) + (P, D) = m.diagonalize() + assert P == eye(3) + assert D == m + + m = Matrix(2, 2, [0, 1, 1, 0]) + assert m.is_symmetric() + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + + m = Matrix(2, 2, [1, 0, 0, 3]) + assert m.is_symmetric() + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + assert P == eye(2) + assert D == m + + m = Matrix(2, 2, [1, 1, 0, 0]) + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + + m = Matrix(3, 3, [1, 2, 0, 0, 3, 0, 2, -4, 2]) + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + for i in P: + assert i.as_numer_denom()[1] == 1 + + m = Matrix(2, 2, [1, 0, 0, 0]) + assert m.is_diagonal() + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + assert P == Matrix([[0, 1], [1, 0]]) + + # diagonalizable, complex only + m = Matrix(2, 2, [0, 1, -1, 0]) + assert not m.is_diagonalizable(True) + raises(MatrixError, lambda: m.diagonalize(True)) + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + + # not diagonalizable + m = Matrix(2, 2, [0, 1, 0, 0]) + assert not m.is_diagonalizable() + raises(MatrixError, lambda: m.diagonalize()) + + m = Matrix(3, 3, [-3, 1, -3, 20, 3, 10, 2, -2, 4]) + assert not m.is_diagonalizable() + raises(MatrixError, lambda: m.diagonalize()) + + # symbolic + a, b, c, d = symbols('a b c d') + m = Matrix(2, 2, [a, c, c, b]) + assert m.is_symmetric() + assert m.is_diagonalizable() + + +def test_issue_15887(): + # Mutable matrix should not use cache + a = MutableDenseMatrix([[0, 1], [1, 0]]) + assert a.is_diagonalizable() is True + a[1, 0] = 0 + assert a.is_diagonalizable() is False + + a = MutableDenseMatrix([[0, 1], [1, 0]]) + a.diagonalize() + a[1, 0] = 0 + raises(MatrixError, lambda: a.diagonalize()) + + +def test_jordan_form(): + + m = Matrix(3, 2, [-3, 1, -3, 20, 3, 10]) + raises(NonSquareMatrixError, lambda: m.jordan_form()) + + # diagonalizable + m = Matrix(3, 3, [7, -12, 6, 10, -19, 10, 12, -24, 13]) + Jmust = Matrix(3, 3, [-1, 0, 0, 0, 1, 0, 0, 0, 1]) + P, J = m.jordan_form() + assert Jmust == J + assert Jmust == m.diagonalize()[1] + + # m = Matrix(3, 3, [0, 6, 3, 1, 3, 1, -2, 2, 1]) + # m.jordan_form() # very long + # m.jordan_form() # + + # diagonalizable, complex only + + # Jordan cells + # complexity: one of eigenvalues is zero + m = Matrix(3, 3, [0, 1, 0, -4, 4, 0, -2, 1, 2]) + # The blocks are ordered according to the value of their eigenvalues, + # in order to make the matrix compatible with .diagonalize() + Jmust = Matrix(3, 3, [2, 1, 0, 0, 2, 0, 0, 0, 2]) + P, J = m.jordan_form() + assert Jmust == J + + # complexity: all of eigenvalues are equal + m = Matrix(3, 3, [2, 6, -15, 1, 1, -5, 1, 2, -6]) + # Jmust = Matrix(3, 3, [-1, 0, 0, 0, -1, 1, 0, 0, -1]) + # same here see 1456ff + Jmust = Matrix(3, 3, [-1, 1, 0, 0, -1, 0, 0, 0, -1]) + P, J = m.jordan_form() + assert Jmust == J + + # complexity: two of eigenvalues are zero + m = Matrix(3, 3, [4, -5, 2, 5, -7, 3, 6, -9, 4]) + Jmust = Matrix(3, 3, [0, 1, 0, 0, 0, 0, 0, 0, 1]) + P, J = m.jordan_form() + assert Jmust == J + + m = Matrix(4, 4, [6, 5, -2, -3, -3, -1, 3, 3, 2, 1, -2, -3, -1, 1, 5, 5]) + Jmust = Matrix(4, 4, [2, 1, 0, 0, + 0, 2, 0, 0, + 0, 0, 2, 1, + 0, 0, 0, 2] + ) + P, J = m.jordan_form() + assert Jmust == J + + m = Matrix(4, 4, [6, 2, -8, -6, -3, 2, 9, 6, 2, -2, -8, -6, -1, 0, 3, 4]) + # Jmust = Matrix(4, 4, [2, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0, -2]) + # same here see 1456ff + Jmust = Matrix(4, 4, [-2, 0, 0, 0, + 0, 2, 1, 0, + 0, 0, 2, 0, + 0, 0, 0, 2]) + P, J = m.jordan_form() + assert Jmust == J + + m = Matrix(4, 4, [5, 4, 2, 1, 0, 1, -1, -1, -1, -1, 3, 0, 1, 1, -1, 2]) + assert not m.is_diagonalizable() + Jmust = Matrix(4, 4, [1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 4, 1, 0, 0, 0, 4]) + P, J = m.jordan_form() + assert Jmust == J + + # checking for maximum precision to remain unchanged + m = Matrix([[Float('1.0', precision=110), Float('2.0', precision=110)], + [Float('3.14159265358979323846264338327', precision=110), Float('4.0', precision=110)]]) + P, J = m.jordan_form() + for term in J.values(): + if isinstance(term, Float): + assert term._prec == 110 + + +def test_jordan_form_complex_issue_9274(): + A = Matrix([[ 2, 4, 1, 0], + [-4, 2, 0, 1], + [ 0, 0, 2, 4], + [ 0, 0, -4, 2]]) + p = 2 - 4*I + q = 2 + 4*I + Jmust1 = Matrix([[p, 1, 0, 0], + [0, p, 0, 0], + [0, 0, q, 1], + [0, 0, 0, q]]) + Jmust2 = Matrix([[q, 1, 0, 0], + [0, q, 0, 0], + [0, 0, p, 1], + [0, 0, 0, p]]) + P, J = A.jordan_form() + assert J == Jmust1 or J == Jmust2 + assert simplify(P*J*P.inv()) == A + + +def test_issue_10220(): + # two non-orthogonal Jordan blocks with eigenvalue 1 + M = Matrix([[1, 0, 0, 1], + [0, 1, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1]]) + P, J = M.jordan_form() + assert P == Matrix([[0, 1, 0, 1], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0]]) + assert J == Matrix([ + [1, 1, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + + +def test_jordan_form_issue_15858(): + A = Matrix([ + [1, 1, 1, 0], + [-2, -1, 0, -1], + [0, 0, -1, -1], + [0, 0, 2, 1]]) + (P, J) = A.jordan_form() + assert P.expand() == Matrix([ + [ -I, -I/2, I, I/2], + [-1 + I, 0, -1 - I, 0], + [ 0, -S(1)/2 - I/2, 0, -S(1)/2 + I/2], + [ 0, 1, 0, 1]]) + assert J == Matrix([ + [-I, 1, 0, 0], + [0, -I, 0, 0], + [0, 0, I, 1], + [0, 0, 0, I]]) + + +def test_Matrix_berkowitz_charpoly(): + UA, K_i, K_w = symbols('UA K_i K_w') + + A = Matrix([[-K_i - UA + K_i**2/(K_i + K_w), K_i*K_w/(K_i + K_w)], + [ K_i*K_w/(K_i + K_w), -K_w + K_w**2/(K_i + K_w)]]) + + charpoly = A.charpoly(x) + + assert charpoly == \ + Poly(x**2 + (K_i*UA + K_w*UA + 2*K_i*K_w)/(K_i + K_w)*x + + K_i*K_w*UA/(K_i + K_w), x, domain='ZZ(K_i,K_w,UA)') + + assert type(charpoly) is PurePoly + + A = Matrix([[1, 3], [2, 0]]) + assert A.charpoly() == A.charpoly(x) == PurePoly(x**2 - x - 6) + + A = Matrix([[1, 2], [x, 0]]) + p = A.charpoly(x) + assert p.gen != x + assert p.as_expr().subs(p.gen, x) == x**2 - 3*x + + +def test_exp_jordan_block(): + l = Symbol('lamda') + + m = Matrix.jordan_block(1, l) + assert m._eval_matrix_exp_jblock() == Matrix([[exp(l)]]) + + m = Matrix.jordan_block(3, l) + assert m._eval_matrix_exp_jblock() == \ + Matrix([ + [exp(l), exp(l), exp(l)/2], + [0, exp(l), exp(l)], + [0, 0, exp(l)]]) + + +def test_exp(): + m = Matrix([[3, 4], [0, -2]]) + m_exp = Matrix([[exp(3), -4*exp(-2)/5 + 4*exp(3)/5], [0, exp(-2)]]) + assert m.exp() == m_exp + assert exp(m) == m_exp + + m = Matrix([[1, 0], [0, 1]]) + assert m.exp() == Matrix([[E, 0], [0, E]]) + assert exp(m) == Matrix([[E, 0], [0, E]]) + + m = Matrix([[1, -1], [1, 1]]) + assert m.exp() == Matrix([[E*cos(1), -E*sin(1)], [E*sin(1), E*cos(1)]]) + + +def test_log(): + l = Symbol('lamda') + + m = Matrix.jordan_block(1, l) + assert m._eval_matrix_log_jblock() == Matrix([[log(l)]]) + + m = Matrix.jordan_block(4, l) + assert m._eval_matrix_log_jblock() == \ + Matrix( + [ + [log(l), 1/l, -1/(2*l**2), 1/(3*l**3)], + [0, log(l), 1/l, -1/(2*l**2)], + [0, 0, log(l), 1/l], + [0, 0, 0, log(l)] + ] + ) + + m = Matrix( + [[0, 0, 1], + [0, 0, 0], + [-1, 0, 0]] + ) + raises(MatrixError, lambda: m.log()) + + +def test_find_reasonable_pivot_naive_finds_guaranteed_nonzero1(): + # Test if matrices._find_reasonable_pivot_naive() + # finds a guaranteed non-zero pivot when the + # some of the candidate pivots are symbolic expressions. + # Keyword argument: simpfunc=None indicates that no simplifications + # should be performed during the search. + x = Symbol('x') + column = Matrix(3, 1, [x, cos(x)**2 + sin(x)**2, S.Half]) + pivot_offset, pivot_val, pivot_assumed_nonzero, simplified =\ + _find_reasonable_pivot_naive(column) + assert pivot_val == S.Half + + +def test_find_reasonable_pivot_naive_finds_guaranteed_nonzero2(): + # Test if matrices._find_reasonable_pivot_naive() + # finds a guaranteed non-zero pivot when the + # some of the candidate pivots are symbolic expressions. + # Keyword argument: simpfunc=_simplify indicates that the search + # should attempt to simplify candidate pivots. + x = Symbol('x') + column = Matrix(3, 1, + [x, + cos(x)**2+sin(x)**2+x**2, + cos(x)**2+sin(x)**2]) + pivot_offset, pivot_val, pivot_assumed_nonzero, simplified =\ + _find_reasonable_pivot_naive(column, simpfunc=_simplify) + assert pivot_val == 1 + + +def test_find_reasonable_pivot_naive_simplifies(): + # Test if matrices._find_reasonable_pivot_naive() + # simplifies candidate pivots, and reports + # their offsets correctly. + x = Symbol('x') + column = Matrix(3, 1, + [x, + cos(x)**2+sin(x)**2+x, + cos(x)**2+sin(x)**2]) + pivot_offset, pivot_val, pivot_assumed_nonzero, simplified =\ + _find_reasonable_pivot_naive(column, simpfunc=_simplify) + + assert len(simplified) == 2 + assert simplified[0][0] == 1 + assert simplified[0][1] == 1+x + assert simplified[1][0] == 2 + assert simplified[1][1] == 1 + + +def test_errors(): + raises(ValueError, lambda: Matrix([[1, 2], [1]])) + raises(IndexError, lambda: Matrix([[1, 2]])[1.2, 5]) + raises(IndexError, lambda: Matrix([[1, 2]])[1, 5.2]) + raises(ValueError, lambda: randMatrix(3, c=4, symmetric=True)) + raises(ValueError, lambda: Matrix([1, 2]).reshape(4, 6)) + raises(ShapeError, + lambda: Matrix([[1, 2], [3, 4]]).copyin_matrix([1, 0], Matrix([1, 2]))) + raises(TypeError, lambda: Matrix([[1, 2], [3, 4]]).copyin_list([0, + 1], set())) + raises(NonSquareMatrixError, lambda: Matrix([[1, 2, 3], [2, 3, 0]]).inv()) + raises(ShapeError, + lambda: Matrix(1, 2, [1, 2]).row_join(Matrix([[1, 2], [3, 4]]))) + raises( + ShapeError, lambda: Matrix([1, 2]).col_join(Matrix([[1, 2], [3, 4]]))) + raises(ShapeError, lambda: Matrix([1]).row_insert(1, Matrix([[1, + 2], [3, 4]]))) + raises(ShapeError, lambda: Matrix([1]).col_insert(1, Matrix([[1, + 2], [3, 4]]))) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).trace()) + raises(TypeError, lambda: Matrix([1]).applyfunc(1)) + raises(ValueError, lambda: Matrix([[1, 2], [3, 4]]).minor(4, 5)) + raises(ValueError, lambda: Matrix([[1, 2], [3, 4]]).minor_submatrix(4, 5)) + raises(TypeError, lambda: Matrix([1, 2, 3]).cross(1)) + raises(TypeError, lambda: Matrix([1, 2, 3]).dot(1)) + raises(ShapeError, lambda: Matrix([1, 2, 3]).dot(Matrix([1, 2]))) + raises(ShapeError, lambda: Matrix([1, 2]).dot([])) + raises(TypeError, lambda: Matrix([1, 2]).dot('a')) + raises(ShapeError, lambda: Matrix([1, 2]).dot([1, 2, 3])) + raises(NonSquareMatrixError, lambda: Matrix([1, 2, 3]).exp()) + raises(ShapeError, lambda: Matrix([[1, 2], [3, 4]]).normalized()) + raises(ValueError, lambda: Matrix([1, 2]).inv(method='not a method')) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).inverse_GE()) + raises(ValueError, lambda: Matrix([[1, 2], [1, 2]]).inverse_GE()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).inverse_ADJ()) + raises(ValueError, lambda: Matrix([[1, 2], [1, 2]]).inverse_ADJ()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).inverse_LU()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).is_nilpotent()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).det()) + raises(ValueError, + lambda: Matrix([[1, 2], [3, 4]]).det(method='Not a real method')) + raises(ValueError, + lambda: Matrix([[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12], [13, 14, 15, 16]]).det(iszerofunc="Not function")) + raises(ValueError, + lambda: Matrix([[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12], [13, 14, 15, 16]]).det(iszerofunc=False)) + raises(ValueError, + lambda: hessian(Matrix([[1, 2], [3, 4]]), Matrix([[1, 2], [2, 1]]))) + raises(ValueError, lambda: hessian(Matrix([[1, 2], [3, 4]]), [])) + raises(ValueError, lambda: hessian(Symbol('x')**2, 'a')) + raises(IndexError, lambda: eye(3)[5, 2]) + raises(IndexError, lambda: eye(3)[2, 5]) + M = Matrix(((1, 2, 3, 4), (5, 6, 7, 8), (9, 10, 11, 12), (13, 14, 15, 16))) + raises(ValueError, lambda: M.det('method=LU_decomposition()')) + V = Matrix([[10, 10, 10]]) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(ValueError, lambda: M.row_insert(4.7, V)) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(ValueError, lambda: M.col_insert(-4.2, V)) + + +def test_len(): + assert len(Matrix()) == 0 + assert len(Matrix([[1, 2]])) == len(Matrix([[1], [2]])) == 2 + assert len(Matrix(0, 2, lambda i, j: 0)) == \ + len(Matrix(2, 0, lambda i, j: 0)) == 0 + assert len(Matrix([[0, 1, 2], [3, 4, 5]])) == 6 + assert Matrix([1]) == Matrix([[1]]) + assert not Matrix() + assert Matrix() == Matrix([]) + + +def test_integrate(): + A = Matrix(((1, 4, x), (y, 2, 4), (10, 5, x**2))) + assert A.integrate(x) == \ + Matrix(((x, 4*x, x**2/2), (x*y, 2*x, 4*x), (10*x, 5*x, x**3/3))) + assert A.integrate(y) == \ + Matrix(((y, 4*y, x*y), (y**2/2, 2*y, 4*y), (10*y, 5*y, y*x**2))) + m = Matrix(2, 1, [x, y]) + assert m.integrate(x) == Matrix(2, 1, [x**2/2, y*x]) + + +def test_diff(): + A = MutableDenseMatrix(((1, 4, x), (y, 2, 4), (10, 5, x**2 + 1))) + assert isinstance(A.diff(x), type(A)) + assert A.diff(x) == MutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert A.diff(y) == MutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + assert diff(A, x) == MutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert diff(A, y) == MutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + A_imm = A.as_immutable() + assert isinstance(A_imm.diff(x), type(A_imm)) + assert A_imm.diff(x) == ImmutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert A_imm.diff(y) == ImmutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + assert diff(A_imm, x) == ImmutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert diff(A_imm, y) == ImmutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + assert A.diff(x, evaluate=False) == ArrayDerivative(A, x, evaluate=False) + assert diff(A, x, evaluate=False) == ArrayDerivative(A, x, evaluate=False) + + +def test_diff_by_matrix(): + + # Derive matrix by matrix: + + A = MutableDenseMatrix([[x, y], [z, t]]) + assert A.diff(A) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + assert diff(A, A) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + A_imm = A.as_immutable() + assert A_imm.diff(A_imm) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + assert diff(A_imm, A_imm) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + # Derive a constant matrix: + assert A.diff(a) == MutableDenseMatrix([[0, 0], [0, 0]]) + + B = ImmutableDenseMatrix([a, b]) + assert A.diff(B) == Array.zeros(2, 1, 2, 2) + assert A.diff(A) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + # Test diff with tuples: + + dB = B.diff([[a, b]]) + assert dB.shape == (2, 2, 1) + assert dB == Array([[[1], [0]], [[0], [1]]]) + + f = Function("f") + fxyz = f(x, y, z) + assert fxyz.diff([[x, y, z]]) == Array([fxyz.diff(x), fxyz.diff(y), fxyz.diff(z)]) + assert fxyz.diff(([x, y, z], 2)) == Array([ + [fxyz.diff(x, 2), fxyz.diff(x, y), fxyz.diff(x, z)], + [fxyz.diff(x, y), fxyz.diff(y, 2), fxyz.diff(y, z)], + [fxyz.diff(x, z), fxyz.diff(z, y), fxyz.diff(z, 2)], + ]) + + expr = sin(x)*exp(y) + assert expr.diff([[x, y]]) == Array([cos(x)*exp(y), sin(x)*exp(y)]) + assert expr.diff(y, ((x, y),)) == Array([cos(x)*exp(y), sin(x)*exp(y)]) + assert expr.diff(x, ((x, y),)) == Array([-sin(x)*exp(y), cos(x)*exp(y)]) + assert expr.diff(((y, x),), [[x, y]]) == Array([[cos(x)*exp(y), -sin(x)*exp(y)], [sin(x)*exp(y), cos(x)*exp(y)]]) + + # Test different notations: + + assert fxyz.diff(x).diff(y).diff(x) == fxyz.diff(((x, y, z),), 3)[0, 1, 0] + assert fxyz.diff(z).diff(y).diff(x) == fxyz.diff(((x, y, z),), 3)[2, 1, 0] + assert fxyz.diff([[x, y, z]], ((z, y, x),)) == Array([[fxyz.diff(i).diff(j) for i in (x, y, z)] for j in (z, y, x)]) + + # Test scalar derived by matrix remains matrix: + res = x.diff(Matrix([[x, y]])) + assert isinstance(res, ImmutableDenseMatrix) + assert res == Matrix([[1, 0]]) + res = (x**3).diff(Matrix([[x, y]])) + assert isinstance(res, ImmutableDenseMatrix) + assert res == Matrix([[3*x**2, 0]]) + + +def test_getattr(): + A = Matrix(((1, 4, x), (y, 2, 4), (10, 5, x**2 + 1))) + raises(AttributeError, lambda: A.nonexistantattribute) + assert getattr(A, 'diff')(x) == Matrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + + +def test_hessenberg(): + A = Matrix([[3, 4, 1], [2, 4, 5], [0, 1, 2]]) + assert A.is_upper_hessenberg + A = A.T + assert A.is_lower_hessenberg + A[0, -1] = 1 + assert A.is_lower_hessenberg is False + + A = Matrix([[3, 4, 1], [2, 4, 5], [3, 1, 2]]) + assert not A.is_upper_hessenberg + + A = zeros(5, 2) + assert A.is_upper_hessenberg + + +def test_cholesky(): + raises(NonSquareMatrixError, lambda: Matrix((1, 2)).cholesky()) + raises(ValueError, lambda: Matrix(((1, 2), (3, 4))).cholesky()) + raises(ValueError, lambda: Matrix(((5 + I, 0), (0, 1))).cholesky()) + raises(ValueError, lambda: Matrix(((1, 5), (5, 1))).cholesky()) + raises(ValueError, lambda: Matrix(((1, 2), (3, 4))).cholesky(hermitian=False)) + assert Matrix(((5 + I, 0), (0, 1))).cholesky(hermitian=False) == Matrix([ + [sqrt(5 + I), 0], [0, 1]]) + A = Matrix(((1, 5), (5, 1))) + L = A.cholesky(hermitian=False) + assert L == Matrix([[1, 0], [5, 2*sqrt(6)*I]]) + assert L*L.T == A + A = Matrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L = A.cholesky() + assert L * L.T == A + assert L.is_lower + assert L == Matrix([[5, 0, 0], [3, 3, 0], [-1, 1, 3]]) + A = Matrix(((4, -2*I, 2 + 2*I), (2*I, 2, -1 + I), (2 - 2*I, -1 - I, 11))) + assert A.cholesky().expand() == Matrix(((2, 0, 0), (I, 1, 0), (1 - I, 0, 3))) + + raises(NonSquareMatrixError, lambda: SparseMatrix((1, 2)).cholesky()) + raises(ValueError, lambda: SparseMatrix(((1, 2), (3, 4))).cholesky()) + raises(ValueError, lambda: SparseMatrix(((5 + I, 0), (0, 1))).cholesky()) + raises(ValueError, lambda: SparseMatrix(((1, 5), (5, 1))).cholesky()) + raises(ValueError, lambda: SparseMatrix(((1, 2), (3, 4))).cholesky(hermitian=False)) + assert SparseMatrix(((5 + I, 0), (0, 1))).cholesky(hermitian=False) == Matrix([ + [sqrt(5 + I), 0], [0, 1]]) + A = SparseMatrix(((1, 5), (5, 1))) + L = A.cholesky(hermitian=False) + assert L == Matrix([[1, 0], [5, 2*sqrt(6)*I]]) + assert L*L.T == A + A = SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L = A.cholesky() + assert L * L.T == A + assert L.is_lower + assert L == Matrix([[5, 0, 0], [3, 3, 0], [-1, 1, 3]]) + A = SparseMatrix(((4, -2*I, 2 + 2*I), (2*I, 2, -1 + I), (2 - 2*I, -1 - I, 11))) + assert A.cholesky() == Matrix(((2, 0, 0), (I, 1, 0), (1 - I, 0, 3))) + + +def test_matrix_norm(): + # Vector Tests + # Test columns and symbols + x = Symbol('x', real=True) + v = Matrix([cos(x), sin(x)]) + assert trigsimp(v.norm(2)) == 1 + assert v.norm(10) == Pow(cos(x)**10 + sin(x)**10, Rational(1, 10)) + + # Test Rows + A = Matrix([[5, Rational(3, 2)]]) + assert A.norm() == Pow(25 + Rational(9, 4), S.Half) + assert A.norm(oo) == max(A) + assert A.norm(-oo) == min(A) + + # Matrix Tests + # Intuitive test + A = Matrix([[1, 1], [1, 1]]) + assert A.norm(2) == 2 + assert A.norm(-2) == 0 + assert A.norm('frobenius') == 2 + assert eye(10).norm(2) == eye(10).norm(-2) == 1 + assert A.norm(oo) == 2 + + # Test with Symbols and more complex entries + A = Matrix([[3, y, y], [x, S.Half, -pi]]) + assert (A.norm('fro') + == sqrt(Rational(37, 4) + 2*abs(y)**2 + pi**2 + x**2)) + + # Check non-square + A = Matrix([[1, 2, -3], [4, 5, Rational(13, 2)]]) + assert A.norm(2) == sqrt(Rational(389, 8) + sqrt(78665)/8) + assert A.norm(-2) is S.Zero + assert A.norm('frobenius') == sqrt(389)/2 + + # Test properties of matrix norms + # https://en.wikipedia.org/wiki/Matrix_norm#Definition + # Two matrices + A = Matrix([[1, 2], [3, 4]]) + B = Matrix([[5, 5], [-2, 2]]) + C = Matrix([[0, -I], [I, 0]]) + D = Matrix([[1, 0], [0, -1]]) + L = [A, B, C, D] + alpha = Symbol('alpha', real=True) + + for order in ['fro', 2, -2]: + # Zero Check + assert zeros(3).norm(order) is S.Zero + # Check Triangle Inequality for all Pairs of Matrices + for X in L: + for Y in L: + dif = (X.norm(order) + Y.norm(order) - + (X + Y).norm(order)) + assert (dif >= 0) + # Scalar multiplication linearity + for M in [A, B, C, D]: + dif = simplify((alpha*M).norm(order) - + abs(alpha) * M.norm(order)) + assert dif == 0 + + # Test Properties of Vector Norms + # https://en.wikipedia.org/wiki/Vector_norm + # Two column vectors + a = Matrix([1, 1 - 1*I, -3]) + b = Matrix([S.Half, 1*I, 1]) + c = Matrix([-1, -1, -1]) + d = Matrix([3, 2, I]) + e = Matrix([Integer(1e2), Rational(1, 1e2), 1]) + L = [a, b, c, d, e] + alpha = Symbol('alpha', real=True) + + for order in [1, 2, -1, -2, S.Infinity, S.NegativeInfinity, pi]: + # Zero Check + if order > 0: + assert Matrix([0, 0, 0]).norm(order) is S.Zero + # Triangle inequality on all pairs + if order >= 1: # Triangle InEq holds only for these norms + for X in L: + for Y in L: + dif = (X.norm(order) + Y.norm(order) - + (X + Y).norm(order)) + assert simplify(dif >= 0) is S.true + # Linear to scalar multiplication + if order in [1, 2, -1, -2, S.Infinity, S.NegativeInfinity]: + for X in L: + dif = simplify((alpha*X).norm(order) - + (abs(alpha) * X.norm(order))) + assert dif == 0 + + # ord=1 + M = Matrix(3, 3, [1, 3, 0, -2, -1, 0, 3, 9, 6]) + assert M.norm(1) == 13 + + +def test_condition_number(): + x = Symbol('x', real=True) + A = eye(3) + A[0, 0] = 10 + A[2, 2] = Rational(1, 10) + assert A.condition_number() == 100 + + A[1, 1] = x + assert A.condition_number() == Max(10, Abs(x)) / Min(Rational(1, 10), Abs(x)) + + M = Matrix([[cos(x), sin(x)], [-sin(x), cos(x)]]) + Mc = M.condition_number() + assert all(Float(1.).epsilon_eq(Mc.subs(x, val).evalf()) for val in + [Rational(1, 5), S.Half, Rational(1, 10), pi/2, pi, pi*Rational(7, 4) ]) + + #issue 10782 + assert Matrix([]).condition_number() == 0 + + +def test_equality(): + A = Matrix(((1, 2, 3), (4, 5, 6), (7, 8, 9))) + B = Matrix(((9, 8, 7), (6, 5, 4), (3, 2, 1))) + assert A == A[:, :] + assert not A != A[:, :] + assert not A == B + assert A != B + assert A != 10 + assert not A == 10 + + # A SparseMatrix can be equal to a Matrix + C = SparseMatrix(((1, 0, 0), (0, 1, 0), (0, 0, 1))) + D = Matrix(((1, 0, 0), (0, 1, 0), (0, 0, 1))) + assert C == D + assert not C != D + + +def test_normalized(): + assert Matrix([3, 4]).normalized() == \ + Matrix([Rational(3, 5), Rational(4, 5)]) + + # Zero vector trivial cases + assert Matrix([0, 0, 0]).normalized() == Matrix([0, 0, 0]) + + # Machine precision error truncation trivial cases + m = Matrix([0,0,1.e-100]) + assert m.normalized( + iszerofunc=lambda x: x.evalf(n=10, chop=True).is_zero + ) == Matrix([0, 0, 0]) + + +def test_print_nonzero(): + assert capture(lambda: eye(3).print_nonzero()) == \ + '[X ]\n[ X ]\n[ X]\n' + assert capture(lambda: eye(3).print_nonzero('.')) == \ + '[. ]\n[ . ]\n[ .]\n' + + +def test_zeros_eye(): + assert Matrix.eye(3) == eye(3) + assert Matrix.zeros(3) == zeros(3) + assert ones(3, 4) == Matrix(3, 4, [1]*12) + + i = Matrix([[1, 0], [0, 1]]) + z = Matrix([[0, 0], [0, 0]]) + for cls in all_classes: + m = cls.eye(2) + assert i == m # but m == i will fail if m is immutable + assert i == eye(2, cls=cls) + assert type(m) == cls + m = cls.zeros(2) + assert z == m + assert z == zeros(2, cls=cls) + assert type(m) == cls + + +def test_is_zero(): + assert Matrix().is_zero_matrix + assert Matrix([[0, 0], [0, 0]]).is_zero_matrix + assert zeros(3, 4).is_zero_matrix + assert not eye(3).is_zero_matrix + assert Matrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert SparseMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert ImmutableMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert ImmutableSparseMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert Matrix([[x, 1], [0, 0]]).is_zero_matrix == False + a = Symbol('a', nonzero=True) + assert Matrix([[a, 0], [0, 0]]).is_zero_matrix == False + + +def test_rotation_matrices(): + # This tests the rotation matrices by rotating about an axis and back. + theta = pi/3 + r3_plus = rot_axis3(theta) + r3_minus = rot_axis3(-theta) + r2_plus = rot_axis2(theta) + r2_minus = rot_axis2(-theta) + r1_plus = rot_axis1(theta) + r1_minus = rot_axis1(-theta) + assert r3_minus*r3_plus*eye(3) == eye(3) + assert r2_minus*r2_plus*eye(3) == eye(3) + assert r1_minus*r1_plus*eye(3) == eye(3) + + # Check the correctness of the trace of the rotation matrix + assert r1_plus.trace() == 1 + 2*cos(theta) + assert r2_plus.trace() == 1 + 2*cos(theta) + assert r3_plus.trace() == 1 + 2*cos(theta) + + # Check that a rotation with zero angle doesn't change anything. + assert rot_axis1(0) == eye(3) + assert rot_axis2(0) == eye(3) + assert rot_axis3(0) == eye(3) + + # Check left-hand convention + # see Issue #24529 + q1 = Quaternion.from_axis_angle([1, 0, 0], pi / 2) + q2 = Quaternion.from_axis_angle([0, 1, 0], pi / 2) + q3 = Quaternion.from_axis_angle([0, 0, 1], pi / 2) + assert rot_axis1(- pi / 2) == q1.to_rotation_matrix() + assert rot_axis2(- pi / 2) == q2.to_rotation_matrix() + assert rot_axis3(- pi / 2) == q3.to_rotation_matrix() + # Check right-hand convention + assert rot_ccw_axis1(+ pi / 2) == q1.to_rotation_matrix() + assert rot_ccw_axis2(+ pi / 2) == q2.to_rotation_matrix() + assert rot_ccw_axis3(+ pi / 2) == q3.to_rotation_matrix() + + +def test_DeferredVector(): + assert str(DeferredVector("vector")[4]) == "vector[4]" + assert sympify(DeferredVector("d")) == DeferredVector("d") + raises(IndexError, lambda: DeferredVector("d")[-1]) + assert str(DeferredVector("d")) == "d" + assert repr(DeferredVector("test")) == "DeferredVector('test')" + + +def test_DeferredVector_not_iterable(): + assert not iterable(DeferredVector('X')) + + +def test_DeferredVector_Matrix(): + raises(TypeError, lambda: Matrix(DeferredVector("V"))) + + +def test_GramSchmidt(): + R = Rational + m1 = Matrix(1, 2, [1, 2]) + m2 = Matrix(1, 2, [2, 3]) + assert GramSchmidt([m1, m2]) == \ + [Matrix(1, 2, [1, 2]), Matrix(1, 2, [R(2)/5, R(-1)/5])] + assert GramSchmidt([m1.T, m2.T]) == \ + [Matrix(2, 1, [1, 2]), Matrix(2, 1, [R(2)/5, R(-1)/5])] + # from wikipedia + assert GramSchmidt([Matrix([3, 1]), Matrix([2, 2])], True) == [ + Matrix([3*sqrt(10)/10, sqrt(10)/10]), + Matrix([-sqrt(10)/10, 3*sqrt(10)/10])] + # https://github.com/sympy/sympy/issues/9488 + L = FiniteSet(Matrix([1])) + assert GramSchmidt(L) == [Matrix([[1]])] + + +def test_casoratian(): + assert casoratian([1, 2, 3, 4], 1) == 0 + assert casoratian([1, 2, 3, 4], 1, zero=False) == 0 + + +def test_zero_dimension_multiply(): + assert (Matrix()*zeros(0, 3)).shape == (0, 3) + assert zeros(3, 0)*zeros(0, 3) == zeros(3, 3) + assert zeros(0, 3)*zeros(3, 0) == Matrix() + + +def test_slice_issue_2884(): + m = Matrix(2, 2, range(4)) + assert m[1, :] == Matrix([[2, 3]]) + assert m[-1, :] == Matrix([[2, 3]]) + assert m[:, 1] == Matrix([[1, 3]]).T + assert m[:, -1] == Matrix([[1, 3]]).T + raises(IndexError, lambda: m[2, :]) + raises(IndexError, lambda: m[2, 2]) + + +def test_slice_issue_3401(): + assert zeros(0, 3)[:, -1].shape == (0, 1) + assert zeros(3, 0)[0, :] == Matrix(1, 0, []) + + +def test_copyin(): + s = zeros(3, 3) + s[3] = 1 + assert s[:, 0] == Matrix([0, 1, 0]) + assert s[3] == 1 + assert s[3: 4] == [1] + s[1, 1] = 42 + assert s[1, 1] == 42 + assert s[1, 1:] == Matrix([[42, 0]]) + s[1, 1:] = Matrix([[5, 6]]) + assert s[1, :] == Matrix([[1, 5, 6]]) + s[1, 1:] = [[42, 43]] + assert s[1, :] == Matrix([[1, 42, 43]]) + s[0, 0] = 17 + assert s[:, :1] == Matrix([17, 1, 0]) + s[0, 0] = [1, 1, 1] + assert s[:, 0] == Matrix([1, 1, 1]) + s[0, 0] = Matrix([1, 1, 1]) + assert s[:, 0] == Matrix([1, 1, 1]) + s[0, 0] = SparseMatrix([1, 1, 1]) + assert s[:, 0] == Matrix([1, 1, 1]) + + +def test_invertible_check(): + # sometimes a singular matrix will have a pivot vector shorter than + # the number of rows in a matrix... + assert Matrix([[1, 2], [1, 2]]).rref() == (Matrix([[1, 2], [0, 0]]), (0,)) + raises(ValueError, lambda: Matrix([[1, 2], [1, 2]]).inv()) + m = Matrix([ + [-1, -1, 0], + [ x, 1, 1], + [ 1, x, -1], + ]) + assert len(m.rref()[1]) != m.rows + # in addition, unless simplify=True in the call to rref, the identity + # matrix will be returned even though m is not invertible + assert m.rref()[0] != eye(3) + assert m.rref(simplify=signsimp)[0] != eye(3) + raises(ValueError, lambda: m.inv(method="ADJ")) + raises(ValueError, lambda: m.inv(method="GE")) + raises(ValueError, lambda: m.inv(method="LU")) + + +def test_issue_3959(): + x, y = symbols('x, y') + e = x*y + assert e.subs(x, Matrix([3, 5, 3])) == Matrix([3, 5, 3])*y + + +def test_issue_5964(): + assert str(Matrix([[1, 2], [3, 4]])) == 'Matrix([[1, 2], [3, 4]])' + + +def test_issue_7604(): + x, y = symbols("x y") + assert sstr(Matrix([[x, 2*y], [y**2, x + 3]])) == \ + 'Matrix([\n[ x, 2*y],\n[y**2, x + 3]])' + + +def test_is_Identity(): + assert eye(3).is_Identity + assert eye(3).as_immutable().is_Identity + assert not zeros(3).is_Identity + assert not ones(3).is_Identity + # issue 6242 + assert not Matrix([[1, 0, 0]]).is_Identity + # issue 8854 + assert SparseMatrix(3,3, {(0,0):1, (1,1):1, (2,2):1}).is_Identity + assert not SparseMatrix(2,3, range(6)).is_Identity + assert not SparseMatrix(3,3, {(0,0):1, (1,1):1}).is_Identity + assert not SparseMatrix(3,3, {(0,0):1, (1,1):1, (2,2):1, (0,1):2, (0,2):3}).is_Identity + + +def test_dot(): + assert ones(1, 3).dot(ones(3, 1)) == 3 + assert ones(1, 3).dot([1, 1, 1]) == 3 + assert Matrix([1, 2, 3]).dot(Matrix([1, 2, 3])) == 14 + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I])) == -5 + I + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I]), hermitian=False) == -5 + I + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I]), hermitian=True) == 13 + I + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I]), hermitian=True, conjugate_convention="physics") == 13 - I + assert Matrix([1, 2, 3*I]).dot(Matrix([4, 5*I, 6]), hermitian=True, conjugate_convention="right") == 4 + 8*I + assert Matrix([1, 2, 3*I]).dot(Matrix([4, 5*I, 6]), hermitian=True, conjugate_convention="left") == 4 - 8*I + assert Matrix([I, 2*I]).dot(Matrix([I, 2*I]), hermitian=False, conjugate_convention="left") == -5 + assert Matrix([I, 2*I]).dot(Matrix([I, 2*I]), conjugate_convention="left") == 5 + raises(ValueError, lambda: Matrix([1, 2]).dot(Matrix([3, 4]), hermitian=True, conjugate_convention="test")) + + +def test_dual(): + B_x, B_y, B_z, E_x, E_y, E_z = symbols( + 'B_x B_y B_z E_x E_y E_z', real=True) + F = Matrix(( + ( 0, E_x, E_y, E_z), + (-E_x, 0, B_z, -B_y), + (-E_y, -B_z, 0, B_x), + (-E_z, B_y, -B_x, 0) + )) + Fd = Matrix(( + ( 0, -B_x, -B_y, -B_z), + (B_x, 0, E_z, -E_y), + (B_y, -E_z, 0, E_x), + (B_z, E_y, -E_x, 0) + )) + assert F.dual().equals(Fd) + assert eye(3).dual().equals(zeros(3)) + assert F.dual().dual().equals(-F) + + +def test_anti_symmetric(): + assert Matrix([1, 2]).is_anti_symmetric() is False + m = Matrix(3, 3, [0, x**2 + 2*x + 1, y, -(x + 1)**2, 0, x*y, -y, -x*y, 0]) + assert m.is_anti_symmetric() is True + assert m.is_anti_symmetric(simplify=False) is None + assert m.is_anti_symmetric(simplify=lambda x: x) is None + + # tweak to fail + m[2, 1] = -m[2, 1] + assert m.is_anti_symmetric() is None + # untweak + m[2, 1] = -m[2, 1] + + m = m.expand() + assert m.is_anti_symmetric(simplify=False) is True + m[0, 0] = 1 + assert m.is_anti_symmetric() is False + + +def test_normalize_sort_diogonalization(): + A = Matrix(((1, 2), (2, 1))) + P, Q = A.diagonalize(normalize=True) + assert P*P.T == P.T*P == eye(P.cols) + P, Q = A.diagonalize(normalize=True, sort=True) + assert P*P.T == P.T*P == eye(P.cols) + assert P*Q*P.inv() == A + + +def test_issue_5321(): + raises(ValueError, lambda: Matrix([[1, 2, 3], Matrix(0, 1, [])])) + + +def test_issue_5320(): + assert Matrix.hstack(eye(2), 2*eye(2)) == Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2] + ]) + assert Matrix.vstack(eye(2), 2*eye(2)) == Matrix([ + [1, 0], + [0, 1], + [2, 0], + [0, 2] + ]) + cls = SparseMatrix + assert cls.hstack(cls(eye(2)), cls(2*eye(2))) == Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2] + ]) + + +def test_issue_11944(): + A = Matrix([[1]]) + AIm = sympify(A) + assert Matrix.hstack(AIm, A) == Matrix([[1, 1]]) + assert Matrix.vstack(AIm, A) == Matrix([[1], [1]]) + + +def test_cross(): + a = [1, 2, 3] + b = [3, 4, 5] + col = Matrix([-2, 4, -2]) + row = col.T + + def test(M, ans): + assert ans == M + assert type(M) == cls + for cls in all_classes: + A = cls(a) + B = cls(b) + test(A.cross(B), col) + test(A.cross(B.T), col) + test(A.T.cross(B.T), row) + test(A.T.cross(B), row) + raises(ShapeError, lambda: + Matrix(1, 2, [1, 1]).cross(Matrix(1, 2, [1, 1]))) + + +def test_hat_vee(): + v1 = Matrix([x, y, z]) + v2 = Matrix([a, b, c]) + assert v1.hat() * v2 == v1.cross(v2) + assert v1.hat().is_anti_symmetric() + assert v1.hat().vee() == v1 + + +def test_hash(): + for cls in immutable_classes: + s = {cls.eye(1), cls.eye(1)} + assert len(s) == 1 and s.pop() == cls.eye(1) + # issue 3979 + for cls in mutable_classes: + assert not isinstance(cls.eye(1), Hashable) + + +def test_adjoint(): + dat = [[0, I], [1, 0]] + ans = Matrix([[0, 1], [-I, 0]]) + for cls in all_classes: + assert ans == cls(dat).adjoint() + + +def test_atoms(): + m = Matrix([[1, 2], [x, 1 - 1/x]]) + assert m.atoms() == {S.One,S(2),S.NegativeOne, x} + assert m.atoms(Symbol) == {x} + + +def test_pinv(): + # Pseudoinverse of an invertible matrix is the inverse. + A1 = Matrix([[a, b], [c, d]]) + assert simplify(A1.pinv(method="RD")) == simplify(A1.inv()) + + # Test the four properties of the pseudoinverse for various matrices. + As = [Matrix([[13, 104], [2212, 3], [-3, 5]]), + Matrix([[1, 7, 9], [11, 17, 19]]), + Matrix([a, b])] + + for A in As: + A_pinv = A.pinv(method="RD") + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + + # XXX Pinv with diagonalization makes expression too complicated. + for A in As: + A_pinv = simplify(A.pinv(method="ED")) + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + + # XXX Computing pinv using diagonalization makes an expression that + # is too complicated to simplify. + # A1 = Matrix([[a, b], [c, d]]) + # assert simplify(A1.pinv(method="ED")) == simplify(A1.inv()) + # so this is tested numerically at a fixed random point + + from sympy.core.numbers import comp + q = A1.pinv(method="ED") + w = A1.inv() + reps = {a: -73633, b: 11362, c: 55486, d: 62570} + assert all( + comp(i.n(), j.n()) + for i, j in zip(q.subs(reps), w.subs(reps)) + ) + + +@slow +def test_pinv_rank_deficient_when_diagonalization_fails(): + # Test the four properties of the pseudoinverse for matrices when + # diagonalization of A.H*A fails. + As = [ + Matrix([ + [61, 89, 55, 20, 71, 0], + [62, 96, 85, 85, 16, 0], + [69, 56, 17, 4, 54, 0], + [10, 54, 91, 41, 71, 0], + [ 7, 30, 10, 48, 90, 0], + [0, 0, 0, 0, 0, 0]]) + ] + for A in As: + A_pinv = A.pinv(method="ED") + AAp = A * A_pinv + ApA = A_pinv * A + assert AAp.H == AAp + + # Here ApA.H and ApA are equivalent expressions but they are very + # complicated expressions involving RootOfs. Using simplify would be + # too slow and so would evalf so we substitute approximate values for + # the RootOfs and then evalf which is less accurate but good enough to + # confirm that these two matrices are equivalent. + # + # assert ApA.H == ApA # <--- would fail (structural equality) + # assert simplify(ApA.H - ApA).is_zero_matrix # <--- too slow + # (ApA.H - ApA).evalf() # <--- too slow + + def allclose(M1, M2): + rootofs = M1.atoms(RootOf) + rootofs_approx = {r: r.evalf() for r in rootofs} + diff_approx = (M1 - M2).xreplace(rootofs_approx).evalf() + return all(abs(e) < 1e-10 for e in diff_approx) + + assert allclose(ApA.H, ApA) + + +def test_issue_7201(): + assert ones(0, 1) + ones(0, 1) == Matrix(0, 1, []) + assert ones(1, 0) + ones(1, 0) == Matrix(1, 0, []) + + +def test_free_symbols(): + for M in ImmutableMatrix, ImmutableSparseMatrix, Matrix, SparseMatrix: + assert M([[x], [0]]).free_symbols == {x} + + +def test_from_ndarray(): + """See issue 7465.""" + try: + from numpy import array + except ImportError: + skip('NumPy must be available to test creating matrices from ndarrays') + + assert Matrix(array([1, 2, 3])) == Matrix([1, 2, 3]) + assert Matrix(array([[1, 2, 3]])) == Matrix([[1, 2, 3]]) + assert Matrix(array([[1, 2, 3], [4, 5, 6]])) == \ + Matrix([[1, 2, 3], [4, 5, 6]]) + assert Matrix(array([x, y, z])) == Matrix([x, y, z]) + raises(NotImplementedError, + lambda: Matrix(array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]))) + assert Matrix([array([1, 2]), array([3, 4])]) == Matrix([[1, 2], [3, 4]]) + assert Matrix([array([1, 2]), [3, 4]]) == Matrix([[1, 2], [3, 4]]) + assert Matrix([array([]), array([])]) == Matrix(2, 0, []) != Matrix([]) + + +def test_17522_numpy(): + from sympy.matrices.common import _matrixify + try: + from numpy import array, matrix + except ImportError: + skip('NumPy must be available to test indexing matrixified NumPy ndarrays and matrices') + + m = _matrixify(array([[1, 2], [3, 4]])) + assert m[3] == 4 + assert list(m) == [1, 2, 3, 4] + + with ignore_warnings(PendingDeprecationWarning): + m = _matrixify(matrix([[1, 2], [3, 4]])) + assert m[3] == 4 + assert list(m) == [1, 2, 3, 4] + + +def test_17522_mpmath(): + from sympy.matrices.common import _matrixify + try: + from mpmath import matrix + except ImportError: + skip('mpmath must be available to test indexing matrixified mpmath matrices') + + m = _matrixify(matrix([[1, 2], [3, 4]])) + assert m[3] == 4.0 + assert list(m) == [1.0, 2.0, 3.0, 4.0] + + +def test_17522_scipy(): + from sympy.matrices.common import _matrixify + try: + from scipy.sparse import csr_matrix + except ImportError: + skip('SciPy must be available to test indexing matrixified SciPy sparse matrices') + + m = _matrixify(csr_matrix([[1, 2], [3, 4]])) + assert m[3] == 4 + assert list(m) == [1, 2, 3, 4] + + +def test_hermitian(): + a = Matrix([[1, I], [-I, 1]]) + assert a.is_hermitian + a[0, 0] = 2*I + assert a.is_hermitian is False + a[0, 0] = x + assert a.is_hermitian is None + a[0, 1] = a[1, 0]*I + assert a.is_hermitian is False + + +def test_issue_9457_9467_9876(): + # for row_del(index) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + M.row_del(1) + assert M == Matrix([[1, 2, 3], [3, 4, 5]]) + N = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + N.row_del(-2) + assert N == Matrix([[1, 2, 3], [3, 4, 5]]) + O = Matrix([[1, 2, 3], [5, 6, 7], [9, 10, 11]]) + O.row_del(-1) + assert O == Matrix([[1, 2, 3], [5, 6, 7]]) + P = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: P.row_del(10)) + Q = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: Q.row_del(-10)) + + # for col_del(index) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + M.col_del(1) + assert M == Matrix([[1, 3], [2, 4], [3, 5]]) + N = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + N.col_del(-2) + assert N == Matrix([[1, 3], [2, 4], [3, 5]]) + P = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: P.col_del(10)) + Q = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: Q.col_del(-10)) + + +def test_issue_9422(): + x, y = symbols('x y', commutative=False) + a, b = symbols('a b') + M = eye(2) + M1 = Matrix(2, 2, [x, y, y, z]) + assert y*x*M != x*y*M + assert b*a*M == a*b*M + assert x*M1 != M1*x + assert a*M1 == M1*a + assert y*x*M == Matrix([[y*x, 0], [0, y*x]]) + + +def test_issue_10770(): + M = Matrix([]) + a = ['col_insert', 'row_join'], Matrix([9, 6, 3]) + b = ['row_insert', 'col_join'], a[1].T + c = ['row_insert', 'col_insert'], Matrix([[1, 2], [3, 4]]) + for ops, m in (a, b, c): + for op in ops: + f = getattr(M, op) + new = f(m) if 'join' in op else f(42, m) + assert new == m and id(new) != id(m) + + +def test_issue_10658(): + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert A.extract([0, 1, 2], [True, True, False]) == \ + Matrix([[1, 2], [4, 5], [7, 8]]) + assert A.extract([0, 1, 2], [True, False, False]) == Matrix([[1], [4], [7]]) + assert A.extract([True, False, False], [0, 1, 2]) == Matrix([[1, 2, 3]]) + assert A.extract([True, False, True], [0, 1, 2]) == \ + Matrix([[1, 2, 3], [7, 8, 9]]) + assert A.extract([0, 1, 2], [False, False, False]) == Matrix(3, 0, []) + assert A.extract([False, False, False], [0, 1, 2]) == Matrix(0, 3, []) + assert A.extract([True, False, True], [False, True, False]) == \ + Matrix([[2], [8]]) + + +def test_opportunistic_simplification(): + # this test relates to issue #10718, #9480, #11434 + + # issue #9480 + m = Matrix([[-5 + 5*sqrt(2), -5], [-5*sqrt(2)/2 + 5, -5*sqrt(2)/2]]) + assert m.rank() == 1 + + # issue #10781 + m = Matrix([[3+3*sqrt(3)*I, -9],[4,-3+3*sqrt(3)*I]]) + assert simplify(m.rref()[0] - Matrix([[1, -9/(3 + 3*sqrt(3)*I)], [0, 0]])) == zeros(2, 2) + + # issue #11434 + ax,ay,bx,by,cx,cy,dx,dy,ex,ey,t0,t1 = symbols('a_x a_y b_x b_y c_x c_y d_x d_y e_x e_y t_0 t_1') + m = Matrix([[ax,ay,ax*t0,ay*t0,0],[bx,by,bx*t0,by*t0,0],[cx,cy,cx*t0,cy*t0,1],[dx,dy,dx*t0,dy*t0,1],[ex,ey,2*ex*t1-ex*t0,2*ey*t1-ey*t0,0]]) + assert m.rank() == 4 + + +def test_partial_pivoting(): + # example from https://en.wikipedia.org/wiki/Pivot_element + # partial pivoting with back substitution gives a perfect result + # naive pivoting give an error ~1e-13, so anything better than + # 1e-15 is good + mm=Matrix([[0.003, 59.14, 59.17], [5.291, -6.13, 46.78]]) + assert (mm.rref()[0] - Matrix([[1.0, 0, 10.0], + [ 0, 1.0, 1.0]])).norm() < 1e-15 + + # issue #11549 + m_mixed = Matrix([[6e-17, 1.0, 4], + [ -1.0, 0, 8], + [ 0, 0, 1]]) + m_float = Matrix([[6e-17, 1.0, 4.], + [ -1.0, 0., 8.], + [ 0., 0., 1.]]) + m_inv = Matrix([[ 0, -1.0, 8.0], + [1.0, 6.0e-17, -4.0], + [ 0, 0, 1]]) + # this example is numerically unstable and involves a matrix with a norm >= 8, + # this comparing the difference of the results with 1e-15 is numerically sound. + assert (m_mixed.inv() - m_inv).norm() < 1e-15 + assert (m_float.inv() - m_inv).norm() < 1e-15 + + +def test_iszero_substitution(): + """ When doing numerical computations, all elements that pass + the iszerofunc test should be set to numerically zero if they + aren't already. """ + + # Matrix from issue #9060 + m = Matrix([[0.9, -0.1, -0.2, 0],[-0.8, 0.9, -0.4, 0],[-0.1, -0.8, 0.6, 0]]) + m_rref = m.rref(iszerofunc=lambda x: abs(x)<6e-15)[0] + m_correct = Matrix([[1.0, 0, -0.301369863013699, 0],[ 0, 1.0, -0.712328767123288, 0],[ 0, 0, 0, 0]]) + m_diff = m_rref - m_correct + assert m_diff.norm() < 1e-15 + # if a zero-substitution wasn't made, this entry will be -1.11022302462516e-16 + assert m_rref[2,2] == 0 + + +def test_issue_11238(): + from sympy.geometry.point import Point + xx = 8*tan(pi*Rational(13, 45))/(tan(pi*Rational(13, 45)) + sqrt(3)) + yy = (-8*sqrt(3)*tan(pi*Rational(13, 45))**2 + 24*tan(pi*Rational(13, 45)))/(-3 + tan(pi*Rational(13, 45))**2) + p1 = Point(0, 0) + p2 = Point(1, -sqrt(3)) + p0 = Point(xx,yy) + m1 = Matrix([p1 - simplify(p0), p2 - simplify(p0)]) + m2 = Matrix([p1 - p0, p2 - p0]) + m3 = Matrix([simplify(p1 - p0), simplify(p2 - p0)]) + + # This system has expressions which are zero and + # cannot be easily proved to be such, so without + # numerical testing, these assertions will fail. + Z = lambda x: abs(x.n()) < 1e-20 + assert m1.rank(simplify=True, iszerofunc=Z) == 1 + assert m2.rank(simplify=True, iszerofunc=Z) == 1 + assert m3.rank(simplify=True, iszerofunc=Z) == 1 + + +def test_as_real_imag(): + m1 = Matrix(2,2,[1,2,3,4]) + m2 = m1*S.ImaginaryUnit + m3 = m1 + m2 + + for kls in all_classes: + a,b = kls(m3).as_real_imag() + assert list(a) == list(m1) + assert list(b) == list(m1) + + +def test_deprecated(): + # Maintain tests for deprecated functions. We must capture + # the deprecation warnings. When the deprecated functionality is + # removed, the corresponding tests should be removed. + + m = Matrix(3, 3, [0, 1, 0, -4, 4, 0, -2, 1, 2]) + P, Jcells = m.jordan_cells() + assert Jcells[1] == Matrix(1, 1, [2]) + assert Jcells[0] == Matrix(2, 2, [2, 1, 0, 2]) + + +def test_issue_14489(): + from sympy.core.mod import Mod + A = Matrix([-1, 1, 2]) + B = Matrix([10, 20, -15]) + + assert Mod(A, 3) == Matrix([2, 1, 2]) + assert Mod(B, 4) == Matrix([2, 0, 1]) + + +def test_issue_14943(): + # Test that __array__ accepts the optional dtype argument + try: + from numpy import array + except ImportError: + skip('NumPy must be available to test creating matrices from ndarrays') + + M = Matrix([[1,2], [3,4]]) + assert array(M, dtype=float).dtype.name == 'float64' + + +def test_case_6913(): + m = MatrixSymbol('m', 1, 1) + a = Symbol("a") + a = m[0, 0]>0 + assert str(a) == 'm[0, 0] > 0' + + +def test_issue_11948(): + A = MatrixSymbol('A', 3, 3) + a = Wild('a') + assert A.match(a) == {a: A} + + +def test_gramschmidt_conjugate_dot(): + vecs = [Matrix([1, I]), Matrix([1, -I])] + assert Matrix.orthogonalize(*vecs) == \ + [Matrix([[1], [I]]), Matrix([[1], [-I]])] + + vecs = [Matrix([1, I, 0]), Matrix([I, 0, -I])] + assert Matrix.orthogonalize(*vecs) == \ + [Matrix([[1], [I], [0]]), Matrix([[I/2], [S(1)/2], [-I]])] + + mat = Matrix([[1, I], [1, -I]]) + Q, R = mat.QRdecomposition() + assert Q * Q.H == Matrix.eye(2) + + +def test_issue_8207(): + a = Matrix(MatrixSymbol('a', 3, 1)) + b = Matrix(MatrixSymbol('b', 3, 1)) + c = a.dot(b) + d = diff(c, a[0, 0]) + e = diff(d, a[0, 0]) + assert d == b[0, 0] + assert e == 0 + + +def test_func(): + from sympy.simplify.simplify import nthroot + + A = Matrix([[1, 2],[0, 3]]) + assert A.analytic_func(sin(x*t), x) == Matrix([[sin(t), sin(3*t) - sin(t)], [0, sin(3*t)]]) + + A = Matrix([[2, 1],[1, 2]]) + assert (pi * A / 6).analytic_func(cos(x), x) == Matrix([[sqrt(3)/4, -sqrt(3)/4], [-sqrt(3)/4, sqrt(3)/4]]) + + + raises(ValueError, lambda : zeros(5).analytic_func(log(x), x)) + raises(ValueError, lambda : (A*x).analytic_func(log(x), x)) + + A = Matrix([[0, -1, -2, 3], [0, -1, -2, 3], [0, 1, 0, -1], [0, 0, -1, 1]]) + assert A.analytic_func(exp(x), x) == A.exp() + raises(ValueError, lambda : A.analytic_func(sqrt(x), x)) + + A = Matrix([[41, 12],[12, 34]]) + assert simplify(A.analytic_func(sqrt(x), x)**2) == A + + A = Matrix([[3, -12, 4], [-1, 0, -2], [-1, 5, -1]]) + assert simplify(A.analytic_func(nthroot(x, 3), x)**3) == A + + A = Matrix([[2, 0, 0, 0], [1, 2, 0, 0], [0, 1, 3, 0], [0, 0, 1, 3]]) + assert A.analytic_func(exp(x), x) == A.exp() + + A = Matrix([[0, 2, 1, 6], [0, 0, 1, 2], [0, 0, 0, 3], [0, 0, 0, 0]]) + assert A.analytic_func(exp(x*t), x) == expand(simplify((A*t).exp())) + + +@skip_under_pyodide("Cannot create threads under pyodide.") +def test_issue_19809(): + + def f(): + assert _dotprodsimp_state.state == None + m = Matrix([[1]]) + m = m * m + return True + + with dotprodsimp(True): + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(f) + assert future.result() + + +def test_issue_23276(): + M = Matrix([x, y]) + assert integrate(M, (x, 0, 1), (y, 0, 1)) == Matrix([ + [S.Half], + [S.Half]]) + + +def test_issue_27225(): + # https://github.com/sympy/sympy/issues/27225 + raises(TypeError, lambda : floor(Matrix([1, 1, 0]))) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_normalforms.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_normalforms.py new file mode 100644 index 0000000000000000000000000000000000000000..47ee52d73539f7fb79295443e1cf7e0a49e30a5e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_normalforms.py @@ -0,0 +1,111 @@ +from sympy.testing.pytest import warns_deprecated_sympy + +from sympy.core.symbol import Symbol +from sympy.polys.polytools import Poly +from sympy.matrices import Matrix, randMatrix +from sympy.matrices.normalforms import ( + invariant_factors, + smith_normal_form, + smith_normal_decomp, + hermite_normal_form, + is_smith_normal_form, +) +from sympy.polys.domains import ZZ, QQ +from sympy.core.numbers import Integer + +import random + + +def test_smith_normal(): + m = Matrix([[12,6,4,8],[3,9,6,12],[2,16,14,28],[20,10,10,20]]) + smf = Matrix([[1, 0, 0, 0], [0, 10, 0, 0], [0, 0, 30, 0], [0, 0, 0, 0]]) + assert smith_normal_form(m) == smf + + a, s, t = smith_normal_decomp(m) + assert a == s * m * t + + x = Symbol('x') + with warns_deprecated_sympy(): + m = Matrix([[Poly(x-1), Poly(1, x),Poly(-1,x)], + [0, Poly(x), Poly(-1,x)], + [Poly(0,x),Poly(-1,x),Poly(x)]]) + invs = 1, x - 1, x**2 - 1 + assert invariant_factors(m, domain=QQ[x]) == invs + + m = Matrix([[2, 4]]) + smf = Matrix([[2, 0]]) + assert smith_normal_form(m) == smf + + prng = random.Random(0) + for i in range(6): + for j in range(6): + for _ in range(10 if i*j else 1): + m = randMatrix(i, j, max=5, percent=50, prng=prng) + a, s, t = smith_normal_decomp(m) + assert a == s * m * t + assert is_smith_normal_form(a) + s.inv().to_DM(ZZ) + t.inv().to_DM(ZZ) + + a, s, t = smith_normal_decomp(m, QQ) + assert a == s * m * t + assert is_smith_normal_form(a) + s.inv() + t.inv() + + +def test_smith_normal_deprecated(): + from sympy.polys.solvers import RawMatrix as Matrix + + with warns_deprecated_sympy(): + m = Matrix([[12, 6, 4,8],[3,9,6,12],[2,16,14,28],[20,10,10,20]]) + setattr(m, 'ring', ZZ) + with warns_deprecated_sympy(): + smf = Matrix([[1, 0, 0, 0], [0, 10, 0, 0], [0, 0, 30, 0], [0, 0, 0, 0]]) + assert smith_normal_form(m) == smf + + x = Symbol('x') + with warns_deprecated_sympy(): + m = Matrix([[Poly(x-1), Poly(1, x),Poly(-1,x)], + [0, Poly(x), Poly(-1,x)], + [Poly(0,x),Poly(-1,x),Poly(x)]]) + setattr(m, 'ring', QQ[x]) + invs = (Poly(1, x, domain='QQ'), Poly(x - 1, domain='QQ'), Poly(x**2 - 1, domain='QQ')) + assert invariant_factors(m) == invs + + with warns_deprecated_sympy(): + m = Matrix([[2, 4]]) + setattr(m, 'ring', ZZ) + with warns_deprecated_sympy(): + smf = Matrix([[2, 0]]) + assert smith_normal_form(m) == smf + + +def test_hermite_normal(): + m = Matrix([[2, 7, 17, 29, 41], [3, 11, 19, 31, 43], [5, 13, 23, 37, 47]]) + hnf = Matrix([[1, 0, 0], [0, 2, 1], [0, 0, 1]]) + assert hermite_normal_form(m) == hnf + + tr_hnf = Matrix([[37, 0, 19], [222, -6, 113], [48, 0, 25], [0, 2, 1], [0, 0, 1]]) + assert hermite_normal_form(m.transpose()) == tr_hnf + + m = Matrix([[8, 28, 68, 116, 164], [3, 11, 19, 31, 43], [5, 13, 23, 37, 47]]) + hnf = Matrix([[4, 0, 0], [0, 2, 1], [0, 0, 1]]) + assert hermite_normal_form(m) == hnf + assert hermite_normal_form(m, D=8) == hnf + assert hermite_normal_form(m, D=ZZ(8)) == hnf + assert hermite_normal_form(m, D=Integer(8)) == hnf + + m = Matrix([[10, 8, 6, 30, 2], [45, 36, 27, 18, 9], [5, 4, 3, 2, 1]]) + hnf = Matrix([[26, 2], [0, 9], [0, 1]]) + assert hermite_normal_form(m) == hnf + + m = Matrix([[2, 7], [0, 0], [0, 0]]) + hnf = Matrix([[1], [0], [0]]) + assert hermite_normal_form(m) == hnf + + +def test_issue_23410(): + A = Matrix([[1, 12], [0, 8], [0, 5]]) + H = Matrix([[1, 0], [0, 8], [0, 5]]) + assert hermite_normal_form(A) == H diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_reductions.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..32c98c6f249b1afafc8193f4248dc9493bb803e0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_reductions.py @@ -0,0 +1,351 @@ +from sympy.core.numbers import I +from sympy.core.symbol import symbols +from sympy.testing.pytest import raises +from sympy.matrices import Matrix, zeros, eye +from sympy.core.symbol import Symbol +from sympy.core.numbers import Rational +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.simplify.simplify import simplify +from sympy.abc import x + + +# Matrix tests +def test_row_op(): + e = eye(3) + + raises(ValueError, lambda: e.elementary_row_op("abc")) + raises(ValueError, lambda: e.elementary_row_op()) + raises(ValueError, lambda: e.elementary_row_op('n->kn', row=5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->kn', row=-5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=1, row2=5)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=5, row2=1)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=-5, row2=1)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=1, row2=-5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=1, row2=5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=5, row2=1, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=-5, row2=1, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=1, row2=-5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=1, row2=1, k=5)) + + # test various ways to set arguments + assert e.elementary_row_op("n->kn", 0, 5) == Matrix([[5, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->kn", 1, 5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->kn", row=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->kn", row1=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_row_op("n<->m", 0, 1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_row_op("n<->m", row1=0, row2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_row_op("n<->m", row=0, row2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->n+km", 0, 5, 1) == Matrix([[1, 5, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->n+km", row=0, k=5, row2=1) == Matrix([[1, 5, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->n+km", row1=0, k=5, row2=1) == Matrix([[1, 5, 0], [0, 1, 0], [0, 0, 1]]) + + # make sure the matrix doesn't change size + a = Matrix(2, 3, [0]*6) + assert a.elementary_row_op("n->kn", 1, 5) == Matrix(2, 3, [0]*6) + assert a.elementary_row_op("n<->m", 0, 1) == Matrix(2, 3, [0]*6) + assert a.elementary_row_op("n->n+km", 0, 5, 1) == Matrix(2, 3, [0]*6) + + +def test_col_op(): + e = eye(3) + + raises(ValueError, lambda: e.elementary_col_op("abc")) + raises(ValueError, lambda: e.elementary_col_op()) + raises(ValueError, lambda: e.elementary_col_op('n->kn', col=5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->kn', col=-5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=1, col2=5)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=5, col2=1)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=-5, col2=1)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=1, col2=-5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=1, col2=5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=5, col2=1, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=-5, col2=1, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=1, col2=-5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=1, col2=1, k=5)) + + # test various ways to set arguments + assert e.elementary_col_op("n->kn", 0, 5) == Matrix([[5, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->kn", 1, 5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->kn", col=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->kn", col1=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_col_op("n<->m", 0, 1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_col_op("n<->m", col1=0, col2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_col_op("n<->m", col=0, col2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->n+km", 0, 5, 1) == Matrix([[1, 0, 0], [5, 1, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->n+km", col=0, k=5, col2=1) == Matrix([[1, 0, 0], [5, 1, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->n+km", col1=0, k=5, col2=1) == Matrix([[1, 0, 0], [5, 1, 0], [0, 0, 1]]) + + # make sure the matrix doesn't change size + a = Matrix(2, 3, [0]*6) + assert a.elementary_col_op("n->kn", 1, 5) == Matrix(2, 3, [0]*6) + assert a.elementary_col_op("n<->m", 0, 1) == Matrix(2, 3, [0]*6) + assert a.elementary_col_op("n->n+km", 0, 5, 1) == Matrix(2, 3, [0]*6) + + +def test_is_echelon(): + zro = zeros(3) + ident = eye(3) + + assert zro.is_echelon + assert ident.is_echelon + + a = Matrix(0, 0, []) + assert a.is_echelon + + a = Matrix(2, 3, [3, 2, 1, 0, 0, 6]) + assert a.is_echelon + + a = Matrix(2, 3, [0, 0, 6, 3, 2, 1]) + assert not a.is_echelon + + x = Symbol('x') + a = Matrix(3, 1, [x, 0, 0]) + assert a.is_echelon + + a = Matrix(3, 1, [x, x, 0]) + assert not a.is_echelon + + a = Matrix(3, 3, [0, 0, 0, 1, 2, 3, 0, 0, 0]) + assert not a.is_echelon + + +def test_echelon_form(): + # echelon form is not unique, but the result + # must be row-equivalent to the original matrix + # and it must be in echelon form. + + a = zeros(3) + e = eye(3) + + # we can assume the zero matrix and the identity matrix shouldn't change + assert a.echelon_form() == a + assert e.echelon_form() == e + + a = Matrix(0, 0, []) + assert a.echelon_form() == a + + a = Matrix(1, 1, [5]) + assert a.echelon_form() == a + + # now we get to the real tests + + def verify_row_null_space(mat, rows, nulls): + for v in nulls: + assert all(t.is_zero for t in a_echelon*v) + for v in rows: + if not all(t.is_zero for t in v): + assert not all(t.is_zero for t in a_echelon*v.transpose()) + + a = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + nulls = [Matrix([ + [ 1], + [-2], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + + a = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 8]) + nulls = [] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + a = Matrix(3, 3, [2, 1, 3, 0, 0, 0, 2, 1, 3]) + nulls = [Matrix([ + [Rational(-1, 2)], + [ 1], + [ 0]]), + Matrix([ + [Rational(-3, 2)], + [ 0], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + # this one requires a row swap + a = Matrix(3, 3, [2, 1, 3, 0, 0, 0, 1, 1, 3]) + nulls = [Matrix([ + [ 0], + [ -3], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + a = Matrix(3, 3, [0, 3, 3, 0, 2, 2, 0, 1, 1]) + nulls = [Matrix([ + [1], + [0], + [0]]), + Matrix([ + [ 0], + [-1], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + a = Matrix(2, 3, [2, 2, 3, 3, 3, 0]) + nulls = [Matrix([ + [-1], + [1], + [0]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + +def test_rref(): + e = Matrix(0, 0, []) + assert e.rref(pivots=False) == e + + e = Matrix(1, 1, [1]) + a = Matrix(1, 1, [5]) + assert e.rref(pivots=False) == a.rref(pivots=False) == e + + a = Matrix(3, 1, [1, 2, 3]) + assert a.rref(pivots=False) == Matrix([[1], [0], [0]]) + + a = Matrix(1, 3, [1, 2, 3]) + assert a.rref(pivots=False) == Matrix([[1, 2, 3]]) + + a = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + assert a.rref(pivots=False) == Matrix([ + [1, 0, -1], + [0, 1, 2], + [0, 0, 0]]) + + a = Matrix(3, 3, [1, 2, 3, 1, 2, 3, 1, 2, 3]) + b = Matrix(3, 3, [1, 2, 3, 0, 0, 0, 0, 0, 0]) + c = Matrix(3, 3, [0, 0, 0, 1, 2, 3, 0, 0, 0]) + d = Matrix(3, 3, [0, 0, 0, 0, 0, 0, 1, 2, 3]) + assert a.rref(pivots=False) == \ + b.rref(pivots=False) == \ + c.rref(pivots=False) == \ + d.rref(pivots=False) == b + + e = eye(3) + z = zeros(3) + assert e.rref(pivots=False) == e + assert z.rref(pivots=False) == z + + a = Matrix([ + [ 0, 0, 1, 2, 2, -5, 3], + [-1, 5, 2, 2, 1, -7, 5], + [ 0, 0, -2, -3, -3, 8, -5], + [-1, 5, 0, -1, -2, 1, 0]]) + mat, pivot_offsets = a.rref() + assert mat == Matrix([ + [1, -5, 0, 0, 1, 1, -1], + [0, 0, 1, 0, 0, -1, 1], + [0, 0, 0, 1, 1, -2, 1], + [0, 0, 0, 0, 0, 0, 0]]) + assert pivot_offsets == (0, 2, 3) + + a = Matrix([[Rational(1, 19), Rational(1, 5), 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11], + [ 12, 13, 14, 15]]) + assert a.rref(pivots=False) == Matrix([ + [1, 0, 0, Rational(-76, 157)], + [0, 1, 0, Rational(-5, 157)], + [0, 0, 1, Rational(238, 157)], + [0, 0, 0, 0]]) + + x = Symbol('x') + a = Matrix(2, 3, [x, 1, 1, sqrt(x), x, 1]) + for i, j in zip(a.rref(pivots=False), + [1, 0, sqrt(x)*(-x + 1)/(-x**Rational(5, 2) + x), + 0, 1, 1/(sqrt(x) + x + 1)]): + assert simplify(i - j).is_zero + + +def test_rref_rhs(): + a, b, c, d = symbols('a b c d') + A = Matrix([[0, 0], [0, 0], [1, 2], [3, 4]]) + B = Matrix([a, b, c, d]) + assert A.rref_rhs(B) == (Matrix([ + [1, 0], + [0, 1], + [0, 0], + [0, 0]]), Matrix([ + [ -2*c + d], + [3*c/2 - d/2], + [ a], + [ b]])) + + +def test_issue_17827(): + C = Matrix([ + [3, 4, -1, 1], + [9, 12, -3, 3], + [0, 2, 1, 3], + [2, 3, 0, -2], + [0, 3, 3, -5], + [8, 15, 0, 6] + ]) + # Tests for row/col within valid range + D = C.elementary_row_op('n<->m', row1=2, row2=5) + E = C.elementary_row_op('n->n+km', row1=5, row2=3, k=-4) + F = C.elementary_row_op('n->kn', row=5, k=2) + assert(D[5, :] == Matrix([[0, 2, 1, 3]])) + assert(E[5, :] == Matrix([[0, 3, 0, 14]])) + assert(F[5, :] == Matrix([[16, 30, 0, 12]])) + # Tests for row/col out of range + raises(ValueError, lambda: C.elementary_row_op('n<->m', row1=2, row2=6)) + raises(ValueError, lambda: C.elementary_row_op('n->kn', row=7, k=2)) + raises(ValueError, lambda: C.elementary_row_op('n->n+km', row1=-1, row2=5, k=2)) + +def test_rank(): + m = Matrix([[1, 2], [x, 1 - 1/x]]) + assert m.rank() == 2 + n = Matrix(3, 3, range(1, 10)) + assert n.rank() == 2 + p = zeros(3) + assert p.rank() == 0 + +def test_issue_11434(): + ax, ay, bx, by, cx, cy, dx, dy, ex, ey, t0, t1 = \ + symbols('a_x a_y b_x b_y c_x c_y d_x d_y e_x e_y t_0 t_1') + M = Matrix([[ax, ay, ax*t0, ay*t0, 0], + [bx, by, bx*t0, by*t0, 0], + [cx, cy, cx*t0, cy*t0, 1], + [dx, dy, dx*t0, dy*t0, 1], + [ex, ey, 2*ex*t1 - ex*t0, 2*ey*t1 - ey*t0, 0]]) + assert M.rank() == 4 + +def test_rank_regression_from_so(): + # see: + # https://stackoverflow.com/questions/19072700/why-does-sympy-give-me-the-wrong-answer-when-i-row-reduce-a-symbolic-matrix + + nu, lamb = symbols('nu, lambda') + A = Matrix([[-3*nu, 1, 0, 0], + [ 3*nu, -2*nu - 1, 2, 0], + [ 0, 2*nu, (-1*nu) - lamb - 2, 3], + [ 0, 0, nu + lamb, -3]]) + expected_reduced = Matrix([[1, 0, 0, 1/(nu**2*(-lamb - nu))], + [0, 1, 0, 3/(nu*(-lamb - nu))], + [0, 0, 1, 3/(-lamb - nu)], + [0, 0, 0, 0]]) + expected_pivots = (0, 1, 2) + + reduced, pivots = A.rref() + + assert simplify(expected_reduced - reduced) == zeros(*A.shape) + assert pivots == expected_pivots + +def test_issue_15872(): + A = Matrix([[1, 1, 1, 0], [-2, -1, 0, -1], [0, 0, -1, -1], [0, 0, 2, 1]]) + B = A - Matrix.eye(4) * I + assert B.rank() == 3 + assert (B**2).rank() == 2 + assert (B**3).rank() == 2 diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_repmatrix.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_repmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..ee36de004705f29eaa49ea8e06fd65a8a2baa718 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_repmatrix.py @@ -0,0 +1,62 @@ +from sympy.testing.pytest import raises +from sympy.matrices.exceptions import NonSquareMatrixError, NonInvertibleMatrixError + +from sympy import Matrix, Rational + + +def test_lll(): + A = Matrix([[1, 0, 0, 0, -20160], + [0, 1, 0, 0, 33768], + [0, 0, 1, 0, 39578], + [0, 0, 0, 1, 47757]]) + L = Matrix([[ 10, -3, -2, 8, -4], + [ 3, -9, 8, 1, -11], + [ -3, 13, -9, -3, -9], + [-12, -7, -11, 9, -1]]) + T = Matrix([[ 10, -3, -2, 8], + [ 3, -9, 8, 1], + [ -3, 13, -9, -3], + [-12, -7, -11, 9]]) + assert A.lll() == L + assert A.lll_transform() == (L, T) + assert T * A == L + + +def test_matrix_inv_mod(): + A = Matrix(2, 1, [1, 0]) + raises(NonSquareMatrixError, lambda: A.inv_mod(2)) + A = Matrix(2, 2, [1, 0, 0, 0]) + raises(NonInvertibleMatrixError, lambda: A.inv_mod(2)) + A = Matrix(2, 2, [1, 2, 3, 4]) + Ai = Matrix(2, 2, [1, 1, 0, 1]) + assert A.inv_mod(3) == Ai + A = Matrix(2, 2, [1, 0, 0, 1]) + assert A.inv_mod(2) == A + A = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + raises(NonInvertibleMatrixError, lambda: A.inv_mod(5)) + A = Matrix(3, 3, [5, 1, 3, 2, 6, 0, 2, 1, 1]) + Ai = Matrix(3, 3, [6, 8, 0, 1, 5, 6, 5, 6, 4]) + assert A.inv_mod(9) == Ai + A = Matrix(3, 3, [1, 6, -3, 4, 1, -5, 3, -5, 5]) + Ai = Matrix(3, 3, [4, 3, 3, 1, 2, 5, 1, 5, 1]) + assert A.inv_mod(6) == Ai + A = Matrix(3, 3, [1, 6, 1, 4, 1, 5, 3, 2, 5]) + Ai = Matrix(3, 3, [6, 0, 3, 6, 6, 4, 1, 6, 1]) + assert A.inv_mod(7) == Ai + A = Matrix([[1, 2], [3, Rational(3,4)]]) + raises(ValueError, lambda: A.inv_mod(2)) + A = Matrix([[1, 2], [3, 4]]) + raises(TypeError, lambda: A.inv_mod(Rational(1, 2))) + # https://github.com/sympy/sympy/issues/27663 + M = Matrix([ + [2, 3, 1, 4], + [1, 5, 3, 2], + [3, 2, 4, 1], + [4, 1, 2, 5], + ]) + assert M.inv_mod(26) == Matrix([ + [7, 21, 10, 10], + [1, 7, 19, 3], + [14, 1, 15, 1], + [25, 23, 3, 12], + ]) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_solvers.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..c1347062c0482336affbeb4bb9a95aedfcc0ae53 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_solvers.py @@ -0,0 +1,615 @@ +import pytest +from sympy.core.function import expand_mul +from sympy.core.numbers import (I, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.simplify.simplify import simplify +from sympy.matrices.exceptions import (ShapeError, NonSquareMatrixError) +from sympy.matrices import ( + ImmutableMatrix, Matrix, eye, ones, ImmutableDenseMatrix, dotprodsimp) +from sympy.matrices.determinant import _det_laplace +from sympy.testing.pytest import raises +from sympy.matrices.exceptions import NonInvertibleMatrixError +from sympy.polys.matrices.exceptions import DMShapeError +from sympy.solvers.solveset import linsolve +from sympy.abc import x, y + +def test_issue_17247_expression_blowup_29(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.gauss_jordan_solve(ones(4, 1)) == (Matrix(S('''[ + [ -32549314808672/3306971225785 - 17397006745216*I/3306971225785], + [ 67439348256/3306971225785 - 9167503335872*I/3306971225785], + [-15091965363354518272/21217636514687010905 + 16890163109293858304*I/21217636514687010905], + [ -11328/952745 + 87616*I/952745]]''')), Matrix(0, 1, [])) + +def test_issue_17247_expression_blowup_30(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.cholesky_solve(ones(4, 1)) == Matrix(S('''[ + [ -32549314808672/3306971225785 - 17397006745216*I/3306971225785], + [ 67439348256/3306971225785 - 9167503335872*I/3306971225785], + [-15091965363354518272/21217636514687010905 + 16890163109293858304*I/21217636514687010905], + [ -11328/952745 + 87616*I/952745]]''')) + +# @XFAIL # This calculation hangs with dotprodsimp. +# def test_issue_17247_expression_blowup_31(): +# M = Matrix([ +# [x + 1, 1 - x, 0, 0], +# [1 - x, x + 1, 0, x + 1], +# [ 0, 1 - x, x + 1, 0], +# [ 0, 0, 0, x + 1]]) +# with dotprodsimp(True): +# assert M.LDLsolve(ones(4, 1)) == Matrix([ +# [(x + 1)/(4*x)], +# [(x - 1)/(4*x)], +# [(x + 1)/(4*x)], +# [ 1/(x + 1)]]) + + +def test_LUsolve_iszerofunc(): + # taken from https://github.com/sympy/sympy/issues/24679 + + M = Matrix([[(x + 1)**2 - (x**2 + 2*x + 1), x], [x, 0]]) + b = Matrix([1, 1]) + is_zero_func = lambda e: False if e._random() else True + + x_exp = Matrix([1/x, (1-(-x**2 - 2*x + (x+1)**2 - 1)/x)/x]) + + assert (x_exp - M.LUsolve(b, iszerofunc=is_zero_func)) == Matrix([0, 0]) + + +def test_issue_17247_expression_blowup_32(): + M = Matrix([ + [x + 1, 1 - x, 0, 0], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 0], + [ 0, 0, 0, x + 1]]) + with dotprodsimp(True): + assert M.LUsolve(ones(4, 1)) == Matrix([ + [(x + 1)/(4*x)], + [(x - 1)/(4*x)], + [(x + 1)/(4*x)], + [ 1/(x + 1)]]) + +def test_LUsolve(): + A = Matrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + x = Matrix(3, 1, [3, 7, 5]) + b = A*x + soln = A.LUsolve(b) + assert soln == x + A = Matrix([[0, -1, 2], + [5, 10, 7], + [8, 3, 4]]) + x = Matrix(3, 1, [-1, 2, 5]) + b = A*x + soln = A.LUsolve(b) + assert soln == x + A = Matrix([[2, 1], [1, 0], [1, 0]]) # issue 14548 + b = Matrix([3, 1, 1]) + assert A.LUsolve(b) == Matrix([1, 1]) + b = Matrix([3, 1, 2]) # inconsistent + raises(ValueError, lambda: A.LUsolve(b)) + A = Matrix([[0, -1, 2], + [5, 10, 7], + [8, 3, 4], + [2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + x = Matrix([2, 1, -4]) + b = A*x + soln = A.LUsolve(b) + assert soln == x + A = Matrix([[0, -1, 2], [5, 10, 7]]) # underdetermined + x = Matrix([-1, 2, 0]) + b = A*x + raises(NotImplementedError, lambda: A.LUsolve(b)) + + A = Matrix(4, 4, lambda i, j: 1/(i+j+1) if i != 3 else 0) + b = Matrix.zeros(4, 1) + raises(NonInvertibleMatrixError, lambda: A.LUsolve(b)) + + +def test_LUsolve_noncommutative(): + a0, a1, a2, a3 = symbols("a:4", commutative=False) + b0, b1 = symbols("b:2", commutative=False) + A = Matrix([[a0, a1], [a2, a3]]) + check = A * A.LUsolve(Matrix([b0, b1])) + assert check[0, 0].expand() == b0 + # Because sympy simplification is very limited with noncommutative expressions, + # perform an explicit check with the second element + assert check[1, 0] == ( + a2*a0**(-1)*(-a1*(-a2*a0**(-1)*a1 + a3)**(-1)*(-a2*a0**(-1)*b0 + b1) + b0) + + a3*(-a2*a0**(-1)*a1 + a3)**(-1)*(-a2*a0**(-1)*b0 + b1) + ) + + +def test_QRsolve(): + A = Matrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + x = Matrix(3, 1, [3, 7, 5]) + b = A*x + soln = A.QRsolve(b) + assert soln == x + x = Matrix([[1, 2], [3, 4], [5, 6]]) + b = A*x + soln = A.QRsolve(b) + assert soln == x + + A = Matrix([[0, -1, 2], + [5, 10, 7], + [8, 3, 4]]) + x = Matrix(3, 1, [-1, 2, 5]) + b = A*x + soln = A.QRsolve(b) + assert soln == x + x = Matrix([[7, 8], [9, 10], [11, 12]]) + b = A*x + soln = A.QRsolve(b) + assert soln == x + +def test_errors(): + raises(ShapeError, lambda: Matrix([1]).LUsolve(Matrix([[1, 2], [3, 4]]))) + +def test_cholesky_solve(): + A = Matrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + x = Matrix(3, 1, [3, 7, 5]) + b = A*x + soln = A.cholesky_solve(b) + assert soln == x + A = Matrix([[0, -1, 2], + [5, 10, 7], + [8, 3, 4]]) + x = Matrix(3, 1, [-1, 2, 5]) + b = A*x + soln = A.cholesky_solve(b) + assert soln == x + A = Matrix(((1, 5), (5, 1))) + x = Matrix((4, -3)) + b = A*x + soln = A.cholesky_solve(b) + assert soln == x + A = Matrix(((9, 3*I), (-3*I, 5))) + x = Matrix((-2, 1)) + b = A*x + soln = A.cholesky_solve(b) + assert expand_mul(soln) == x + A = Matrix(((9*I, 3), (-3 + I, 5))) + x = Matrix((2 + 3*I, -1)) + b = A*x + soln = A.cholesky_solve(b) + assert expand_mul(soln) == x + a00, a01, a11, b0, b1 = symbols('a00, a01, a11, b0, b1') + A = Matrix(((a00, a01), (a01, a11))) + b = Matrix((b0, b1)) + x = A.cholesky_solve(b) + assert simplify(A*x) == b + + +def test_LDLsolve(): + A = Matrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + x = Matrix(3, 1, [3, 7, 5]) + b = A*x + soln = A.LDLsolve(b) + assert soln == x + + A = Matrix([[0, -1, 2], + [5, 10, 7], + [8, 3, 4]]) + x = Matrix(3, 1, [-1, 2, 5]) + b = A*x + soln = A.LDLsolve(b) + assert soln == x + + A = Matrix(((9, 3*I), (-3*I, 5))) + x = Matrix((-2, 1)) + b = A*x + soln = A.LDLsolve(b) + assert expand_mul(soln) == x + + A = Matrix(((9*I, 3), (-3 + I, 5))) + x = Matrix((2 + 3*I, -1)) + b = A*x + soln = A.LDLsolve(b) + assert expand_mul(soln) == x + + A = Matrix(((9, 3), (3, 9))) + x = Matrix((1, 1)) + b = A * x + soln = A.LDLsolve(b) + assert expand_mul(soln) == x + + A = Matrix([[-5, -3, -4], [-3, -7, 7]]) + x = Matrix([[8], [7], [-2]]) + b = A * x + raises(NotImplementedError, lambda: A.LDLsolve(b)) + + +def test_lower_triangular_solve(): + + raises(NonSquareMatrixError, + lambda: Matrix([1, 0]).lower_triangular_solve(Matrix([0, 1]))) + raises(ShapeError, + lambda: Matrix([[1, 0], [0, 1]]).lower_triangular_solve(Matrix([1]))) + raises(ValueError, + lambda: Matrix([[2, 1], [1, 2]]).lower_triangular_solve( + Matrix([[1, 0], [0, 1]]))) + + A = Matrix([[1, 0], [0, 1]]) + B = Matrix([[x, y], [y, x]]) + C = Matrix([[4, 8], [2, 9]]) + + assert A.lower_triangular_solve(B) == B + assert A.lower_triangular_solve(C) == C + + +def test_upper_triangular_solve(): + + raises(NonSquareMatrixError, + lambda: Matrix([1, 0]).upper_triangular_solve(Matrix([0, 1]))) + raises(ShapeError, + lambda: Matrix([[1, 0], [0, 1]]).upper_triangular_solve(Matrix([1]))) + raises(TypeError, + lambda: Matrix([[2, 1], [1, 2]]).upper_triangular_solve( + Matrix([[1, 0], [0, 1]]))) + + A = Matrix([[1, 0], [0, 1]]) + B = Matrix([[x, y], [y, x]]) + C = Matrix([[2, 4], [3, 8]]) + + assert A.upper_triangular_solve(B) == B + assert A.upper_triangular_solve(C) == C + + +def test_diagonal_solve(): + raises(TypeError, lambda: Matrix([1, 1]).diagonal_solve(Matrix([1]))) + A = Matrix([[1, 0], [0, 1]])*2 + B = Matrix([[x, y], [y, x]]) + assert A.diagonal_solve(B) == B/2 + + A = Matrix([[1, 0], [1, 2]]) + raises(TypeError, lambda: A.diagonal_solve(B)) + +def test_pinv_solve(): + # Fully determined system (unique result, identical to other solvers). + A = Matrix([[1, 5], [7, 9]]) + B = Matrix([12, 13]) + assert A.pinv_solve(B) == A.cholesky_solve(B) + assert A.pinv_solve(B) == A.LDLsolve(B) + assert A.pinv_solve(B) == Matrix([sympify('-43/26'), sympify('71/26')]) + assert A * A.pinv() * B == B + # Fully determined, with two-dimensional B matrix. + B = Matrix([[12, 13, 14], [15, 16, 17]]) + assert A.pinv_solve(B) == A.cholesky_solve(B) + assert A.pinv_solve(B) == A.LDLsolve(B) + assert A.pinv_solve(B) == Matrix([[-33, -37, -41], [69, 75, 81]]) / 26 + assert A * A.pinv() * B == B + # Underdetermined system (infinite results). + A = Matrix([[1, 0, 1], [0, 1, 1]]) + B = Matrix([5, 7]) + solution = A.pinv_solve(B) + w = {} + for s in solution.atoms(Symbol): + # Extract dummy symbols used in the solution. + w[s.name] = s + assert solution == Matrix([[w['w0_0']/3 + w['w1_0']/3 - w['w2_0']/3 + 1], + [w['w0_0']/3 + w['w1_0']/3 - w['w2_0']/3 + 3], + [-w['w0_0']/3 - w['w1_0']/3 + w['w2_0']/3 + 4]]) + assert A * A.pinv() * B == B + # Overdetermined system (least squares results). + A = Matrix([[1, 0], [0, 0], [0, 1]]) + B = Matrix([3, 2, 1]) + assert A.pinv_solve(B) == Matrix([3, 1]) + # Proof the solution is not exact. + assert A * A.pinv() * B != B + +def test_pinv_rank_deficient(): + # Test the four properties of the pseudoinverse for various matrices. + As = [Matrix([[1, 1, 1], [2, 2, 2]]), + Matrix([[1, 0], [0, 0]]), + Matrix([[1, 2], [2, 4], [3, 6]])] + + for A in As: + A_pinv = A.pinv(method="RD") + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + + for A in As: + A_pinv = A.pinv(method="ED") + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + + # Test solving with rank-deficient matrices. + A = Matrix([[1, 0], [0, 0]]) + # Exact, non-unique solution. + B = Matrix([3, 0]) + solution = A.pinv_solve(B) + w1 = solution.atoms(Symbol).pop() + assert w1.name == 'w1_0' + assert solution == Matrix([3, w1]) + assert A * A.pinv() * B == B + # Least squares, non-unique solution. + B = Matrix([3, 1]) + solution = A.pinv_solve(B) + w1 = solution.atoms(Symbol).pop() + assert w1.name == 'w1_0' + assert solution == Matrix([3, w1]) + assert A * A.pinv() * B != B + +def test_gauss_jordan_solve(): + + # Square, full rank, unique solution + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + b = Matrix([3, 6, 9]) + sol, params = A.gauss_jordan_solve(b) + assert sol == Matrix([[-1], [2], [0]]) + assert params == Matrix(0, 1, []) + + # Square, full rank, unique solution, B has more columns than rows + A = eye(3) + B = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) + sol, params = A.gauss_jordan_solve(B) + assert sol == B + assert params == Matrix(0, 4, []) + + # Square, reduced rank, parametrized solution + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + b = Matrix([3, 6, 9]) + sol, params, freevar = A.gauss_jordan_solve(b, freevar=True) + w = {} + for s in sol.atoms(Symbol): + # Extract dummy symbols used in the solution. + w[s.name] = s + assert sol == Matrix([[w['tau0'] - 1], [-2*w['tau0'] + 2], [w['tau0']]]) + assert params == Matrix([[w['tau0']]]) + assert freevar == [2] + + # Square, reduced rank, parametrized solution, B has two columns + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + B = Matrix([[3, 4], [6, 8], [9, 12]]) + sol, params, freevar = A.gauss_jordan_solve(B, freevar=True) + w = {} + for s in sol.atoms(Symbol): + # Extract dummy symbols used in the solution. + w[s.name] = s + assert sol == Matrix([[w['tau0'] - 1, w['tau1'] - Rational(4, 3)], + [-2*w['tau0'] + 2, -2*w['tau1'] + Rational(8, 3)], + [w['tau0'], w['tau1']],]) + assert params == Matrix([[w['tau0'], w['tau1']]]) + assert freevar == [2] + + # Square, reduced rank, parametrized solution + A = Matrix([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) + b = Matrix([0, 0, 0]) + sol, params = A.gauss_jordan_solve(b) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert sol == Matrix([[-2*w['tau0'] - 3*w['tau1']], + [w['tau0']], [w['tau1']]]) + assert params == Matrix([[w['tau0']], [w['tau1']]]) + + # Square, reduced rank, parametrized solution + A = Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + b = Matrix([0, 0, 0]) + sol, params = A.gauss_jordan_solve(b) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert sol == Matrix([[w['tau0']], [w['tau1']], [w['tau2']]]) + assert params == Matrix([[w['tau0']], [w['tau1']], [w['tau2']]]) + + # Square, reduced rank, no solution + A = Matrix([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) + b = Matrix([0, 0, 1]) + raises(ValueError, lambda: A.gauss_jordan_solve(b)) + + # Rectangular, tall, full rank, unique solution + A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]]) + b = Matrix([0, 0, 1, 0]) + sol, params = A.gauss_jordan_solve(b) + assert sol == Matrix([[Rational(-1, 2)], [0], [Rational(1, 6)]]) + assert params == Matrix(0, 1, []) + + # Rectangular, tall, full rank, unique solution, B has less columns than rows + A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]]) + B = Matrix([[0,0], [0, 0], [1, 2], [0, 0]]) + sol, params = A.gauss_jordan_solve(B) + assert sol == Matrix([[Rational(-1, 2), Rational(-2, 2)], [0, 0], [Rational(1, 6), Rational(2, 6)]]) + assert params == Matrix(0, 2, []) + + # Rectangular, tall, full rank, no solution + A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]]) + b = Matrix([0, 0, 0, 1]) + raises(ValueError, lambda: A.gauss_jordan_solve(b)) + + # Rectangular, tall, full rank, no solution, B has two columns (2nd has no solution) + A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]]) + B = Matrix([[0,0], [0, 0], [1, 0], [0, 1]]) + raises(ValueError, lambda: A.gauss_jordan_solve(B)) + + # Rectangular, tall, full rank, no solution, B has two columns (1st has no solution) + A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]]) + B = Matrix([[0,0], [0, 0], [0, 1], [1, 0]]) + raises(ValueError, lambda: A.gauss_jordan_solve(B)) + + # Rectangular, tall, reduced rank, parametrized solution + A = Matrix([[1, 5, 3], [2, 10, 6], [3, 15, 9], [1, 4, 3]]) + b = Matrix([0, 0, 0, 1]) + sol, params = A.gauss_jordan_solve(b) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert sol == Matrix([[-3*w['tau0'] + 5], [-1], [w['tau0']]]) + assert params == Matrix([[w['tau0']]]) + + # Rectangular, tall, reduced rank, no solution + A = Matrix([[1, 5, 3], [2, 10, 6], [3, 15, 9], [1, 4, 3]]) + b = Matrix([0, 0, 1, 1]) + raises(ValueError, lambda: A.gauss_jordan_solve(b)) + + # Rectangular, wide, full rank, parametrized solution + A = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 1, 12]]) + b = Matrix([1, 1, 1]) + sol, params = A.gauss_jordan_solve(b) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert sol == Matrix([[2*w['tau0'] - 1], [-3*w['tau0'] + 1], [0], + [w['tau0']]]) + assert params == Matrix([[w['tau0']]]) + + # Rectangular, wide, reduced rank, parametrized solution + A = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [2, 4, 6, 8]]) + b = Matrix([0, 1, 0]) + sol, params = A.gauss_jordan_solve(b) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert sol == Matrix([[w['tau0'] + 2*w['tau1'] + S.Half], + [-2*w['tau0'] - 3*w['tau1'] - Rational(1, 4)], + [w['tau0']], [w['tau1']]]) + assert params == Matrix([[w['tau0']], [w['tau1']]]) + # watch out for clashing symbols + x0, x1, x2, _x0 = symbols('_tau0 _tau1 _tau2 tau1') + M = Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, _x0]]) + A = M[:, :-1] + b = M[:, -1:] + sol, params = A.gauss_jordan_solve(b) + assert params == Matrix(3, 1, [x0, x1, x2]) + assert sol == Matrix(5, 1, [x0, 0, x1, _x0, x2]) + + # Rectangular, wide, reduced rank, no solution + A = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [2, 4, 6, 8]]) + b = Matrix([1, 1, 1]) + raises(ValueError, lambda: A.gauss_jordan_solve(b)) + + # Test for immutable matrix + A = ImmutableMatrix([[1, 0], [0, 1]]) + B = ImmutableMatrix([1, 2]) + sol, params = A.gauss_jordan_solve(B) + assert sol == ImmutableMatrix([1, 2]) + assert params == ImmutableMatrix(0, 1, []) + assert sol.__class__ == ImmutableDenseMatrix + assert params.__class__ == ImmutableDenseMatrix + + # Test placement of free variables + A = Matrix([[1, 0, 0, 0], [0, 0, 0, 1]]) + b = Matrix([1, 1]) + sol, params = A.gauss_jordan_solve(b) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert sol == Matrix([[1], [w['tau0']], [w['tau1']], [1]]) + assert params == Matrix([[w['tau0']], [w['tau1']]]) + + +def test_linsolve_underdetermined_AND_gauss_jordan_solve(): + #Test placement of free variables as per issue 19815 + A = Matrix([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]) + B = Matrix([1, 2, 1, 1, 1, 1, 1, 2]) + sol, params = A.gauss_jordan_solve(B) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert params == Matrix([[w['tau0']], [w['tau1']], [w['tau2']], + [w['tau3']], [w['tau4']], [w['tau5']]]) + assert sol == Matrix([[1 - 1*w['tau2']], + [w['tau2']], + [1 - 1*w['tau0'] + w['tau1']], + [w['tau0']], + [w['tau3'] + w['tau4']], + [-1*w['tau3'] - 1*w['tau4'] - 1*w['tau1']], + [1 - 1*w['tau2']], + [w['tau1']], + [w['tau2']], + [w['tau3']], + [w['tau4']], + [1 - 1*w['tau5']], + [w['tau5']], + [1]]) + + from sympy.abc import j,f + # https://github.com/sympy/sympy/issues/20046 + A = Matrix([ + [1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, -1, 0, -1, 0, -1, 0, -1, -j], + [0, 0, 0, 0, 1, 1, 1, 1, f] + ]) + + sol_1=Matrix(list(linsolve(A))[0]) + + tau0, tau1, tau2, tau3, tau4 = symbols('tau:5') + + assert sol_1 == Matrix([[-f - j - tau0 + tau2 + tau4 + 1], + [j - tau1 - tau2 - tau4], + [tau0], + [tau1], + [f - tau2 - tau3 - tau4], + [tau2], + [tau3], + [tau4]]) + + # https://github.com/sympy/sympy/issues/19815 + sol_2 = A[:, : -1 ] * sol_1 - A[:, -1 ] + assert sol_2 == Matrix([[0], [0], [0]]) + + +@pytest.mark.parametrize("det_method", ["bird", "laplace"]) +@pytest.mark.parametrize("M, rhs", [ + (Matrix([[2, 3, 5], [3, 6, 2], [8, 3, 6]]), Matrix(3, 1, [3, 7, 5])), + (Matrix([[2, 3, 5], [3, 6, 2], [8, 3, 6]]), + Matrix([[1, 2], [3, 4], [5, 6]])), + (Matrix(2, 2, symbols("a:4")), Matrix(2, 1, symbols("b:2"))), +]) +def test_cramer_solve(det_method, M, rhs): + assert simplify(M.cramer_solve(rhs, det_method=det_method) - M.LUsolve(rhs) + ) == Matrix.zeros(M.rows, rhs.cols) + + +@pytest.mark.parametrize("det_method, error", [ + ("bird", DMShapeError), (_det_laplace, NonSquareMatrixError)]) +def test_cramer_solve_errors(det_method, error): + # Non-square matrix + A = Matrix([[0, -1, 2], [5, 10, 7]]) + b = Matrix([-2, 15]) + raises(error, lambda: A.cramer_solve(b, det_method=det_method)) + + +def test_solve(): + A = Matrix([[1,2], [2,4]]) + b = Matrix([[3], [4]]) + raises(ValueError, lambda: A.solve(b)) #no solution + b = Matrix([[ 4], [8]]) + raises(ValueError, lambda: A.solve(b)) #infinite solution diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_sparse.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..4d257c8062f220cc06bc0dabdc7ac40ce9dc4adc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_sparse.py @@ -0,0 +1,745 @@ +from sympy.core.numbers import (Float, I, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.polys.polytools import PurePoly +from sympy.matrices import \ + Matrix, MutableSparseMatrix, ImmutableSparseMatrix, SparseMatrix, eye, \ + ones, zeros, ShapeError, NonSquareMatrixError +from sympy.testing.pytest import raises + + +def test_sparse_creation(): + a = SparseMatrix(2, 2, {(0, 0): [[1, 2], [3, 4]]}) + assert a == SparseMatrix([[1, 2], [3, 4]]) + a = SparseMatrix(2, 2, {(0, 0): [[1, 2]]}) + assert a == SparseMatrix([[1, 2], [0, 0]]) + a = SparseMatrix(2, 2, {(0, 0): [1, 2]}) + assert a == SparseMatrix([[1, 0], [2, 0]]) + + +def test_sparse_matrix(): + def sparse_eye(n): + return SparseMatrix.eye(n) + + def sparse_zeros(n): + return SparseMatrix.zeros(n) + + # creation args + raises(TypeError, lambda: SparseMatrix(1, 2)) + + a = SparseMatrix(( + (1, 0), + (0, 1) + )) + assert SparseMatrix(a) == a + + from sympy.matrices import MutableDenseMatrix + a = MutableSparseMatrix([]) + b = MutableDenseMatrix([1, 2]) + assert a.row_join(b) == b + assert a.col_join(b) == b + assert type(a.row_join(b)) == type(a) + assert type(a.col_join(b)) == type(a) + + # make sure 0 x n matrices get stacked correctly + sparse_matrices = [SparseMatrix.zeros(0, n) for n in range(4)] + assert SparseMatrix.hstack(*sparse_matrices) == Matrix(0, 6, []) + sparse_matrices = [SparseMatrix.zeros(n, 0) for n in range(4)] + assert SparseMatrix.vstack(*sparse_matrices) == Matrix(6, 0, []) + + # test element assignment + a = SparseMatrix(( + (1, 0), + (0, 1) + )) + + a[3] = 4 + assert a[1, 1] == 4 + a[3] = 1 + + a[0, 0] = 2 + assert a == SparseMatrix(( + (2, 0), + (0, 1) + )) + a[1, 0] = 5 + assert a == SparseMatrix(( + (2, 0), + (5, 1) + )) + a[1, 1] = 0 + assert a == SparseMatrix(( + (2, 0), + (5, 0) + )) + assert a.todok() == {(0, 0): 2, (1, 0): 5} + + # test_multiplication + a = SparseMatrix(( + (1, 2), + (3, 1), + (0, 6), + )) + + b = SparseMatrix(( + (1, 2), + (3, 0), + )) + + c = a*b + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + try: + eval('c = a @ b') + except SyntaxError: + pass + else: + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + x = Symbol("x") + + c = b * Symbol("x") + assert isinstance(c, SparseMatrix) + assert c[0, 0] == x + assert c[0, 1] == 2*x + assert c[1, 0] == 3*x + assert c[1, 1] == 0 + + c = 5 * b + assert isinstance(c, SparseMatrix) + assert c[0, 0] == 5 + assert c[0, 1] == 2*5 + assert c[1, 0] == 3*5 + assert c[1, 1] == 0 + + #test_power + A = SparseMatrix([[2, 3], [4, 5]]) + assert (A**5)[:] == [6140, 8097, 10796, 14237] + A = SparseMatrix([[2, 1, 3], [4, 2, 4], [6, 12, 1]]) + assert (A**3)[:] == [290, 262, 251, 448, 440, 368, 702, 954, 433] + + # test_creation + x = Symbol("x") + a = SparseMatrix([[x, 0], [0, 0]]) + m = a + assert m.cols == m.rows + assert m.cols == 2 + assert m[:] == [x, 0, 0, 0] + b = SparseMatrix(2, 2, [x, 0, 0, 0]) + m = b + assert m.cols == m.rows + assert m.cols == 2 + assert m[:] == [x, 0, 0, 0] + + assert a == b + S = sparse_eye(3) + S.row_del(1) + assert S == SparseMatrix([ + [1, 0, 0], + [0, 0, 1]]) + S = sparse_eye(3) + S.col_del(1) + assert S == SparseMatrix([ + [1, 0], + [0, 0], + [0, 1]]) + S = SparseMatrix.eye(3) + S[2, 1] = 2 + S.col_swap(1, 0) + assert S == SparseMatrix([ + [0, 1, 0], + [1, 0, 0], + [2, 0, 1]]) + S.row_swap(0, 1) + assert S == SparseMatrix([ + [1, 0, 0], + [0, 1, 0], + [2, 0, 1]]) + + a = SparseMatrix(1, 2, [1, 2]) + b = a.copy() + c = a.copy() + assert a[0] == 1 + a.row_del(0) + assert a == SparseMatrix(0, 2, []) + b.col_del(1) + assert b == SparseMatrix(1, 1, [1]) + + assert SparseMatrix([[1, 2, 3], [1, 2], [1]]) == Matrix([ + [1, 2, 3], + [1, 2, 0], + [1, 0, 0]]) + assert SparseMatrix(4, 4, {(1, 1): sparse_eye(2)}) == Matrix([ + [0, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 0]]) + raises(ValueError, lambda: SparseMatrix(1, 1, {(1, 1): 1})) + assert SparseMatrix(1, 2, [1, 2]).tolist() == [[1, 2]] + assert SparseMatrix(2, 2, [1, [2, 3]]).tolist() == [[1, 0], [2, 3]] + raises(ValueError, lambda: SparseMatrix(2, 2, [1])) + raises(ValueError, lambda: SparseMatrix(1, 1, [[1, 2]])) + assert SparseMatrix([.1]).has(Float) + # autosizing + assert SparseMatrix(None, {(0, 1): 0}).shape == (0, 0) + assert SparseMatrix(None, {(0, 1): 1}).shape == (1, 2) + assert SparseMatrix(None, None, {(0, 1): 1}).shape == (1, 2) + raises(ValueError, lambda: SparseMatrix(None, 1, [[1, 2]])) + raises(ValueError, lambda: SparseMatrix(1, None, [[1, 2]])) + raises(ValueError, lambda: SparseMatrix(3, 3, {(0, 0): ones(2), (1, 1): 2})) + + # test_determinant + x, y = Symbol('x'), Symbol('y') + + assert SparseMatrix(1, 1, [0]).det() == 0 + + assert SparseMatrix([[1]]).det() == 1 + + assert SparseMatrix(((-3, 2), (8, -5))).det() == -1 + + assert SparseMatrix(((x, 1), (y, 2*y))).det() == 2*x*y - y + + assert SparseMatrix(( (1, 1, 1), + (1, 2, 3), + (1, 3, 6) )).det() == 1 + + assert SparseMatrix(( ( 3, -2, 0, 5), + (-2, 1, -2, 2), + ( 0, -2, 5, 0), + ( 5, 0, 3, 4) )).det() == -289 + + assert SparseMatrix(( ( 1, 2, 3, 4), + ( 5, 6, 7, 8), + ( 9, 10, 11, 12), + (13, 14, 15, 16) )).det() == 0 + + assert SparseMatrix(( (3, 2, 0, 0, 0), + (0, 3, 2, 0, 0), + (0, 0, 3, 2, 0), + (0, 0, 0, 3, 2), + (2, 0, 0, 0, 3) )).det() == 275 + + assert SparseMatrix(( (1, 0, 1, 2, 12), + (2, 0, 1, 1, 4), + (2, 1, 1, -1, 3), + (3, 2, -1, 1, 8), + (1, 1, 1, 0, 6) )).det() == -55 + + assert SparseMatrix(( (-5, 2, 3, 4, 5), + ( 1, -4, 3, 4, 5), + ( 1, 2, -3, 4, 5), + ( 1, 2, 3, -2, 5), + ( 1, 2, 3, 4, -1) )).det() == 11664 + + assert SparseMatrix(( ( 3, 0, 0, 0), + (-2, 1, 0, 0), + ( 0, -2, 5, 0), + ( 5, 0, 3, 4) )).det() == 60 + + assert SparseMatrix(( ( 1, 0, 0, 0), + ( 5, 0, 0, 0), + ( 9, 10, 11, 0), + (13, 14, 15, 16) )).det() == 0 + + assert SparseMatrix(( (3, 2, 0, 0, 0), + (0, 3, 2, 0, 0), + (0, 0, 3, 2, 0), + (0, 0, 0, 3, 2), + (0, 0, 0, 0, 3) )).det() == 243 + + assert SparseMatrix(( ( 2, 7, -1, 3, 2), + ( 0, 0, 1, 0, 1), + (-2, 0, 7, 0, 2), + (-3, -2, 4, 5, 3), + ( 1, 0, 0, 0, 1) )).det() == 123 + + # test_slicing + m0 = sparse_eye(4) + assert m0[:3, :3] == sparse_eye(3) + assert m0[2:4, 0:2] == sparse_zeros(2) + + m1 = SparseMatrix(3, 3, lambda i, j: i + j) + assert m1[0, :] == SparseMatrix(1, 3, (0, 1, 2)) + assert m1[1:3, 1] == SparseMatrix(2, 1, (2, 3)) + + m2 = SparseMatrix( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]) + assert m2[:, -1] == SparseMatrix(4, 1, [3, 7, 11, 15]) + assert m2[-2:, :] == SparseMatrix([[8, 9, 10, 11], [12, 13, 14, 15]]) + + assert SparseMatrix([[1, 2], [3, 4]])[[1], [1]] == Matrix([[4]]) + + # test_submatrix_assignment + m = sparse_zeros(4) + m[2:4, 2:4] = sparse_eye(2) + assert m == SparseMatrix([(0, 0, 0, 0), + (0, 0, 0, 0), + (0, 0, 1, 0), + (0, 0, 0, 1)]) + assert len(m.todok()) == 2 + m[:2, :2] = sparse_eye(2) + assert m == sparse_eye(4) + m[:, 0] = SparseMatrix(4, 1, (1, 2, 3, 4)) + assert m == SparseMatrix([(1, 0, 0, 0), + (2, 1, 0, 0), + (3, 0, 1, 0), + (4, 0, 0, 1)]) + m[:, :] = sparse_zeros(4) + assert m == sparse_zeros(4) + m[:, :] = ((1, 2, 3, 4), (5, 6, 7, 8), (9, 10, 11, 12), (13, 14, 15, 16)) + assert m == SparseMatrix((( 1, 2, 3, 4), + ( 5, 6, 7, 8), + ( 9, 10, 11, 12), + (13, 14, 15, 16))) + m[:2, 0] = [0, 0] + assert m == SparseMatrix((( 0, 2, 3, 4), + ( 0, 6, 7, 8), + ( 9, 10, 11, 12), + (13, 14, 15, 16))) + + # test_reshape + m0 = sparse_eye(3) + assert m0.reshape(1, 9) == SparseMatrix(1, 9, (1, 0, 0, 0, 1, 0, 0, 0, 1)) + m1 = SparseMatrix(3, 4, lambda i, j: i + j) + assert m1.reshape(4, 3) == \ + SparseMatrix([(0, 1, 2), (3, 1, 2), (3, 4, 2), (3, 4, 5)]) + assert m1.reshape(2, 6) == \ + SparseMatrix([(0, 1, 2, 3, 1, 2), (3, 4, 2, 3, 4, 5)]) + + # test_applyfunc + m0 = sparse_eye(3) + assert m0.applyfunc(lambda x: 2*x) == sparse_eye(3)*2 + assert m0.applyfunc(lambda x: 0 ) == sparse_zeros(3) + + # test__eval_Abs + assert abs(SparseMatrix(((x, 1), (y, 2*y)))) == SparseMatrix(((Abs(x), 1), (Abs(y), 2*Abs(y)))) + + # test_LUdecomp + testmat = SparseMatrix([[ 0, 2, 5, 3], + [ 3, 3, 7, 4], + [ 8, 4, 0, 2], + [-2, 6, 3, 4]]) + L, U, p = testmat.LUdecomposition() + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - testmat == sparse_zeros(4) + + testmat = SparseMatrix([[ 6, -2, 7, 4], + [ 0, 3, 6, 7], + [ 1, -2, 7, 4], + [-9, 2, 6, 3]]) + L, U, p = testmat.LUdecomposition() + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - testmat == sparse_zeros(4) + + x, y, z = Symbol('x'), Symbol('y'), Symbol('z') + M = Matrix(((1, x, 1), (2, y, 0), (y, 0, z))) + L, U, p = M.LUdecomposition() + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - M == sparse_zeros(3) + + # test_LUsolve + A = SparseMatrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + x = SparseMatrix(3, 1, [3, 7, 5]) + b = A*x + soln = A.LUsolve(b) + assert soln == x + A = SparseMatrix([[0, -1, 2], + [5, 10, 7], + [8, 3, 4]]) + x = SparseMatrix(3, 1, [-1, 2, 5]) + b = A*x + soln = A.LUsolve(b) + assert soln == x + + # test_inverse + A = sparse_eye(4) + assert A.inv() == sparse_eye(4) + assert A.inv(method="CH") == sparse_eye(4) + assert A.inv(method="LDL") == sparse_eye(4) + + A = SparseMatrix([[2, 3, 5], + [3, 6, 2], + [7, 2, 6]]) + Ainv = SparseMatrix(Matrix(A).inv()) + assert A*Ainv == sparse_eye(3) + assert A.inv(method="CH") == Ainv + assert A.inv(method="LDL") == Ainv + + A = SparseMatrix([[2, 3, 5], + [3, 6, 2], + [5, 2, 6]]) + Ainv = SparseMatrix(Matrix(A).inv()) + assert A*Ainv == sparse_eye(3) + assert A.inv(method="CH") == Ainv + assert A.inv(method="LDL") == Ainv + + # test_cross + v1 = Matrix(1, 3, [1, 2, 3]) + v2 = Matrix(1, 3, [3, 4, 5]) + assert v1.cross(v2) == Matrix(1, 3, [-2, 4, -2]) + assert v1.norm(2)**2 == 14 + + # conjugate + a = SparseMatrix(((1, 2 + I), (3, 4))) + assert a.C == SparseMatrix([ + [1, 2 - I], + [3, 4] + ]) + + # mul + assert a*Matrix(2, 2, [1, 0, 0, 1]) == a + assert a + Matrix(2, 2, [1, 1, 1, 1]) == SparseMatrix([ + [2, 3 + I], + [4, 5] + ]) + + # col join + assert a.col_join(sparse_eye(2)) == SparseMatrix([ + [1, 2 + I], + [3, 4], + [1, 0], + [0, 1] + ]) + + # row insert + assert a.row_insert(2, sparse_eye(2)) == SparseMatrix([ + [1, 2 + I], + [3, 4], + [1, 0], + [0, 1] + ]) + + # col insert + assert a.col_insert(2, SparseMatrix.zeros(2, 1)) == SparseMatrix([ + [1, 2 + I, 0], + [3, 4, 0], + ]) + + # symmetric + assert not a.is_symmetric(simplify=False) + + # col op + M = SparseMatrix.eye(3)*2 + M[1, 0] = -1 + M.col_op(1, lambda v, i: v + 2*M[i, 0]) + assert M == SparseMatrix([ + [ 2, 4, 0], + [-1, 0, 0], + [ 0, 0, 2] + ]) + + # fill + M = SparseMatrix.eye(3) + M.fill(2) + assert M == SparseMatrix([ + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + ]) + + # test_cofactor + assert sparse_eye(3) == sparse_eye(3).cofactor_matrix() + test = SparseMatrix([[1, 3, 2], [2, 6, 3], [2, 3, 6]]) + assert test.cofactor_matrix() == \ + SparseMatrix([[27, -6, -6], [-12, 2, 3], [-3, 1, 0]]) + test = SparseMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert test.cofactor_matrix() == \ + SparseMatrix([[-3, 6, -3], [6, -12, 6], [-3, 6, -3]]) + + # test_jacobian + x = Symbol('x') + y = Symbol('y') + L = SparseMatrix(1, 2, [x**2*y, 2*y**2 + x*y]) + syms = [x, y] + assert L.jacobian(syms) == Matrix([[2*x*y, x**2], [y, 4*y + x]]) + + L = SparseMatrix(1, 2, [x, x**2*y**3]) + assert L.jacobian(syms) == SparseMatrix([[1, 0], [2*x*y**3, x**2*3*y**2]]) + + # test_QR + A = Matrix([[1, 2], [2, 3]]) + Q, S = A.QRdecomposition() + R = Rational + assert Q == Matrix([ + [ 5**R(-1, 2), (R(2)/5)*(R(1)/5)**R(-1, 2)], + [2*5**R(-1, 2), (-R(1)/5)*(R(1)/5)**R(-1, 2)]]) + assert S == Matrix([ + [5**R(1, 2), 8*5**R(-1, 2)], + [ 0, (R(1)/5)**R(1, 2)]]) + assert Q*S == A + assert Q.T * Q == sparse_eye(2) + + R = Rational + # test nullspace + # first test reduced row-ech form + + M = SparseMatrix([[5, 7, 2, 1], + [1, 6, 2, -1]]) + out, tmp = M.rref() + assert out == Matrix([[1, 0, -R(2)/23, R(13)/23], + [0, 1, R(8)/23, R(-6)/23]]) + + M = SparseMatrix([[ 1, 3, 0, 2, 6, 3, 1], + [-2, -6, 0, -2, -8, 3, 1], + [ 3, 9, 0, 0, 6, 6, 2], + [-1, -3, 0, 1, 0, 9, 3]]) + + out, tmp = M.rref() + assert out == Matrix([[1, 3, 0, 0, 2, 0, 0], + [0, 0, 0, 1, 2, 0, 0], + [0, 0, 0, 0, 0, 1, R(1)/3], + [0, 0, 0, 0, 0, 0, 0]]) + # now check the vectors + basis = M.nullspace() + assert basis[0] == Matrix([-3, 1, 0, 0, 0, 0, 0]) + assert basis[1] == Matrix([0, 0, 1, 0, 0, 0, 0]) + assert basis[2] == Matrix([-2, 0, 0, -2, 1, 0, 0]) + assert basis[3] == Matrix([0, 0, 0, 0, 0, R(-1)/3, 1]) + + # test eigen + x = Symbol('x') + y = Symbol('y') + sparse_eye3 = sparse_eye(3) + assert sparse_eye3.charpoly(x) == PurePoly((x - 1)**3) + assert sparse_eye3.charpoly(y) == PurePoly((y - 1)**3) + + # test values + M = Matrix([( 0, 1, -1), + ( 1, 1, 0), + (-1, 0, 1)]) + vals = M.eigenvals() + assert sorted(vals.keys()) == [-1, 1, 2] + + R = Rational + M = Matrix([[1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + assert M.eigenvects() == [(1, 3, [ + Matrix([1, 0, 0]), + Matrix([0, 1, 0]), + Matrix([0, 0, 1])])] + M = Matrix([[5, 0, 2], + [3, 2, 0], + [0, 0, 1]]) + assert M.eigenvects() == [(1, 1, [Matrix([R(-1)/2, R(3)/2, 1])]), + (2, 1, [Matrix([0, 1, 0])]), + (5, 1, [Matrix([1, 1, 0])])] + + assert M.zeros(3, 5) == SparseMatrix(3, 5, {}) + A = SparseMatrix(10, 10, {(0, 0): 18, (0, 9): 12, (1, 4): 18, (2, 7): 16, (3, 9): 12, (4, 2): 19, (5, 7): 16, (6, 2): 12, (9, 7): 18}) + assert A.row_list() == [(0, 0, 18), (0, 9, 12), (1, 4, 18), (2, 7, 16), (3, 9, 12), (4, 2, 19), (5, 7, 16), (6, 2, 12), (9, 7, 18)] + assert A.col_list() == [(0, 0, 18), (4, 2, 19), (6, 2, 12), (1, 4, 18), (2, 7, 16), (5, 7, 16), (9, 7, 18), (0, 9, 12), (3, 9, 12)] + assert SparseMatrix.eye(2).nnz() == 2 + + +def test_scalar_multiply(): + assert SparseMatrix([[1, 2]]).scalar_multiply(3) == SparseMatrix([[3, 6]]) + + +def test_transpose(): + assert SparseMatrix(((1, 2), (3, 4))).transpose() == \ + SparseMatrix(((1, 3), (2, 4))) + + +def test_trace(): + assert SparseMatrix(((1, 2), (3, 4))).trace() == 5 + assert SparseMatrix(((0, 0), (0, 4))).trace() == 4 + + +def test_CL_RL(): + assert SparseMatrix(((1, 2), (3, 4))).row_list() == \ + [(0, 0, 1), (0, 1, 2), (1, 0, 3), (1, 1, 4)] + assert SparseMatrix(((1, 2), (3, 4))).col_list() == \ + [(0, 0, 1), (1, 0, 3), (0, 1, 2), (1, 1, 4)] + + +def test_add(): + assert SparseMatrix(((1, 0), (0, 1))) + SparseMatrix(((0, 1), (1, 0))) == \ + SparseMatrix(((1, 1), (1, 1))) + a = SparseMatrix(100, 100, lambda i, j: int(j != 0 and i % j == 0)) + b = SparseMatrix(100, 100, lambda i, j: int(i != 0 and j % i == 0)) + assert (len(a.todok()) + len(b.todok()) - len((a + b).todok()) > 0) + + +def test_errors(): + raises(ValueError, lambda: SparseMatrix(1.4, 2, lambda i, j: 0)) + raises(TypeError, lambda: SparseMatrix([1, 2, 3], [1, 2])) + raises(ValueError, lambda: SparseMatrix([[1, 2], [3, 4]])[(1, 2, 3)]) + raises(IndexError, lambda: SparseMatrix([[1, 2], [3, 4]])[5]) + raises(ValueError, lambda: SparseMatrix([[1, 2], [3, 4]])[1, 2, 3]) + raises(TypeError, + lambda: SparseMatrix([[1, 2], [3, 4]]).copyin_list([0, 1], set())) + raises( + IndexError, lambda: SparseMatrix([[1, 2], [3, 4]])[1, 2]) + raises(TypeError, lambda: SparseMatrix([1, 2, 3]).cross(1)) + raises(IndexError, lambda: SparseMatrix(1, 2, [1, 2])[3]) + raises(ShapeError, + lambda: SparseMatrix(1, 2, [1, 2]) + SparseMatrix(2, 1, [2, 1])) + + +def test_len(): + assert not SparseMatrix() + assert SparseMatrix() == SparseMatrix([]) + assert SparseMatrix() == SparseMatrix([[]]) + + +def test_sparse_zeros_sparse_eye(): + assert SparseMatrix.eye(3) == eye(3, cls=SparseMatrix) + assert len(SparseMatrix.eye(3).todok()) == 3 + assert SparseMatrix.zeros(3) == zeros(3, cls=SparseMatrix) + assert len(SparseMatrix.zeros(3).todok()) == 0 + + +def test_copyin(): + s = SparseMatrix(3, 3, {}) + s[1, 0] = 1 + assert s[:, 0] == SparseMatrix(Matrix([0, 1, 0])) + assert s[3] == 1 + assert s[3: 4] == [1] + s[1, 1] = 42 + assert s[1, 1] == 42 + assert s[1, 1:] == SparseMatrix([[42, 0]]) + s[1, 1:] = Matrix([[5, 6]]) + assert s[1, :] == SparseMatrix([[1, 5, 6]]) + s[1, 1:] = [[42, 43]] + assert s[1, :] == SparseMatrix([[1, 42, 43]]) + s[0, 0] = 17 + assert s[:, :1] == SparseMatrix([17, 1, 0]) + s[0, 0] = [1, 1, 1] + assert s[:, 0] == SparseMatrix([1, 1, 1]) + s[0, 0] = Matrix([1, 1, 1]) + assert s[:, 0] == SparseMatrix([1, 1, 1]) + s[0, 0] = SparseMatrix([1, 1, 1]) + assert s[:, 0] == SparseMatrix([1, 1, 1]) + + +def test_sparse_solve(): + A = SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + assert A.cholesky() == Matrix([ + [ 5, 0, 0], + [ 3, 3, 0], + [-1, 1, 3]]) + assert A.cholesky() * A.cholesky().T == Matrix([ + [25, 15, -5], + [15, 18, 0], + [-5, 0, 11]]) + + A = SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L, D = A.LDLdecomposition() + assert 15*L == Matrix([ + [15, 0, 0], + [ 9, 15, 0], + [-3, 5, 15]]) + assert D == Matrix([ + [25, 0, 0], + [ 0, 9, 0], + [ 0, 0, 9]]) + assert L * D * L.T == A + + A = SparseMatrix(((3, 0, 2), (0, 0, 1), (1, 2, 0))) + assert A.inv() * A == SparseMatrix(eye(3)) + + A = SparseMatrix([ + [ 2, -1, 0], + [-1, 2, -1], + [ 0, 0, 2]]) + ans = SparseMatrix([ + [Rational(2, 3), Rational(1, 3), Rational(1, 6)], + [Rational(1, 3), Rational(2, 3), Rational(1, 3)], + [ 0, 0, S.Half]]) + assert A.inv(method='CH') == ans + assert A.inv(method='LDL') == ans + assert A * ans == SparseMatrix(eye(3)) + + s = A.solve(A[:, 0], 'LDL') + assert A*s == A[:, 0] + s = A.solve(A[:, 0], 'CH') + assert A*s == A[:, 0] + A = A.col_join(A) + s = A.solve_least_squares(A[:, 0], 'CH') + assert A*s == A[:, 0] + s = A.solve_least_squares(A[:, 0], 'LDL') + assert A*s == A[:, 0] + + +def test_lower_triangular_solve(): + raises(NonSquareMatrixError, lambda: + SparseMatrix([[1, 2]]).lower_triangular_solve(Matrix([[1, 2]]))) + raises(ShapeError, lambda: + SparseMatrix([[1, 2], [0, 4]]).lower_triangular_solve(Matrix([1]))) + raises(ValueError, lambda: + SparseMatrix([[1, 2], [3, 4]]).lower_triangular_solve(Matrix([[1, 2], [3, 4]]))) + + a, b, c, d = symbols('a:d') + u, v, w, x = symbols('u:x') + + A = SparseMatrix([[a, 0], [c, d]]) + B = MutableSparseMatrix([[u, v], [w, x]]) + C = ImmutableSparseMatrix([[u, v], [w, x]]) + + sol = Matrix([[u/a, v/a], [(w - c*u/a)/d, (x - c*v/a)/d]]) + assert A.lower_triangular_solve(B) == sol + assert A.lower_triangular_solve(C) == sol + + +def test_upper_triangular_solve(): + raises(NonSquareMatrixError, lambda: + SparseMatrix([[1, 2]]).upper_triangular_solve(Matrix([[1, 2]]))) + raises(ShapeError, lambda: + SparseMatrix([[1, 2], [0, 4]]).upper_triangular_solve(Matrix([1]))) + raises(TypeError, lambda: + SparseMatrix([[1, 2], [3, 4]]).upper_triangular_solve(Matrix([[1, 2], [3, 4]]))) + + a, b, c, d = symbols('a:d') + u, v, w, x = symbols('u:x') + + A = SparseMatrix([[a, b], [0, d]]) + B = MutableSparseMatrix([[u, v], [w, x]]) + C = ImmutableSparseMatrix([[u, v], [w, x]]) + + sol = Matrix([[(u - b*w/d)/a, (v - b*x/d)/a], [w/d, x/d]]) + assert A.upper_triangular_solve(B) == sol + assert A.upper_triangular_solve(C) == sol + + +def test_diagonal_solve(): + a, d = symbols('a d') + u, v, w, x = symbols('u:x') + + A = SparseMatrix([[a, 0], [0, d]]) + B = MutableSparseMatrix([[u, v], [w, x]]) + C = ImmutableSparseMatrix([[u, v], [w, x]]) + + sol = Matrix([[u/a, v/a], [w/d, x/d]]) + assert A.diagonal_solve(B) == sol + assert A.diagonal_solve(C) == sol + + +def test_hermitian(): + x = Symbol('x') + a = SparseMatrix([[0, I], [-I, 0]]) + assert a.is_hermitian + a = SparseMatrix([[1, I], [-I, 1]]) + assert a.is_hermitian + a[0, 0] = 2*I + assert a.is_hermitian is False + a[0, 0] = x + assert a.is_hermitian is None + a[0, 1] = a[1, 0]*I + assert a.is_hermitian is False diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_sparsetools.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_sparsetools.py new file mode 100644 index 0000000000000000000000000000000000000000..244944c31da06460d4bc7beff8bce0f91fea9f14 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_sparsetools.py @@ -0,0 +1,132 @@ +from sympy.matrices.sparsetools import _doktocsr, _csrtodok, banded +from sympy.matrices.dense import (Matrix, eye, ones, zeros) +from sympy.matrices import SparseMatrix +from sympy.testing.pytest import raises + + +def test_doktocsr(): + a = SparseMatrix([[1, 2, 0, 0], [0, 3, 9, 0], [0, 1, 4, 0]]) + b = SparseMatrix(4, 6, [10, 20, 0, 0, 0, 0, 0, 30, 0, 40, 0, 0, 0, 0, 50, + 60, 70, 0, 0, 0, 0, 0, 0, 80]) + c = SparseMatrix(4, 4, [0, 0, 0, 0, 0, 12, 0, 2, 15, 0, 12, 0, 0, 0, 0, 4]) + d = SparseMatrix(10, 10, {(1, 1): 12, (3, 5): 7, (7, 8): 12}) + e = SparseMatrix([[0, 0, 0], [1, 0, 2], [3, 0, 0]]) + f = SparseMatrix(7, 8, {(2, 3): 5, (4, 5):12}) + assert _doktocsr(a) == [[1, 2, 3, 9, 1, 4], [0, 1, 1, 2, 1, 2], + [0, 2, 4, 6], [3, 4]] + assert _doktocsr(b) == [[10, 20, 30, 40, 50, 60, 70, 80], + [0, 1, 1, 3, 2, 3, 4, 5], [0, 2, 4, 7, 8], [4, 6]] + assert _doktocsr(c) == [[12, 2, 15, 12, 4], [1, 3, 0, 2, 3], + [0, 0, 2, 4, 5], [4, 4]] + assert _doktocsr(d) == [[12, 7, 12], [1, 5, 8], + [0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3], [10, 10]] + assert _doktocsr(e) == [[1, 2, 3], [0, 2, 0], [0, 0, 2, 3], [3, 3]] + assert _doktocsr(f) == [[5, 12], [3, 5], [0, 0, 0, 1, 1, 2, 2, 2], [7, 8]] + + +def test_csrtodok(): + h = [[5, 7, 5], [2, 1, 3], [0, 1, 1, 3], [3, 4]] + g = [[12, 5, 4], [2, 4, 2], [0, 1, 2, 3], [3, 7]] + i = [[1, 3, 12], [0, 2, 4], [0, 2, 3], [2, 5]] + j = [[11, 15, 12, 15], [2, 4, 1, 2], [0, 1, 1, 2, 3, 4], [5, 8]] + k = [[1, 3], [2, 1], [0, 1, 1, 2], [3, 3]] + m = _csrtodok(h) + assert isinstance(m, SparseMatrix) + assert m == SparseMatrix(3, 4, + {(0, 2): 5, (2, 1): 7, (2, 3): 5}) + assert _csrtodok(g) == SparseMatrix(3, 7, + {(0, 2): 12, (1, 4): 5, (2, 2): 4}) + assert _csrtodok(i) == SparseMatrix([[1, 0, 3, 0, 0], [0, 0, 0, 0, 12]]) + assert _csrtodok(j) == SparseMatrix(5, 8, + {(0, 2): 11, (2, 4): 15, (3, 1): 12, (4, 2): 15}) + assert _csrtodok(k) == SparseMatrix(3, 3, {(0, 2): 1, (2, 1): 3}) + + +def test_banded(): + raises(TypeError, lambda: banded()) + raises(TypeError, lambda: banded(1)) + raises(TypeError, lambda: banded(1, 2)) + raises(TypeError, lambda: banded(1, 2, 3)) + raises(TypeError, lambda: banded(1, 2, 3, 4)) + raises(ValueError, lambda: banded({0: (1, 2)}, rows=1)) + raises(ValueError, lambda: banded({0: (1, 2)}, cols=1)) + raises(ValueError, lambda: banded(1, {0: (1, 2)})) + raises(ValueError, lambda: banded(2, 1, {0: (1, 2)})) + raises(ValueError, lambda: banded(1, 2, {0: (1, 2)})) + + assert isinstance(banded(2, 4, {}), SparseMatrix) + assert banded(2, 4, {}) == zeros(2, 4) + assert banded({0: 0, 1: 0}) == zeros(0) + assert banded({0: Matrix([1, 2])}) == Matrix([1, 2]) + assert banded({1: [1, 2, 3, 0], -1: [4, 5, 6]}) == \ + banded({1: (1, 2, 3), -1: (4, 5, 6)}) == \ + Matrix([ + [0, 1, 0, 0], + [4, 0, 2, 0], + [0, 5, 0, 3], + [0, 0, 6, 0]]) + assert banded(3, 4, {-1: 1, 0: 2, 1: 3}) == \ + Matrix([ + [2, 3, 0, 0], + [1, 2, 3, 0], + [0, 1, 2, 3]]) + s = lambda d: (1 + d)**2 + assert banded(5, {0: s, 2: s}) == \ + Matrix([ + [1, 0, 1, 0, 0], + [0, 4, 0, 4, 0], + [0, 0, 9, 0, 9], + [0, 0, 0, 16, 0], + [0, 0, 0, 0, 25]]) + assert banded(2, {0: 1}) == \ + Matrix([ + [1, 0], + [0, 1]]) + assert banded(2, 3, {0: 1}) == \ + Matrix([ + [1, 0, 0], + [0, 1, 0]]) + vert = Matrix([1, 2, 3]) + assert banded({0: vert}, cols=3) == \ + Matrix([ + [1, 0, 0], + [2, 1, 0], + [3, 2, 1], + [0, 3, 2], + [0, 0, 3]]) + assert banded(4, {0: ones(2)}) == \ + Matrix([ + [1, 1, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 1], + [0, 0, 1, 1]]) + raises(ValueError, lambda: banded({0: 2, 1: ones(2)}, rows=5)) + assert banded({0: 2, 2: (ones(2),)*3}) == \ + Matrix([ + [2, 0, 1, 1, 0, 0, 0, 0], + [0, 2, 1, 1, 0, 0, 0, 0], + [0, 0, 2, 0, 1, 1, 0, 0], + [0, 0, 0, 2, 1, 1, 0, 0], + [0, 0, 0, 0, 2, 0, 1, 1], + [0, 0, 0, 0, 0, 2, 1, 1]]) + raises(ValueError, lambda: banded({0: (2,)*5, 1: (ones(2),)*3})) + u2 = Matrix([[1, 1], [0, 1]]) + assert banded({0: (2,)*5, 1: (u2,)*3}) == \ + Matrix([ + [2, 1, 1, 0, 0, 0, 0], + [0, 2, 1, 0, 0, 0, 0], + [0, 0, 2, 1, 1, 0, 0], + [0, 0, 0, 2, 1, 0, 0], + [0, 0, 0, 0, 2, 1, 1], + [0, 0, 0, 0, 0, 0, 1]]) + assert banded({0:(0, ones(2)), 2: 2}) == \ + Matrix([ + [0, 0, 2], + [0, 1, 1], + [0, 1, 1]]) + raises(ValueError, lambda: banded({0: (0, ones(2)), 1: 2})) + assert banded({0: 1}, cols=3) == banded({0: 1}, rows=3) == eye(3) + assert banded({1: 1}, rows=3) == Matrix([ + [0, 1, 0], + [0, 0, 1], + [0, 0, 0]]) diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_subspaces.py b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_subspaces.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd853e321eb06f754c17e7bd0c11deb870506f5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/tests/test_subspaces.py @@ -0,0 +1,109 @@ +from sympy.matrices import Matrix +from sympy.core.numbers import Rational +from sympy.core.symbol import symbols +from sympy.solvers import solve + + +def test_columnspace_one(): + m = Matrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + basis = m.columnspace() + assert basis[0] == Matrix([1, -2, 0, 3]) + assert basis[1] == Matrix([2, -5, -3, 6]) + assert basis[2] == Matrix([2, -1, 4, -7]) + + assert len(basis) == 3 + assert Matrix.hstack(m, *basis).columnspace() == basis + + +def test_rowspace(): + m = Matrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + basis = m.rowspace() + assert basis[0] == Matrix([[1, 2, 0, 2, 5]]) + assert basis[1] == Matrix([[0, -1, 1, 3, 2]]) + assert basis[2] == Matrix([[0, 0, 0, 5, 5]]) + + assert len(basis) == 3 + + +def test_nullspace_one(): + m = Matrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + basis = m.nullspace() + assert basis[0] == Matrix([-2, 1, 1, 0, 0]) + assert basis[1] == Matrix([-1, -1, 0, -1, 1]) + # make sure the null space is really gets zeroed + assert all(e.is_zero for e in m*basis[0]) + assert all(e.is_zero for e in m*basis[1]) + +def test_nullspace_second(): + # first test reduced row-ech form + R = Rational + + M = Matrix([[5, 7, 2, 1], + [1, 6, 2, -1]]) + out, tmp = M.rref() + assert out == Matrix([[1, 0, -R(2)/23, R(13)/23], + [0, 1, R(8)/23, R(-6)/23]]) + + M = Matrix([[-5, -1, 4, -3, -1], + [ 1, -1, -1, 1, 0], + [-1, 0, 0, 0, 0], + [ 4, 1, -4, 3, 1], + [-2, 0, 2, -2, -1]]) + assert M*M.nullspace()[0] == Matrix(5, 1, [0]*5) + + M = Matrix([[ 1, 3, 0, 2, 6, 3, 1], + [-2, -6, 0, -2, -8, 3, 1], + [ 3, 9, 0, 0, 6, 6, 2], + [-1, -3, 0, 1, 0, 9, 3]]) + out, tmp = M.rref() + assert out == Matrix([[1, 3, 0, 0, 2, 0, 0], + [0, 0, 0, 1, 2, 0, 0], + [0, 0, 0, 0, 0, 1, R(1)/3], + [0, 0, 0, 0, 0, 0, 0]]) + + # now check the vectors + basis = M.nullspace() + assert basis[0] == Matrix([-3, 1, 0, 0, 0, 0, 0]) + assert basis[1] == Matrix([0, 0, 1, 0, 0, 0, 0]) + assert basis[2] == Matrix([-2, 0, 0, -2, 1, 0, 0]) + assert basis[3] == Matrix([0, 0, 0, 0, 0, R(-1)/3, 1]) + + # issue 4797; just see that we can do it when rows > cols + M = Matrix([[1, 2], [2, 4], [3, 6]]) + assert M.nullspace() + + +def test_columnspace_second(): + M = Matrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + # now check the vectors + basis = M.columnspace() + assert basis[0] == Matrix([1, -2, 0, 3]) + assert basis[1] == Matrix([2, -5, -3, 6]) + assert basis[2] == Matrix([2, -1, 4, -7]) + + #check by columnspace definition + a, b, c, d, e = symbols('a b c d e') + X = Matrix([a, b, c, d, e]) + for i in range(len(basis)): + eq=M*X-basis[i] + assert len(solve(eq, X)) != 0 + + #check if rank-nullity theorem holds + assert M.rank() == len(basis) + assert len(M.nullspace()) + len(M.columnspace()) == M.cols diff --git a/.venv/lib/python3.13/site-packages/sympy/matrices/utilities.py b/.venv/lib/python3.13/site-packages/sympy/matrices/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a680b47e63615e210e561639a192ba47c642d3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/matrices/utilities.py @@ -0,0 +1,72 @@ +from contextlib import contextmanager +from threading import local + +from sympy.core.function import expand_mul + + +class DotProdSimpState(local): + def __init__(self): + self.state = None + +_dotprodsimp_state = DotProdSimpState() + +@contextmanager +def dotprodsimp(x): + old = _dotprodsimp_state.state + + try: + _dotprodsimp_state.state = x + yield + finally: + _dotprodsimp_state.state = old + + +def _dotprodsimp(expr, withsimp=False): + """Wrapper for simplify.dotprodsimp to avoid circular imports.""" + from sympy.simplify.simplify import dotprodsimp as dps + return dps(expr, withsimp=withsimp) + + +def _get_intermediate_simp(deffunc=lambda x: x, offfunc=lambda x: x, + onfunc=_dotprodsimp, dotprodsimp=None): + """Support function for controlling intermediate simplification. Returns a + simplification function according to the global setting of dotprodsimp + operation. + + ``deffunc`` - Function to be used by default. + ``offfunc`` - Function to be used if dotprodsimp has been turned off. + ``onfunc`` - Function to be used if dotprodsimp has been turned on. + ``dotprodsimp`` - True, False or None. Will be overridden by global + _dotprodsimp_state.state if that is not None. + """ + + if dotprodsimp is False or _dotprodsimp_state.state is False: + return offfunc + if dotprodsimp is True or _dotprodsimp_state.state is True: + return onfunc + + return deffunc # None, None + + +def _get_intermediate_simp_bool(default=False, dotprodsimp=None): + """Same as ``_get_intermediate_simp`` but returns bools instead of functions + by default.""" + + return _get_intermediate_simp(default, False, True, dotprodsimp) + + +def _iszero(x): + """Returns True if x is zero.""" + return getattr(x, 'is_zero', None) + + +def _is_zero_after_expand_mul(x): + """Tests by expand_mul only, suitable for polynomials and rational + functions.""" + return expand_mul(x) == 0 + + +def _simplify(expr): + """ Wrapper to avoid circular imports. """ + from sympy.simplify.simplify import simplify + return simplify(expr) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/__init__.py b/.venv/lib/python3.13/site-packages/sympy/parsing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b39d031bca26bc599eb9eb0e12dfe48f7e6db174 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/__init__.py @@ -0,0 +1,4 @@ +"""Used for translating a string into a SymPy expression. """ +__all__ = ['parse_expr'] + +from .sympy_parser import parse_expr diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/ast_parser.py b/.venv/lib/python3.13/site-packages/sympy/parsing/ast_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..95a773d5bec6e130810b7b7925fdff57270aec17 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/ast_parser.py @@ -0,0 +1,79 @@ +""" +This module implements the functionality to take any Python expression as a +string and fix all numbers and other things before evaluating it, +thus + +1/2 + +returns + +Integer(1)/Integer(2) + +We use the ast module for this. It is well documented at docs.python.org. + +Some tips to understand how this works: use dump() to get a nice +representation of any node. Then write a string of what you want to get, +e.g. "Integer(1)", parse it, dump it and you'll see that you need to do +"Call(Name('Integer', Load()), [node], [], None, None)". You do not need +to bother with lineno and col_offset, just call fix_missing_locations() +before returning the node. +""" + +from sympy.core.basic import Basic +from sympy.core.sympify import SympifyError + +from ast import parse, NodeTransformer, Call, Name, Load, \ + fix_missing_locations, Constant, Tuple + +class Transform(NodeTransformer): + + def __init__(self, local_dict, global_dict): + NodeTransformer.__init__(self) + self.local_dict = local_dict + self.global_dict = global_dict + + def visit_Constant(self, node): + if isinstance(node.value, int): + return fix_missing_locations(Call(func=Name('Integer', Load()), + args=[node], keywords=[])) + elif isinstance(node.value, float): + return fix_missing_locations(Call(func=Name('Float', Load()), + args=[node], keywords=[])) + return node + + def visit_Name(self, node): + if node.id in self.local_dict: + return node + elif node.id in self.global_dict: + name_obj = self.global_dict[node.id] + + if isinstance(name_obj, (Basic, type)) or callable(name_obj): + return node + elif node.id in ['True', 'False']: + return node + return fix_missing_locations(Call(func=Name('Symbol', Load()), + args=[Constant(node.id)], keywords=[])) + + def visit_Lambda(self, node): + args = [self.visit(arg) for arg in node.args.args] + body = self.visit(node.body) + n = Call(func=Name('Lambda', Load()), + args=[Tuple(args, Load()), body], keywords=[]) + return fix_missing_locations(n) + +def parse_expr(s, local_dict): + """ + Converts the string "s" to a SymPy expression, in local_dict. + + It converts all numbers to Integers before feeding it to Python and + automatically creates Symbols. + """ + global_dict = {} + exec('from sympy import *', global_dict) + try: + a = parse(s.strip(), mode="eval") + except SyntaxError: + raise SympifyError("Cannot parse %s." % repr(s)) + a = Transform(local_dict, global_dict).visit(a) + e = compile(a, "", "eval") + return eval(e, global_dict, local_dict) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/Autolev.g4 b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/Autolev.g4 new file mode 100644 index 0000000000000000000000000000000000000000..94feea5fa4f49e9d1054eca2cd60c996aebff7c2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/Autolev.g4 @@ -0,0 +1,118 @@ +grammar Autolev; + +options { + language = Python3; +} + +prog: stat+; + +stat: varDecl + | functionCall + | codeCommands + | massDecl + | inertiaDecl + | assignment + | settings + ; + +assignment: vec equals expr #vecAssign + | ID '[' index ']' equals expr #indexAssign + | ID diff? equals expr #regularAssign; + +equals: ('='|'+='|'-='|':='|'*='|'/='|'^='); + +index: expr (',' expr)* ; + +diff: ('\'')+; + +functionCall: ID '(' (expr (',' expr)*)? ')' + | (Mass|Inertia) '(' (ID (',' ID)*)? ')'; + +varDecl: varType varDecl2 (',' varDecl2)*; + +varType: Newtonian|Frames|Bodies|Particles|Points|Constants + | Specifieds|Imaginary|Variables ('\'')*|MotionVariables ('\'')*; + +varDecl2: ID ('{' INT ',' INT '}')? (('{' INT ':' INT (',' INT ':' INT)* '}'))? ('{' INT '}')? ('+'|'-')? ('\'')* ('=' expr)?; + +ranges: ('{' INT ':' INT (',' INT ':' INT)* '}'); + +massDecl: Mass massDecl2 (',' massDecl2)*; + +massDecl2: ID '=' expr; + +inertiaDecl: Inertia ID ('(' ID ')')? (',' expr)+; + +matrix: '[' expr ((','|';') expr)* ']'; +matrixInOutput: (ID (ID '=' (FLOAT|INT)?))|FLOAT|INT; + +codeCommands: units + | inputs + | outputs + | codegen + | commands; + +settings: ID (EXP|ID|FLOAT|INT)?; + +units: UnitSystem ID (',' ID)*; +inputs: Input inputs2 (',' inputs2)*; +id_diff: ID diff?; +inputs2: id_diff '=' expr expr?; +outputs: Output outputs2 (',' outputs2)*; +outputs2: expr expr?; +codegen: ID functionCall ('['matrixInOutput (',' matrixInOutput)*']')? ID'.'ID; + +commands: Save ID'.'ID + | Encode ID (',' ID)*; + +vec: ID ('>')+ + | '0>' + | '1>>'; + +expr: expr '^' expr # Exponent + | expr ('*'|'/') expr # MulDiv + | expr ('+'|'-') expr # AddSub + | EXP # exp + | '-' expr # negativeOne + | FLOAT # float + | INT # int + | ID('\'')* # id + | vec # VectorOrDyadic + | ID '['expr (',' expr)* ']' # Indexing + | functionCall # function + | matrix # matrices + | '(' expr ')' # parens + | expr '=' expr # idEqualsExpr + | expr ':' expr # colon + | ID? ranges ('\'')* # rangess + ; + +// These are to take care of the case insensitivity of Autolev. +Mass: ('M'|'m')('A'|'a')('S'|'s')('S'|'s'); +Inertia: ('I'|'i')('N'|'n')('E'|'e')('R'|'r')('T'|'t')('I'|'i')('A'|'a'); +Input: ('I'|'i')('N'|'n')('P'|'p')('U'|'u')('T'|'t')('S'|'s')?; +Output: ('O'|'o')('U'|'u')('T'|'t')('P'|'p')('U'|'u')('T'|'t'); +Save: ('S'|'s')('A'|'a')('V'|'v')('E'|'e'); +UnitSystem: ('U'|'u')('N'|'n')('I'|'i')('T'|'t')('S'|'s')('Y'|'y')('S'|'s')('T'|'t')('E'|'e')('M'|'m'); +Encode: ('E'|'e')('N'|'n')('C'|'c')('O'|'o')('D'|'d')('E'|'e'); +Newtonian: ('N'|'n')('E'|'e')('W'|'w')('T'|'t')('O'|'o')('N'|'n')('I'|'i')('A'|'a')('N'|'n'); +Frames: ('F'|'f')('R'|'r')('A'|'a')('M'|'m')('E'|'e')('S'|'s')?; +Bodies: ('B'|'b')('O'|'o')('D'|'d')('I'|'i')('E'|'e')('S'|'s')?; +Particles: ('P'|'p')('A'|'a')('R'|'r')('T'|'t')('I'|'i')('C'|'c')('L'|'l')('E'|'e')('S'|'s')?; +Points: ('P'|'p')('O'|'o')('I'|'i')('N'|'n')('T'|'t')('S'|'s')?; +Constants: ('C'|'c')('O'|'o')('N'|'n')('S'|'s')('T'|'t')('A'|'a')('N'|'n')('T'|'t')('S'|'s')?; +Specifieds: ('S'|'s')('P'|'p')('E'|'e')('C'|'c')('I'|'i')('F'|'f')('I'|'i')('E'|'e')('D'|'d')('S'|'s')?; +Imaginary: ('I'|'i')('M'|'m')('A'|'a')('G'|'g')('I'|'i')('N'|'n')('A'|'a')('R'|'r')('Y'|'y'); +Variables: ('V'|'v')('A'|'a')('R'|'r')('I'|'i')('A'|'a')('B'|'b')('L'|'l')('E'|'e')('S'|'s')?; +MotionVariables: ('M'|'m')('O'|'o')('T'|'t')('I'|'i')('O'|'o')('N'|'n')('V'|'v')('A'|'a')('R'|'r')('I'|'i')('A'|'a')('B'|'b')('L'|'l')('E'|'e')('S'|'s')?; + +fragment DIFF: ('\'')*; +fragment DIGIT: [0-9]; +INT: [0-9]+ ; // match integers +FLOAT: DIGIT+ '.' DIGIT* + | '.' DIGIT+; +EXP: FLOAT 'E' INT +| FLOAT 'E' '-' INT; +LINE_COMMENT : '%' .*? '\r'? '\n' -> skip ; +ID: [a-zA-Z][a-zA-Z0-9_]*; +WS: [ \t\r\n&]+ -> skip ; // toss out whitespace diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/__init__.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec81bb83325d68e1c11b43a1df5ec56846367e9f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/__init__.py @@ -0,0 +1,97 @@ +from sympy.external import import_module +from sympy.utilities.decorator import doctest_depends_on + +@doctest_depends_on(modules=('antlr4',)) +def parse_autolev(autolev_code, include_numeric=False): + """Parses Autolev code (version 4.1) to SymPy code. + + Parameters + ========= + autolev_code : Can be an str or any object with a readlines() method (such as a file handle or StringIO). + include_numeric : boolean, optional + If True NumPy, PyDy, or other numeric code is included for numeric evaluation lines in the Autolev code. + + Returns + ======= + sympy_code : str + Equivalent SymPy and/or numpy/pydy code as the input code. + + + Example (Double Pendulum) + ========================= + >>> my_al_text = ("MOTIONVARIABLES' Q{2}', U{2}'", + ... "CONSTANTS L,M,G", + ... "NEWTONIAN N", + ... "FRAMES A,B", + ... "SIMPROT(N, A, 3, Q1)", + ... "SIMPROT(N, B, 3, Q2)", + ... "W_A_N>=U1*N3>", + ... "W_B_N>=U2*N3>", + ... "POINT O", + ... "PARTICLES P,R", + ... "P_O_P> = L*A1>", + ... "P_P_R> = L*B1>", + ... "V_O_N> = 0>", + ... "V2PTS(N, A, O, P)", + ... "V2PTS(N, B, P, R)", + ... "MASS P=M, R=M", + ... "Q1' = U1", + ... "Q2' = U2", + ... "GRAVITY(G*N1>)", + ... "ZERO = FR() + FRSTAR()", + ... "KANE()", + ... "INPUT M=1,G=9.81,L=1", + ... "INPUT Q1=.1,Q2=.2,U1=0,U2=0", + ... "INPUT TFINAL=10, INTEGSTP=.01", + ... "CODE DYNAMICS() some_filename.c") + >>> my_al_text = '\\n'.join(my_al_text) + >>> from sympy.parsing.autolev import parse_autolev + >>> print(parse_autolev(my_al_text, include_numeric=True)) + import sympy.physics.mechanics as _me + import sympy as _sm + import math as m + import numpy as _np + + q1, q2, u1, u2 = _me.dynamicsymbols('q1 q2 u1 u2') + q1_d, q2_d, u1_d, u2_d = _me.dynamicsymbols('q1_ q2_ u1_ u2_', 1) + l, m, g = _sm.symbols('l m g', real=True) + frame_n = _me.ReferenceFrame('n') + frame_a = _me.ReferenceFrame('a') + frame_b = _me.ReferenceFrame('b') + frame_a.orient(frame_n, 'Axis', [q1, frame_n.z]) + frame_b.orient(frame_n, 'Axis', [q2, frame_n.z]) + frame_a.set_ang_vel(frame_n, u1*frame_n.z) + frame_b.set_ang_vel(frame_n, u2*frame_n.z) + point_o = _me.Point('o') + particle_p = _me.Particle('p', _me.Point('p_pt'), _sm.Symbol('m')) + particle_r = _me.Particle('r', _me.Point('r_pt'), _sm.Symbol('m')) + particle_p.point.set_pos(point_o, l*frame_a.x) + particle_r.point.set_pos(particle_p.point, l*frame_b.x) + point_o.set_vel(frame_n, 0) + particle_p.point.v2pt_theory(point_o,frame_n,frame_a) + particle_r.point.v2pt_theory(particle_p.point,frame_n,frame_b) + particle_p.mass = m + particle_r.mass = m + force_p = particle_p.mass*(g*frame_n.x) + force_r = particle_r.mass*(g*frame_n.x) + kd_eqs = [q1_d - u1, q2_d - u2] + forceList = [(particle_p.point,particle_p.mass*(g*frame_n.x)), (particle_r.point,particle_r.mass*(g*frame_n.x))] + kane = _me.KanesMethod(frame_n, q_ind=[q1,q2], u_ind=[u1, u2], kd_eqs = kd_eqs) + fr, frstar = kane.kanes_equations([particle_p, particle_r], forceList) + zero = fr+frstar + from pydy.system import System + sys = System(kane, constants = {l:1, m:1, g:9.81}, + specifieds={}, + initial_conditions={q1:.1, q2:.2, u1:0, u2:0}, + times = _np.linspace(0.0, 10, 10/.01)) + + y=sys.integrate() + + """ + + _autolev = import_module( + 'sympy.parsing.autolev._parse_autolev_antlr', + import_kwargs={'fromlist': ['X']}) + + if _autolev is not None: + return _autolev.parse_autolev(autolev_code, include_numeric) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_antlr/__init__.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_antlr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b71e9f51fd455558a9eb42dc840604c6c96e4b3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_antlr/__init__.py @@ -0,0 +1,5 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_antlr/autolevlexer.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_antlr/autolevlexer.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b3b1d27ade809a63d9fd328a1572c17625443e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_antlr/autolevlexer.py @@ -0,0 +1,253 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +from io import StringIO +import sys +if sys.version_info[1] > 5: + from typing import TextIO +else: + from typing.io import TextIO + + +def serializedATN(): + return [ + 4,0,49,393,6,-1,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5, + 2,6,7,6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2, + 13,7,13,2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7, + 19,2,20,7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2, + 26,7,26,2,27,7,27,2,28,7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7, + 32,2,33,7,33,2,34,7,34,2,35,7,35,2,36,7,36,2,37,7,37,2,38,7,38,2, + 39,7,39,2,40,7,40,2,41,7,41,2,42,7,42,2,43,7,43,2,44,7,44,2,45,7, + 45,2,46,7,46,2,47,7,47,2,48,7,48,2,49,7,49,2,50,7,50,1,0,1,0,1,1, + 1,1,1,2,1,2,1,3,1,3,1,3,1,4,1,4,1,4,1,5,1,5,1,5,1,6,1,6,1,6,1,7, + 1,7,1,7,1,8,1,8,1,8,1,9,1,9,1,10,1,10,1,11,1,11,1,12,1,12,1,13,1, + 13,1,14,1,14,1,15,1,15,1,16,1,16,1,17,1,17,1,18,1,18,1,19,1,19,1, + 20,1,20,1,21,1,21,1,21,1,22,1,22,1,22,1,22,1,23,1,23,1,24,1,24,1, + 25,1,25,1,26,1,26,1,26,1,26,1,26,1,27,1,27,1,27,1,27,1,27,1,27,1, + 27,1,27,1,28,1,28,1,28,1,28,1,28,1,28,3,28,184,8,28,1,29,1,29,1, + 29,1,29,1,29,1,29,1,29,1,30,1,30,1,30,1,30,1,30,1,31,1,31,1,31,1, + 31,1,31,1,31,1,31,1,31,1,31,1,31,1,31,1,32,1,32,1,32,1,32,1,32,1, + 32,1,32,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,34,1, + 34,1,34,1,34,1,34,1,34,3,34,232,8,34,1,35,1,35,1,35,1,35,1,35,1, + 35,3,35,240,8,35,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,3, + 36,251,8,36,1,37,1,37,1,37,1,37,1,37,1,37,3,37,259,8,37,1,38,1,38, + 1,38,1,38,1,38,1,38,1,38,1,38,1,38,3,38,270,8,38,1,39,1,39,1,39, + 1,39,1,39,1,39,1,39,1,39,1,39,1,39,3,39,282,8,39,1,40,1,40,1,40, + 1,40,1,40,1,40,1,40,1,40,1,40,1,40,1,41,1,41,1,41,1,41,1,41,1,41, + 1,41,1,41,1,41,3,41,303,8,41,1,42,1,42,1,42,1,42,1,42,1,42,1,42, + 1,42,1,42,1,42,1,42,1,42,1,42,1,42,1,42,3,42,320,8,42,1,43,5,43, + 323,8,43,10,43,12,43,326,9,43,1,44,1,44,1,45,4,45,331,8,45,11,45, + 12,45,332,1,46,4,46,336,8,46,11,46,12,46,337,1,46,1,46,5,46,342, + 8,46,10,46,12,46,345,9,46,1,46,1,46,4,46,349,8,46,11,46,12,46,350, + 3,46,353,8,46,1,47,1,47,1,47,1,47,1,47,1,47,1,47,1,47,1,47,3,47, + 364,8,47,1,48,1,48,5,48,368,8,48,10,48,12,48,371,9,48,1,48,3,48, + 374,8,48,1,48,1,48,1,48,1,48,1,49,1,49,5,49,382,8,49,10,49,12,49, + 385,9,49,1,50,4,50,388,8,50,11,50,12,50,389,1,50,1,50,1,369,0,51, + 1,1,3,2,5,3,7,4,9,5,11,6,13,7,15,8,17,9,19,10,21,11,23,12,25,13, + 27,14,29,15,31,16,33,17,35,18,37,19,39,20,41,21,43,22,45,23,47,24, + 49,25,51,26,53,27,55,28,57,29,59,30,61,31,63,32,65,33,67,34,69,35, + 71,36,73,37,75,38,77,39,79,40,81,41,83,42,85,43,87,0,89,0,91,44, + 93,45,95,46,97,47,99,48,101,49,1,0,24,2,0,77,77,109,109,2,0,65,65, + 97,97,2,0,83,83,115,115,2,0,73,73,105,105,2,0,78,78,110,110,2,0, + 69,69,101,101,2,0,82,82,114,114,2,0,84,84,116,116,2,0,80,80,112, + 112,2,0,85,85,117,117,2,0,79,79,111,111,2,0,86,86,118,118,2,0,89, + 89,121,121,2,0,67,67,99,99,2,0,68,68,100,100,2,0,87,87,119,119,2, + 0,70,70,102,102,2,0,66,66,98,98,2,0,76,76,108,108,2,0,71,71,103, + 103,1,0,48,57,2,0,65,90,97,122,4,0,48,57,65,90,95,95,97,122,4,0, + 9,10,13,13,32,32,38,38,410,0,1,1,0,0,0,0,3,1,0,0,0,0,5,1,0,0,0,0, + 7,1,0,0,0,0,9,1,0,0,0,0,11,1,0,0,0,0,13,1,0,0,0,0,15,1,0,0,0,0,17, + 1,0,0,0,0,19,1,0,0,0,0,21,1,0,0,0,0,23,1,0,0,0,0,25,1,0,0,0,0,27, + 1,0,0,0,0,29,1,0,0,0,0,31,1,0,0,0,0,33,1,0,0,0,0,35,1,0,0,0,0,37, + 1,0,0,0,0,39,1,0,0,0,0,41,1,0,0,0,0,43,1,0,0,0,0,45,1,0,0,0,0,47, + 1,0,0,0,0,49,1,0,0,0,0,51,1,0,0,0,0,53,1,0,0,0,0,55,1,0,0,0,0,57, + 1,0,0,0,0,59,1,0,0,0,0,61,1,0,0,0,0,63,1,0,0,0,0,65,1,0,0,0,0,67, + 1,0,0,0,0,69,1,0,0,0,0,71,1,0,0,0,0,73,1,0,0,0,0,75,1,0,0,0,0,77, + 1,0,0,0,0,79,1,0,0,0,0,81,1,0,0,0,0,83,1,0,0,0,0,85,1,0,0,0,0,91, + 1,0,0,0,0,93,1,0,0,0,0,95,1,0,0,0,0,97,1,0,0,0,0,99,1,0,0,0,0,101, + 1,0,0,0,1,103,1,0,0,0,3,105,1,0,0,0,5,107,1,0,0,0,7,109,1,0,0,0, + 9,112,1,0,0,0,11,115,1,0,0,0,13,118,1,0,0,0,15,121,1,0,0,0,17,124, + 1,0,0,0,19,127,1,0,0,0,21,129,1,0,0,0,23,131,1,0,0,0,25,133,1,0, + 0,0,27,135,1,0,0,0,29,137,1,0,0,0,31,139,1,0,0,0,33,141,1,0,0,0, + 35,143,1,0,0,0,37,145,1,0,0,0,39,147,1,0,0,0,41,149,1,0,0,0,43,151, + 1,0,0,0,45,154,1,0,0,0,47,158,1,0,0,0,49,160,1,0,0,0,51,162,1,0, + 0,0,53,164,1,0,0,0,55,169,1,0,0,0,57,177,1,0,0,0,59,185,1,0,0,0, + 61,192,1,0,0,0,63,197,1,0,0,0,65,208,1,0,0,0,67,215,1,0,0,0,69,225, + 1,0,0,0,71,233,1,0,0,0,73,241,1,0,0,0,75,252,1,0,0,0,77,260,1,0, + 0,0,79,271,1,0,0,0,81,283,1,0,0,0,83,293,1,0,0,0,85,304,1,0,0,0, + 87,324,1,0,0,0,89,327,1,0,0,0,91,330,1,0,0,0,93,352,1,0,0,0,95,363, + 1,0,0,0,97,365,1,0,0,0,99,379,1,0,0,0,101,387,1,0,0,0,103,104,5, + 91,0,0,104,2,1,0,0,0,105,106,5,93,0,0,106,4,1,0,0,0,107,108,5,61, + 0,0,108,6,1,0,0,0,109,110,5,43,0,0,110,111,5,61,0,0,111,8,1,0,0, + 0,112,113,5,45,0,0,113,114,5,61,0,0,114,10,1,0,0,0,115,116,5,58, + 0,0,116,117,5,61,0,0,117,12,1,0,0,0,118,119,5,42,0,0,119,120,5,61, + 0,0,120,14,1,0,0,0,121,122,5,47,0,0,122,123,5,61,0,0,123,16,1,0, + 0,0,124,125,5,94,0,0,125,126,5,61,0,0,126,18,1,0,0,0,127,128,5,44, + 0,0,128,20,1,0,0,0,129,130,5,39,0,0,130,22,1,0,0,0,131,132,5,40, + 0,0,132,24,1,0,0,0,133,134,5,41,0,0,134,26,1,0,0,0,135,136,5,123, + 0,0,136,28,1,0,0,0,137,138,5,125,0,0,138,30,1,0,0,0,139,140,5,58, + 0,0,140,32,1,0,0,0,141,142,5,43,0,0,142,34,1,0,0,0,143,144,5,45, + 0,0,144,36,1,0,0,0,145,146,5,59,0,0,146,38,1,0,0,0,147,148,5,46, + 0,0,148,40,1,0,0,0,149,150,5,62,0,0,150,42,1,0,0,0,151,152,5,48, + 0,0,152,153,5,62,0,0,153,44,1,0,0,0,154,155,5,49,0,0,155,156,5,62, + 0,0,156,157,5,62,0,0,157,46,1,0,0,0,158,159,5,94,0,0,159,48,1,0, + 0,0,160,161,5,42,0,0,161,50,1,0,0,0,162,163,5,47,0,0,163,52,1,0, + 0,0,164,165,7,0,0,0,165,166,7,1,0,0,166,167,7,2,0,0,167,168,7,2, + 0,0,168,54,1,0,0,0,169,170,7,3,0,0,170,171,7,4,0,0,171,172,7,5,0, + 0,172,173,7,6,0,0,173,174,7,7,0,0,174,175,7,3,0,0,175,176,7,1,0, + 0,176,56,1,0,0,0,177,178,7,3,0,0,178,179,7,4,0,0,179,180,7,8,0,0, + 180,181,7,9,0,0,181,183,7,7,0,0,182,184,7,2,0,0,183,182,1,0,0,0, + 183,184,1,0,0,0,184,58,1,0,0,0,185,186,7,10,0,0,186,187,7,9,0,0, + 187,188,7,7,0,0,188,189,7,8,0,0,189,190,7,9,0,0,190,191,7,7,0,0, + 191,60,1,0,0,0,192,193,7,2,0,0,193,194,7,1,0,0,194,195,7,11,0,0, + 195,196,7,5,0,0,196,62,1,0,0,0,197,198,7,9,0,0,198,199,7,4,0,0,199, + 200,7,3,0,0,200,201,7,7,0,0,201,202,7,2,0,0,202,203,7,12,0,0,203, + 204,7,2,0,0,204,205,7,7,0,0,205,206,7,5,0,0,206,207,7,0,0,0,207, + 64,1,0,0,0,208,209,7,5,0,0,209,210,7,4,0,0,210,211,7,13,0,0,211, + 212,7,10,0,0,212,213,7,14,0,0,213,214,7,5,0,0,214,66,1,0,0,0,215, + 216,7,4,0,0,216,217,7,5,0,0,217,218,7,15,0,0,218,219,7,7,0,0,219, + 220,7,10,0,0,220,221,7,4,0,0,221,222,7,3,0,0,222,223,7,1,0,0,223, + 224,7,4,0,0,224,68,1,0,0,0,225,226,7,16,0,0,226,227,7,6,0,0,227, + 228,7,1,0,0,228,229,7,0,0,0,229,231,7,5,0,0,230,232,7,2,0,0,231, + 230,1,0,0,0,231,232,1,0,0,0,232,70,1,0,0,0,233,234,7,17,0,0,234, + 235,7,10,0,0,235,236,7,14,0,0,236,237,7,3,0,0,237,239,7,5,0,0,238, + 240,7,2,0,0,239,238,1,0,0,0,239,240,1,0,0,0,240,72,1,0,0,0,241,242, + 7,8,0,0,242,243,7,1,0,0,243,244,7,6,0,0,244,245,7,7,0,0,245,246, + 7,3,0,0,246,247,7,13,0,0,247,248,7,18,0,0,248,250,7,5,0,0,249,251, + 7,2,0,0,250,249,1,0,0,0,250,251,1,0,0,0,251,74,1,0,0,0,252,253,7, + 8,0,0,253,254,7,10,0,0,254,255,7,3,0,0,255,256,7,4,0,0,256,258,7, + 7,0,0,257,259,7,2,0,0,258,257,1,0,0,0,258,259,1,0,0,0,259,76,1,0, + 0,0,260,261,7,13,0,0,261,262,7,10,0,0,262,263,7,4,0,0,263,264,7, + 2,0,0,264,265,7,7,0,0,265,266,7,1,0,0,266,267,7,4,0,0,267,269,7, + 7,0,0,268,270,7,2,0,0,269,268,1,0,0,0,269,270,1,0,0,0,270,78,1,0, + 0,0,271,272,7,2,0,0,272,273,7,8,0,0,273,274,7,5,0,0,274,275,7,13, + 0,0,275,276,7,3,0,0,276,277,7,16,0,0,277,278,7,3,0,0,278,279,7,5, + 0,0,279,281,7,14,0,0,280,282,7,2,0,0,281,280,1,0,0,0,281,282,1,0, + 0,0,282,80,1,0,0,0,283,284,7,3,0,0,284,285,7,0,0,0,285,286,7,1,0, + 0,286,287,7,19,0,0,287,288,7,3,0,0,288,289,7,4,0,0,289,290,7,1,0, + 0,290,291,7,6,0,0,291,292,7,12,0,0,292,82,1,0,0,0,293,294,7,11,0, + 0,294,295,7,1,0,0,295,296,7,6,0,0,296,297,7,3,0,0,297,298,7,1,0, + 0,298,299,7,17,0,0,299,300,7,18,0,0,300,302,7,5,0,0,301,303,7,2, + 0,0,302,301,1,0,0,0,302,303,1,0,0,0,303,84,1,0,0,0,304,305,7,0,0, + 0,305,306,7,10,0,0,306,307,7,7,0,0,307,308,7,3,0,0,308,309,7,10, + 0,0,309,310,7,4,0,0,310,311,7,11,0,0,311,312,7,1,0,0,312,313,7,6, + 0,0,313,314,7,3,0,0,314,315,7,1,0,0,315,316,7,17,0,0,316,317,7,18, + 0,0,317,319,7,5,0,0,318,320,7,2,0,0,319,318,1,0,0,0,319,320,1,0, + 0,0,320,86,1,0,0,0,321,323,5,39,0,0,322,321,1,0,0,0,323,326,1,0, + 0,0,324,322,1,0,0,0,324,325,1,0,0,0,325,88,1,0,0,0,326,324,1,0,0, + 0,327,328,7,20,0,0,328,90,1,0,0,0,329,331,7,20,0,0,330,329,1,0,0, + 0,331,332,1,0,0,0,332,330,1,0,0,0,332,333,1,0,0,0,333,92,1,0,0,0, + 334,336,3,89,44,0,335,334,1,0,0,0,336,337,1,0,0,0,337,335,1,0,0, + 0,337,338,1,0,0,0,338,339,1,0,0,0,339,343,5,46,0,0,340,342,3,89, + 44,0,341,340,1,0,0,0,342,345,1,0,0,0,343,341,1,0,0,0,343,344,1,0, + 0,0,344,353,1,0,0,0,345,343,1,0,0,0,346,348,5,46,0,0,347,349,3,89, + 44,0,348,347,1,0,0,0,349,350,1,0,0,0,350,348,1,0,0,0,350,351,1,0, + 0,0,351,353,1,0,0,0,352,335,1,0,0,0,352,346,1,0,0,0,353,94,1,0,0, + 0,354,355,3,93,46,0,355,356,5,69,0,0,356,357,3,91,45,0,357,364,1, + 0,0,0,358,359,3,93,46,0,359,360,5,69,0,0,360,361,5,45,0,0,361,362, + 3,91,45,0,362,364,1,0,0,0,363,354,1,0,0,0,363,358,1,0,0,0,364,96, + 1,0,0,0,365,369,5,37,0,0,366,368,9,0,0,0,367,366,1,0,0,0,368,371, + 1,0,0,0,369,370,1,0,0,0,369,367,1,0,0,0,370,373,1,0,0,0,371,369, + 1,0,0,0,372,374,5,13,0,0,373,372,1,0,0,0,373,374,1,0,0,0,374,375, + 1,0,0,0,375,376,5,10,0,0,376,377,1,0,0,0,377,378,6,48,0,0,378,98, + 1,0,0,0,379,383,7,21,0,0,380,382,7,22,0,0,381,380,1,0,0,0,382,385, + 1,0,0,0,383,381,1,0,0,0,383,384,1,0,0,0,384,100,1,0,0,0,385,383, + 1,0,0,0,386,388,7,23,0,0,387,386,1,0,0,0,388,389,1,0,0,0,389,387, + 1,0,0,0,389,390,1,0,0,0,390,391,1,0,0,0,391,392,6,50,0,0,392,102, + 1,0,0,0,21,0,183,231,239,250,258,269,281,302,319,324,332,337,343, + 350,352,363,369,373,383,389,1,6,0,0 + ] + +class AutolevLexer(Lexer): + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + T__0 = 1 + T__1 = 2 + T__2 = 3 + T__3 = 4 + T__4 = 5 + T__5 = 6 + T__6 = 7 + T__7 = 8 + T__8 = 9 + T__9 = 10 + T__10 = 11 + T__11 = 12 + T__12 = 13 + T__13 = 14 + T__14 = 15 + T__15 = 16 + T__16 = 17 + T__17 = 18 + T__18 = 19 + T__19 = 20 + T__20 = 21 + T__21 = 22 + T__22 = 23 + T__23 = 24 + T__24 = 25 + T__25 = 26 + Mass = 27 + Inertia = 28 + Input = 29 + Output = 30 + Save = 31 + UnitSystem = 32 + Encode = 33 + Newtonian = 34 + Frames = 35 + Bodies = 36 + Particles = 37 + Points = 38 + Constants = 39 + Specifieds = 40 + Imaginary = 41 + Variables = 42 + MotionVariables = 43 + INT = 44 + FLOAT = 45 + EXP = 46 + LINE_COMMENT = 47 + ID = 48 + WS = 49 + + channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] + + modeNames = [ "DEFAULT_MODE" ] + + literalNames = [ "", + "'['", "']'", "'='", "'+='", "'-='", "':='", "'*='", "'/='", + "'^='", "','", "'''", "'('", "')'", "'{'", "'}'", "':'", "'+'", + "'-'", "';'", "'.'", "'>'", "'0>'", "'1>>'", "'^'", "'*'", "'/'" ] + + symbolicNames = [ "", + "Mass", "Inertia", "Input", "Output", "Save", "UnitSystem", + "Encode", "Newtonian", "Frames", "Bodies", "Particles", "Points", + "Constants", "Specifieds", "Imaginary", "Variables", "MotionVariables", + "INT", "FLOAT", "EXP", "LINE_COMMENT", "ID", "WS" ] + + ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", + "T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13", + "T__14", "T__15", "T__16", "T__17", "T__18", "T__19", + "T__20", "T__21", "T__22", "T__23", "T__24", "T__25", + "Mass", "Inertia", "Input", "Output", "Save", "UnitSystem", + "Encode", "Newtonian", "Frames", "Bodies", "Particles", + "Points", "Constants", "Specifieds", "Imaginary", "Variables", + "MotionVariables", "DIFF", "DIGIT", "INT", "FLOAT", "EXP", + "LINE_COMMENT", "ID", "WS" ] + + grammarFileName = "Autolev.g4" + + def __init__(self, input=None, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.11.1") + self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) + self._actions = None + self._predicates = None + + diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_antlr/autolevlistener.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_antlr/autolevlistener.py new file mode 100644 index 0000000000000000000000000000000000000000..6f391a298a71ecf2d04cf921a919cbb68b181fab --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_antlr/autolevlistener.py @@ -0,0 +1,421 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +if __name__ is not None and "." in __name__: + from .autolevparser import AutolevParser +else: + from autolevparser import AutolevParser + +# This class defines a complete listener for a parse tree produced by AutolevParser. +class AutolevListener(ParseTreeListener): + + # Enter a parse tree produced by AutolevParser#prog. + def enterProg(self, ctx:AutolevParser.ProgContext): + pass + + # Exit a parse tree produced by AutolevParser#prog. + def exitProg(self, ctx:AutolevParser.ProgContext): + pass + + + # Enter a parse tree produced by AutolevParser#stat. + def enterStat(self, ctx:AutolevParser.StatContext): + pass + + # Exit a parse tree produced by AutolevParser#stat. + def exitStat(self, ctx:AutolevParser.StatContext): + pass + + + # Enter a parse tree produced by AutolevParser#vecAssign. + def enterVecAssign(self, ctx:AutolevParser.VecAssignContext): + pass + + # Exit a parse tree produced by AutolevParser#vecAssign. + def exitVecAssign(self, ctx:AutolevParser.VecAssignContext): + pass + + + # Enter a parse tree produced by AutolevParser#indexAssign. + def enterIndexAssign(self, ctx:AutolevParser.IndexAssignContext): + pass + + # Exit a parse tree produced by AutolevParser#indexAssign. + def exitIndexAssign(self, ctx:AutolevParser.IndexAssignContext): + pass + + + # Enter a parse tree produced by AutolevParser#regularAssign. + def enterRegularAssign(self, ctx:AutolevParser.RegularAssignContext): + pass + + # Exit a parse tree produced by AutolevParser#regularAssign. + def exitRegularAssign(self, ctx:AutolevParser.RegularAssignContext): + pass + + + # Enter a parse tree produced by AutolevParser#equals. + def enterEquals(self, ctx:AutolevParser.EqualsContext): + pass + + # Exit a parse tree produced by AutolevParser#equals. + def exitEquals(self, ctx:AutolevParser.EqualsContext): + pass + + + # Enter a parse tree produced by AutolevParser#index. + def enterIndex(self, ctx:AutolevParser.IndexContext): + pass + + # Exit a parse tree produced by AutolevParser#index. + def exitIndex(self, ctx:AutolevParser.IndexContext): + pass + + + # Enter a parse tree produced by AutolevParser#diff. + def enterDiff(self, ctx:AutolevParser.DiffContext): + pass + + # Exit a parse tree produced by AutolevParser#diff. + def exitDiff(self, ctx:AutolevParser.DiffContext): + pass + + + # Enter a parse tree produced by AutolevParser#functionCall. + def enterFunctionCall(self, ctx:AutolevParser.FunctionCallContext): + pass + + # Exit a parse tree produced by AutolevParser#functionCall. + def exitFunctionCall(self, ctx:AutolevParser.FunctionCallContext): + pass + + + # Enter a parse tree produced by AutolevParser#varDecl. + def enterVarDecl(self, ctx:AutolevParser.VarDeclContext): + pass + + # Exit a parse tree produced by AutolevParser#varDecl. + def exitVarDecl(self, ctx:AutolevParser.VarDeclContext): + pass + + + # Enter a parse tree produced by AutolevParser#varType. + def enterVarType(self, ctx:AutolevParser.VarTypeContext): + pass + + # Exit a parse tree produced by AutolevParser#varType. + def exitVarType(self, ctx:AutolevParser.VarTypeContext): + pass + + + # Enter a parse tree produced by AutolevParser#varDecl2. + def enterVarDecl2(self, ctx:AutolevParser.VarDecl2Context): + pass + + # Exit a parse tree produced by AutolevParser#varDecl2. + def exitVarDecl2(self, ctx:AutolevParser.VarDecl2Context): + pass + + + # Enter a parse tree produced by AutolevParser#ranges. + def enterRanges(self, ctx:AutolevParser.RangesContext): + pass + + # Exit a parse tree produced by AutolevParser#ranges. + def exitRanges(self, ctx:AutolevParser.RangesContext): + pass + + + # Enter a parse tree produced by AutolevParser#massDecl. + def enterMassDecl(self, ctx:AutolevParser.MassDeclContext): + pass + + # Exit a parse tree produced by AutolevParser#massDecl. + def exitMassDecl(self, ctx:AutolevParser.MassDeclContext): + pass + + + # Enter a parse tree produced by AutolevParser#massDecl2. + def enterMassDecl2(self, ctx:AutolevParser.MassDecl2Context): + pass + + # Exit a parse tree produced by AutolevParser#massDecl2. + def exitMassDecl2(self, ctx:AutolevParser.MassDecl2Context): + pass + + + # Enter a parse tree produced by AutolevParser#inertiaDecl. + def enterInertiaDecl(self, ctx:AutolevParser.InertiaDeclContext): + pass + + # Exit a parse tree produced by AutolevParser#inertiaDecl. + def exitInertiaDecl(self, ctx:AutolevParser.InertiaDeclContext): + pass + + + # Enter a parse tree produced by AutolevParser#matrix. + def enterMatrix(self, ctx:AutolevParser.MatrixContext): + pass + + # Exit a parse tree produced by AutolevParser#matrix. + def exitMatrix(self, ctx:AutolevParser.MatrixContext): + pass + + + # Enter a parse tree produced by AutolevParser#matrixInOutput. + def enterMatrixInOutput(self, ctx:AutolevParser.MatrixInOutputContext): + pass + + # Exit a parse tree produced by AutolevParser#matrixInOutput. + def exitMatrixInOutput(self, ctx:AutolevParser.MatrixInOutputContext): + pass + + + # Enter a parse tree produced by AutolevParser#codeCommands. + def enterCodeCommands(self, ctx:AutolevParser.CodeCommandsContext): + pass + + # Exit a parse tree produced by AutolevParser#codeCommands. + def exitCodeCommands(self, ctx:AutolevParser.CodeCommandsContext): + pass + + + # Enter a parse tree produced by AutolevParser#settings. + def enterSettings(self, ctx:AutolevParser.SettingsContext): + pass + + # Exit a parse tree produced by AutolevParser#settings. + def exitSettings(self, ctx:AutolevParser.SettingsContext): + pass + + + # Enter a parse tree produced by AutolevParser#units. + def enterUnits(self, ctx:AutolevParser.UnitsContext): + pass + + # Exit a parse tree produced by AutolevParser#units. + def exitUnits(self, ctx:AutolevParser.UnitsContext): + pass + + + # Enter a parse tree produced by AutolevParser#inputs. + def enterInputs(self, ctx:AutolevParser.InputsContext): + pass + + # Exit a parse tree produced by AutolevParser#inputs. + def exitInputs(self, ctx:AutolevParser.InputsContext): + pass + + + # Enter a parse tree produced by AutolevParser#id_diff. + def enterId_diff(self, ctx:AutolevParser.Id_diffContext): + pass + + # Exit a parse tree produced by AutolevParser#id_diff. + def exitId_diff(self, ctx:AutolevParser.Id_diffContext): + pass + + + # Enter a parse tree produced by AutolevParser#inputs2. + def enterInputs2(self, ctx:AutolevParser.Inputs2Context): + pass + + # Exit a parse tree produced by AutolevParser#inputs2. + def exitInputs2(self, ctx:AutolevParser.Inputs2Context): + pass + + + # Enter a parse tree produced by AutolevParser#outputs. + def enterOutputs(self, ctx:AutolevParser.OutputsContext): + pass + + # Exit a parse tree produced by AutolevParser#outputs. + def exitOutputs(self, ctx:AutolevParser.OutputsContext): + pass + + + # Enter a parse tree produced by AutolevParser#outputs2. + def enterOutputs2(self, ctx:AutolevParser.Outputs2Context): + pass + + # Exit a parse tree produced by AutolevParser#outputs2. + def exitOutputs2(self, ctx:AutolevParser.Outputs2Context): + pass + + + # Enter a parse tree produced by AutolevParser#codegen. + def enterCodegen(self, ctx:AutolevParser.CodegenContext): + pass + + # Exit a parse tree produced by AutolevParser#codegen. + def exitCodegen(self, ctx:AutolevParser.CodegenContext): + pass + + + # Enter a parse tree produced by AutolevParser#commands. + def enterCommands(self, ctx:AutolevParser.CommandsContext): + pass + + # Exit a parse tree produced by AutolevParser#commands. + def exitCommands(self, ctx:AutolevParser.CommandsContext): + pass + + + # Enter a parse tree produced by AutolevParser#vec. + def enterVec(self, ctx:AutolevParser.VecContext): + pass + + # Exit a parse tree produced by AutolevParser#vec. + def exitVec(self, ctx:AutolevParser.VecContext): + pass + + + # Enter a parse tree produced by AutolevParser#parens. + def enterParens(self, ctx:AutolevParser.ParensContext): + pass + + # Exit a parse tree produced by AutolevParser#parens. + def exitParens(self, ctx:AutolevParser.ParensContext): + pass + + + # Enter a parse tree produced by AutolevParser#VectorOrDyadic. + def enterVectorOrDyadic(self, ctx:AutolevParser.VectorOrDyadicContext): + pass + + # Exit a parse tree produced by AutolevParser#VectorOrDyadic. + def exitVectorOrDyadic(self, ctx:AutolevParser.VectorOrDyadicContext): + pass + + + # Enter a parse tree produced by AutolevParser#Exponent. + def enterExponent(self, ctx:AutolevParser.ExponentContext): + pass + + # Exit a parse tree produced by AutolevParser#Exponent. + def exitExponent(self, ctx:AutolevParser.ExponentContext): + pass + + + # Enter a parse tree produced by AutolevParser#MulDiv. + def enterMulDiv(self, ctx:AutolevParser.MulDivContext): + pass + + # Exit a parse tree produced by AutolevParser#MulDiv. + def exitMulDiv(self, ctx:AutolevParser.MulDivContext): + pass + + + # Enter a parse tree produced by AutolevParser#AddSub. + def enterAddSub(self, ctx:AutolevParser.AddSubContext): + pass + + # Exit a parse tree produced by AutolevParser#AddSub. + def exitAddSub(self, ctx:AutolevParser.AddSubContext): + pass + + + # Enter a parse tree produced by AutolevParser#float. + def enterFloat(self, ctx:AutolevParser.FloatContext): + pass + + # Exit a parse tree produced by AutolevParser#float. + def exitFloat(self, ctx:AutolevParser.FloatContext): + pass + + + # Enter a parse tree produced by AutolevParser#int. + def enterInt(self, ctx:AutolevParser.IntContext): + pass + + # Exit a parse tree produced by AutolevParser#int. + def exitInt(self, ctx:AutolevParser.IntContext): + pass + + + # Enter a parse tree produced by AutolevParser#idEqualsExpr. + def enterIdEqualsExpr(self, ctx:AutolevParser.IdEqualsExprContext): + pass + + # Exit a parse tree produced by AutolevParser#idEqualsExpr. + def exitIdEqualsExpr(self, ctx:AutolevParser.IdEqualsExprContext): + pass + + + # Enter a parse tree produced by AutolevParser#negativeOne. + def enterNegativeOne(self, ctx:AutolevParser.NegativeOneContext): + pass + + # Exit a parse tree produced by AutolevParser#negativeOne. + def exitNegativeOne(self, ctx:AutolevParser.NegativeOneContext): + pass + + + # Enter a parse tree produced by AutolevParser#function. + def enterFunction(self, ctx:AutolevParser.FunctionContext): + pass + + # Exit a parse tree produced by AutolevParser#function. + def exitFunction(self, ctx:AutolevParser.FunctionContext): + pass + + + # Enter a parse tree produced by AutolevParser#rangess. + def enterRangess(self, ctx:AutolevParser.RangessContext): + pass + + # Exit a parse tree produced by AutolevParser#rangess. + def exitRangess(self, ctx:AutolevParser.RangessContext): + pass + + + # Enter a parse tree produced by AutolevParser#colon. + def enterColon(self, ctx:AutolevParser.ColonContext): + pass + + # Exit a parse tree produced by AutolevParser#colon. + def exitColon(self, ctx:AutolevParser.ColonContext): + pass + + + # Enter a parse tree produced by AutolevParser#id. + def enterId(self, ctx:AutolevParser.IdContext): + pass + + # Exit a parse tree produced by AutolevParser#id. + def exitId(self, ctx:AutolevParser.IdContext): + pass + + + # Enter a parse tree produced by AutolevParser#exp. + def enterExp(self, ctx:AutolevParser.ExpContext): + pass + + # Exit a parse tree produced by AutolevParser#exp. + def exitExp(self, ctx:AutolevParser.ExpContext): + pass + + + # Enter a parse tree produced by AutolevParser#matrices. + def enterMatrices(self, ctx:AutolevParser.MatricesContext): + pass + + # Exit a parse tree produced by AutolevParser#matrices. + def exitMatrices(self, ctx:AutolevParser.MatricesContext): + pass + + + # Enter a parse tree produced by AutolevParser#Indexing. + def enterIndexing(self, ctx:AutolevParser.IndexingContext): + pass + + # Exit a parse tree produced by AutolevParser#Indexing. + def exitIndexing(self, ctx:AutolevParser.IndexingContext): + pass + + + +del AutolevParser diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_antlr/autolevparser.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_antlr/autolevparser.py new file mode 100644 index 0000000000000000000000000000000000000000..e63ef1c110812580d06291ee7c7ec40b6a076cea --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_antlr/autolevparser.py @@ -0,0 +1,3063 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +from io import StringIO +import sys +if sys.version_info[1] > 5: + from typing import TextIO +else: + from typing.io import TextIO + +def serializedATN(): + return [ + 4,1,49,431,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6,7, + 6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2,13,7,13, + 2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7,19,2,20, + 7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2,26,7,26, + 2,27,7,27,1,0,4,0,58,8,0,11,0,12,0,59,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,3,1,69,8,1,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2, + 3,2,84,8,2,1,2,1,2,1,2,3,2,89,8,2,1,3,1,3,1,4,1,4,1,4,5,4,96,8,4, + 10,4,12,4,99,9,4,1,5,4,5,102,8,5,11,5,12,5,103,1,6,1,6,1,6,1,6,1, + 6,5,6,111,8,6,10,6,12,6,114,9,6,3,6,116,8,6,1,6,1,6,1,6,1,6,1,6, + 1,6,5,6,124,8,6,10,6,12,6,127,9,6,3,6,129,8,6,1,6,3,6,132,8,6,1, + 7,1,7,1,7,1,7,5,7,138,8,7,10,7,12,7,141,9,7,1,8,1,8,1,8,1,8,1,8, + 1,8,1,8,1,8,1,8,1,8,5,8,153,8,8,10,8,12,8,156,9,8,1,8,1,8,5,8,160, + 8,8,10,8,12,8,163,9,8,3,8,165,8,8,1,9,1,9,1,9,1,9,1,9,1,9,3,9,173, + 8,9,1,9,1,9,1,9,1,9,1,9,1,9,1,9,1,9,5,9,183,8,9,10,9,12,9,186,9, + 9,1,9,3,9,189,8,9,1,9,1,9,1,9,3,9,194,8,9,1,9,3,9,197,8,9,1,9,5, + 9,200,8,9,10,9,12,9,203,9,9,1,9,1,9,3,9,207,8,9,1,10,1,10,1,10,1, + 10,1,10,1,10,1,10,1,10,5,10,217,8,10,10,10,12,10,220,9,10,1,10,1, + 10,1,11,1,11,1,11,1,11,5,11,228,8,11,10,11,12,11,231,9,11,1,12,1, + 12,1,12,1,12,1,13,1,13,1,13,1,13,1,13,3,13,242,8,13,1,13,1,13,4, + 13,246,8,13,11,13,12,13,247,1,14,1,14,1,14,1,14,5,14,254,8,14,10, + 14,12,14,257,9,14,1,14,1,14,1,15,1,15,1,15,1,15,3,15,265,8,15,1, + 15,1,15,3,15,269,8,15,1,16,1,16,1,16,1,16,1,16,3,16,276,8,16,1,17, + 1,17,3,17,280,8,17,1,18,1,18,1,18,1,18,5,18,286,8,18,10,18,12,18, + 289,9,18,1,19,1,19,1,19,1,19,5,19,295,8,19,10,19,12,19,298,9,19, + 1,20,1,20,3,20,302,8,20,1,21,1,21,1,21,1,21,3,21,308,8,21,1,22,1, + 22,1,22,1,22,5,22,314,8,22,10,22,12,22,317,9,22,1,23,1,23,3,23,321, + 8,23,1,24,1,24,1,24,1,24,1,24,1,24,5,24,329,8,24,10,24,12,24,332, + 9,24,1,24,1,24,3,24,336,8,24,1,24,1,24,1,24,1,24,1,25,1,25,1,25, + 1,25,1,25,1,25,1,25,1,25,5,25,350,8,25,10,25,12,25,353,9,25,3,25, + 355,8,25,1,26,1,26,4,26,359,8,26,11,26,12,26,360,1,26,1,26,3,26, + 365,8,26,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,5,27,375,8,27,10, + 27,12,27,378,9,27,1,27,1,27,1,27,1,27,1,27,1,27,5,27,386,8,27,10, + 27,12,27,389,9,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,3, + 27,400,8,27,1,27,1,27,5,27,404,8,27,10,27,12,27,407,9,27,3,27,409, + 8,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27, + 1,27,1,27,1,27,5,27,426,8,27,10,27,12,27,429,9,27,1,27,0,1,54,28, + 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44, + 46,48,50,52,54,0,7,1,0,3,9,1,0,27,28,1,0,17,18,2,0,10,10,19,19,1, + 0,44,45,2,0,44,46,48,48,1,0,25,26,483,0,57,1,0,0,0,2,68,1,0,0,0, + 4,88,1,0,0,0,6,90,1,0,0,0,8,92,1,0,0,0,10,101,1,0,0,0,12,131,1,0, + 0,0,14,133,1,0,0,0,16,164,1,0,0,0,18,166,1,0,0,0,20,208,1,0,0,0, + 22,223,1,0,0,0,24,232,1,0,0,0,26,236,1,0,0,0,28,249,1,0,0,0,30,268, + 1,0,0,0,32,275,1,0,0,0,34,277,1,0,0,0,36,281,1,0,0,0,38,290,1,0, + 0,0,40,299,1,0,0,0,42,303,1,0,0,0,44,309,1,0,0,0,46,318,1,0,0,0, + 48,322,1,0,0,0,50,354,1,0,0,0,52,364,1,0,0,0,54,408,1,0,0,0,56,58, + 3,2,1,0,57,56,1,0,0,0,58,59,1,0,0,0,59,57,1,0,0,0,59,60,1,0,0,0, + 60,1,1,0,0,0,61,69,3,14,7,0,62,69,3,12,6,0,63,69,3,32,16,0,64,69, + 3,22,11,0,65,69,3,26,13,0,66,69,3,4,2,0,67,69,3,34,17,0,68,61,1, + 0,0,0,68,62,1,0,0,0,68,63,1,0,0,0,68,64,1,0,0,0,68,65,1,0,0,0,68, + 66,1,0,0,0,68,67,1,0,0,0,69,3,1,0,0,0,70,71,3,52,26,0,71,72,3,6, + 3,0,72,73,3,54,27,0,73,89,1,0,0,0,74,75,5,48,0,0,75,76,5,1,0,0,76, + 77,3,8,4,0,77,78,5,2,0,0,78,79,3,6,3,0,79,80,3,54,27,0,80,89,1,0, + 0,0,81,83,5,48,0,0,82,84,3,10,5,0,83,82,1,0,0,0,83,84,1,0,0,0,84, + 85,1,0,0,0,85,86,3,6,3,0,86,87,3,54,27,0,87,89,1,0,0,0,88,70,1,0, + 0,0,88,74,1,0,0,0,88,81,1,0,0,0,89,5,1,0,0,0,90,91,7,0,0,0,91,7, + 1,0,0,0,92,97,3,54,27,0,93,94,5,10,0,0,94,96,3,54,27,0,95,93,1,0, + 0,0,96,99,1,0,0,0,97,95,1,0,0,0,97,98,1,0,0,0,98,9,1,0,0,0,99,97, + 1,0,0,0,100,102,5,11,0,0,101,100,1,0,0,0,102,103,1,0,0,0,103,101, + 1,0,0,0,103,104,1,0,0,0,104,11,1,0,0,0,105,106,5,48,0,0,106,115, + 5,12,0,0,107,112,3,54,27,0,108,109,5,10,0,0,109,111,3,54,27,0,110, + 108,1,0,0,0,111,114,1,0,0,0,112,110,1,0,0,0,112,113,1,0,0,0,113, + 116,1,0,0,0,114,112,1,0,0,0,115,107,1,0,0,0,115,116,1,0,0,0,116, + 117,1,0,0,0,117,132,5,13,0,0,118,119,7,1,0,0,119,128,5,12,0,0,120, + 125,5,48,0,0,121,122,5,10,0,0,122,124,5,48,0,0,123,121,1,0,0,0,124, + 127,1,0,0,0,125,123,1,0,0,0,125,126,1,0,0,0,126,129,1,0,0,0,127, + 125,1,0,0,0,128,120,1,0,0,0,128,129,1,0,0,0,129,130,1,0,0,0,130, + 132,5,13,0,0,131,105,1,0,0,0,131,118,1,0,0,0,132,13,1,0,0,0,133, + 134,3,16,8,0,134,139,3,18,9,0,135,136,5,10,0,0,136,138,3,18,9,0, + 137,135,1,0,0,0,138,141,1,0,0,0,139,137,1,0,0,0,139,140,1,0,0,0, + 140,15,1,0,0,0,141,139,1,0,0,0,142,165,5,34,0,0,143,165,5,35,0,0, + 144,165,5,36,0,0,145,165,5,37,0,0,146,165,5,38,0,0,147,165,5,39, + 0,0,148,165,5,40,0,0,149,165,5,41,0,0,150,154,5,42,0,0,151,153,5, + 11,0,0,152,151,1,0,0,0,153,156,1,0,0,0,154,152,1,0,0,0,154,155,1, + 0,0,0,155,165,1,0,0,0,156,154,1,0,0,0,157,161,5,43,0,0,158,160,5, + 11,0,0,159,158,1,0,0,0,160,163,1,0,0,0,161,159,1,0,0,0,161,162,1, + 0,0,0,162,165,1,0,0,0,163,161,1,0,0,0,164,142,1,0,0,0,164,143,1, + 0,0,0,164,144,1,0,0,0,164,145,1,0,0,0,164,146,1,0,0,0,164,147,1, + 0,0,0,164,148,1,0,0,0,164,149,1,0,0,0,164,150,1,0,0,0,164,157,1, + 0,0,0,165,17,1,0,0,0,166,172,5,48,0,0,167,168,5,14,0,0,168,169,5, + 44,0,0,169,170,5,10,0,0,170,171,5,44,0,0,171,173,5,15,0,0,172,167, + 1,0,0,0,172,173,1,0,0,0,173,188,1,0,0,0,174,175,5,14,0,0,175,176, + 5,44,0,0,176,177,5,16,0,0,177,184,5,44,0,0,178,179,5,10,0,0,179, + 180,5,44,0,0,180,181,5,16,0,0,181,183,5,44,0,0,182,178,1,0,0,0,183, + 186,1,0,0,0,184,182,1,0,0,0,184,185,1,0,0,0,185,187,1,0,0,0,186, + 184,1,0,0,0,187,189,5,15,0,0,188,174,1,0,0,0,188,189,1,0,0,0,189, + 193,1,0,0,0,190,191,5,14,0,0,191,192,5,44,0,0,192,194,5,15,0,0,193, + 190,1,0,0,0,193,194,1,0,0,0,194,196,1,0,0,0,195,197,7,2,0,0,196, + 195,1,0,0,0,196,197,1,0,0,0,197,201,1,0,0,0,198,200,5,11,0,0,199, + 198,1,0,0,0,200,203,1,0,0,0,201,199,1,0,0,0,201,202,1,0,0,0,202, + 206,1,0,0,0,203,201,1,0,0,0,204,205,5,3,0,0,205,207,3,54,27,0,206, + 204,1,0,0,0,206,207,1,0,0,0,207,19,1,0,0,0,208,209,5,14,0,0,209, + 210,5,44,0,0,210,211,5,16,0,0,211,218,5,44,0,0,212,213,5,10,0,0, + 213,214,5,44,0,0,214,215,5,16,0,0,215,217,5,44,0,0,216,212,1,0,0, + 0,217,220,1,0,0,0,218,216,1,0,0,0,218,219,1,0,0,0,219,221,1,0,0, + 0,220,218,1,0,0,0,221,222,5,15,0,0,222,21,1,0,0,0,223,224,5,27,0, + 0,224,229,3,24,12,0,225,226,5,10,0,0,226,228,3,24,12,0,227,225,1, + 0,0,0,228,231,1,0,0,0,229,227,1,0,0,0,229,230,1,0,0,0,230,23,1,0, + 0,0,231,229,1,0,0,0,232,233,5,48,0,0,233,234,5,3,0,0,234,235,3,54, + 27,0,235,25,1,0,0,0,236,237,5,28,0,0,237,241,5,48,0,0,238,239,5, + 12,0,0,239,240,5,48,0,0,240,242,5,13,0,0,241,238,1,0,0,0,241,242, + 1,0,0,0,242,245,1,0,0,0,243,244,5,10,0,0,244,246,3,54,27,0,245,243, + 1,0,0,0,246,247,1,0,0,0,247,245,1,0,0,0,247,248,1,0,0,0,248,27,1, + 0,0,0,249,250,5,1,0,0,250,255,3,54,27,0,251,252,7,3,0,0,252,254, + 3,54,27,0,253,251,1,0,0,0,254,257,1,0,0,0,255,253,1,0,0,0,255,256, + 1,0,0,0,256,258,1,0,0,0,257,255,1,0,0,0,258,259,5,2,0,0,259,29,1, + 0,0,0,260,261,5,48,0,0,261,262,5,48,0,0,262,264,5,3,0,0,263,265, + 7,4,0,0,264,263,1,0,0,0,264,265,1,0,0,0,265,269,1,0,0,0,266,269, + 5,45,0,0,267,269,5,44,0,0,268,260,1,0,0,0,268,266,1,0,0,0,268,267, + 1,0,0,0,269,31,1,0,0,0,270,276,3,36,18,0,271,276,3,38,19,0,272,276, + 3,44,22,0,273,276,3,48,24,0,274,276,3,50,25,0,275,270,1,0,0,0,275, + 271,1,0,0,0,275,272,1,0,0,0,275,273,1,0,0,0,275,274,1,0,0,0,276, + 33,1,0,0,0,277,279,5,48,0,0,278,280,7,5,0,0,279,278,1,0,0,0,279, + 280,1,0,0,0,280,35,1,0,0,0,281,282,5,32,0,0,282,287,5,48,0,0,283, + 284,5,10,0,0,284,286,5,48,0,0,285,283,1,0,0,0,286,289,1,0,0,0,287, + 285,1,0,0,0,287,288,1,0,0,0,288,37,1,0,0,0,289,287,1,0,0,0,290,291, + 5,29,0,0,291,296,3,42,21,0,292,293,5,10,0,0,293,295,3,42,21,0,294, + 292,1,0,0,0,295,298,1,0,0,0,296,294,1,0,0,0,296,297,1,0,0,0,297, + 39,1,0,0,0,298,296,1,0,0,0,299,301,5,48,0,0,300,302,3,10,5,0,301, + 300,1,0,0,0,301,302,1,0,0,0,302,41,1,0,0,0,303,304,3,40,20,0,304, + 305,5,3,0,0,305,307,3,54,27,0,306,308,3,54,27,0,307,306,1,0,0,0, + 307,308,1,0,0,0,308,43,1,0,0,0,309,310,5,30,0,0,310,315,3,46,23, + 0,311,312,5,10,0,0,312,314,3,46,23,0,313,311,1,0,0,0,314,317,1,0, + 0,0,315,313,1,0,0,0,315,316,1,0,0,0,316,45,1,0,0,0,317,315,1,0,0, + 0,318,320,3,54,27,0,319,321,3,54,27,0,320,319,1,0,0,0,320,321,1, + 0,0,0,321,47,1,0,0,0,322,323,5,48,0,0,323,335,3,12,6,0,324,325,5, + 1,0,0,325,330,3,30,15,0,326,327,5,10,0,0,327,329,3,30,15,0,328,326, + 1,0,0,0,329,332,1,0,0,0,330,328,1,0,0,0,330,331,1,0,0,0,331,333, + 1,0,0,0,332,330,1,0,0,0,333,334,5,2,0,0,334,336,1,0,0,0,335,324, + 1,0,0,0,335,336,1,0,0,0,336,337,1,0,0,0,337,338,5,48,0,0,338,339, + 5,20,0,0,339,340,5,48,0,0,340,49,1,0,0,0,341,342,5,31,0,0,342,343, + 5,48,0,0,343,344,5,20,0,0,344,355,5,48,0,0,345,346,5,33,0,0,346, + 351,5,48,0,0,347,348,5,10,0,0,348,350,5,48,0,0,349,347,1,0,0,0,350, + 353,1,0,0,0,351,349,1,0,0,0,351,352,1,0,0,0,352,355,1,0,0,0,353, + 351,1,0,0,0,354,341,1,0,0,0,354,345,1,0,0,0,355,51,1,0,0,0,356,358, + 5,48,0,0,357,359,5,21,0,0,358,357,1,0,0,0,359,360,1,0,0,0,360,358, + 1,0,0,0,360,361,1,0,0,0,361,365,1,0,0,0,362,365,5,22,0,0,363,365, + 5,23,0,0,364,356,1,0,0,0,364,362,1,0,0,0,364,363,1,0,0,0,365,53, + 1,0,0,0,366,367,6,27,-1,0,367,409,5,46,0,0,368,369,5,18,0,0,369, + 409,3,54,27,12,370,409,5,45,0,0,371,409,5,44,0,0,372,376,5,48,0, + 0,373,375,5,11,0,0,374,373,1,0,0,0,375,378,1,0,0,0,376,374,1,0,0, + 0,376,377,1,0,0,0,377,409,1,0,0,0,378,376,1,0,0,0,379,409,3,52,26, + 0,380,381,5,48,0,0,381,382,5,1,0,0,382,387,3,54,27,0,383,384,5,10, + 0,0,384,386,3,54,27,0,385,383,1,0,0,0,386,389,1,0,0,0,387,385,1, + 0,0,0,387,388,1,0,0,0,388,390,1,0,0,0,389,387,1,0,0,0,390,391,5, + 2,0,0,391,409,1,0,0,0,392,409,3,12,6,0,393,409,3,28,14,0,394,395, + 5,12,0,0,395,396,3,54,27,0,396,397,5,13,0,0,397,409,1,0,0,0,398, + 400,5,48,0,0,399,398,1,0,0,0,399,400,1,0,0,0,400,401,1,0,0,0,401, + 405,3,20,10,0,402,404,5,11,0,0,403,402,1,0,0,0,404,407,1,0,0,0,405, + 403,1,0,0,0,405,406,1,0,0,0,406,409,1,0,0,0,407,405,1,0,0,0,408, + 366,1,0,0,0,408,368,1,0,0,0,408,370,1,0,0,0,408,371,1,0,0,0,408, + 372,1,0,0,0,408,379,1,0,0,0,408,380,1,0,0,0,408,392,1,0,0,0,408, + 393,1,0,0,0,408,394,1,0,0,0,408,399,1,0,0,0,409,427,1,0,0,0,410, + 411,10,16,0,0,411,412,5,24,0,0,412,426,3,54,27,17,413,414,10,15, + 0,0,414,415,7,6,0,0,415,426,3,54,27,16,416,417,10,14,0,0,417,418, + 7,2,0,0,418,426,3,54,27,15,419,420,10,3,0,0,420,421,5,3,0,0,421, + 426,3,54,27,4,422,423,10,2,0,0,423,424,5,16,0,0,424,426,3,54,27, + 3,425,410,1,0,0,0,425,413,1,0,0,0,425,416,1,0,0,0,425,419,1,0,0, + 0,425,422,1,0,0,0,426,429,1,0,0,0,427,425,1,0,0,0,427,428,1,0,0, + 0,428,55,1,0,0,0,429,427,1,0,0,0,50,59,68,83,88,97,103,112,115,125, + 128,131,139,154,161,164,172,184,188,193,196,201,206,218,229,241, + 247,255,264,268,275,279,287,296,301,307,315,320,330,335,351,354, + 360,364,376,387,399,405,408,425,427 + ] + +class AutolevParser ( Parser ): + + grammarFileName = "Autolev.g4" + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + sharedContextCache = PredictionContextCache() + + literalNames = [ "", "'['", "']'", "'='", "'+='", "'-='", "':='", + "'*='", "'/='", "'^='", "','", "'''", "'('", "')'", + "'{'", "'}'", "':'", "'+'", "'-'", "';'", "'.'", "'>'", + "'0>'", "'1>>'", "'^'", "'*'", "'/'" ] + + symbolicNames = [ "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "Mass", "Inertia", + "Input", "Output", "Save", "UnitSystem", "Encode", + "Newtonian", "Frames", "Bodies", "Particles", "Points", + "Constants", "Specifieds", "Imaginary", "Variables", + "MotionVariables", "INT", "FLOAT", "EXP", "LINE_COMMENT", + "ID", "WS" ] + + RULE_prog = 0 + RULE_stat = 1 + RULE_assignment = 2 + RULE_equals = 3 + RULE_index = 4 + RULE_diff = 5 + RULE_functionCall = 6 + RULE_varDecl = 7 + RULE_varType = 8 + RULE_varDecl2 = 9 + RULE_ranges = 10 + RULE_massDecl = 11 + RULE_massDecl2 = 12 + RULE_inertiaDecl = 13 + RULE_matrix = 14 + RULE_matrixInOutput = 15 + RULE_codeCommands = 16 + RULE_settings = 17 + RULE_units = 18 + RULE_inputs = 19 + RULE_id_diff = 20 + RULE_inputs2 = 21 + RULE_outputs = 22 + RULE_outputs2 = 23 + RULE_codegen = 24 + RULE_commands = 25 + RULE_vec = 26 + RULE_expr = 27 + + ruleNames = [ "prog", "stat", "assignment", "equals", "index", "diff", + "functionCall", "varDecl", "varType", "varDecl2", "ranges", + "massDecl", "massDecl2", "inertiaDecl", "matrix", "matrixInOutput", + "codeCommands", "settings", "units", "inputs", "id_diff", + "inputs2", "outputs", "outputs2", "codegen", "commands", + "vec", "expr" ] + + EOF = Token.EOF + T__0=1 + T__1=2 + T__2=3 + T__3=4 + T__4=5 + T__5=6 + T__6=7 + T__7=8 + T__8=9 + T__9=10 + T__10=11 + T__11=12 + T__12=13 + T__13=14 + T__14=15 + T__15=16 + T__16=17 + T__17=18 + T__18=19 + T__19=20 + T__20=21 + T__21=22 + T__22=23 + T__23=24 + T__24=25 + T__25=26 + Mass=27 + Inertia=28 + Input=29 + Output=30 + Save=31 + UnitSystem=32 + Encode=33 + Newtonian=34 + Frames=35 + Bodies=36 + Particles=37 + Points=38 + Constants=39 + Specifieds=40 + Imaginary=41 + Variables=42 + MotionVariables=43 + INT=44 + FLOAT=45 + EXP=46 + LINE_COMMENT=47 + ID=48 + WS=49 + + def __init__(self, input:TokenStream, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.11.1") + self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) + self._predicates = None + + + + + class ProgContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def stat(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.StatContext) + else: + return self.getTypedRuleContext(AutolevParser.StatContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_prog + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterProg" ): + listener.enterProg(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitProg" ): + listener.exitProg(self) + + + + + def prog(self): + + localctx = AutolevParser.ProgContext(self, self._ctx, self.state) + self.enterRule(localctx, 0, self.RULE_prog) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 57 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 56 + self.stat() + self.state = 59 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (((_la) & ~0x3f) == 0 and ((1 << _la) & 299067041120256) != 0): + break + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class StatContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def varDecl(self): + return self.getTypedRuleContext(AutolevParser.VarDeclContext,0) + + + def functionCall(self): + return self.getTypedRuleContext(AutolevParser.FunctionCallContext,0) + + + def codeCommands(self): + return self.getTypedRuleContext(AutolevParser.CodeCommandsContext,0) + + + def massDecl(self): + return self.getTypedRuleContext(AutolevParser.MassDeclContext,0) + + + def inertiaDecl(self): + return self.getTypedRuleContext(AutolevParser.InertiaDeclContext,0) + + + def assignment(self): + return self.getTypedRuleContext(AutolevParser.AssignmentContext,0) + + + def settings(self): + return self.getTypedRuleContext(AutolevParser.SettingsContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_stat + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterStat" ): + listener.enterStat(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitStat" ): + listener.exitStat(self) + + + + + def stat(self): + + localctx = AutolevParser.StatContext(self, self._ctx, self.state) + self.enterRule(localctx, 2, self.RULE_stat) + try: + self.state = 68 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,1,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 61 + self.varDecl() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 62 + self.functionCall() + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 63 + self.codeCommands() + pass + + elif la_ == 4: + self.enterOuterAlt(localctx, 4) + self.state = 64 + self.massDecl() + pass + + elif la_ == 5: + self.enterOuterAlt(localctx, 5) + self.state = 65 + self.inertiaDecl() + pass + + elif la_ == 6: + self.enterOuterAlt(localctx, 6) + self.state = 66 + self.assignment() + pass + + elif la_ == 7: + self.enterOuterAlt(localctx, 7) + self.state = 67 + self.settings() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AssignmentContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return AutolevParser.RULE_assignment + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + + class VecAssignContext(AssignmentContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.AssignmentContext + super().__init__(parser) + self.copyFrom(ctx) + + def vec(self): + return self.getTypedRuleContext(AutolevParser.VecContext,0) + + def equals(self): + return self.getTypedRuleContext(AutolevParser.EqualsContext,0) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVecAssign" ): + listener.enterVecAssign(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVecAssign" ): + listener.exitVecAssign(self) + + + class RegularAssignContext(AssignmentContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.AssignmentContext + super().__init__(parser) + self.copyFrom(ctx) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + def equals(self): + return self.getTypedRuleContext(AutolevParser.EqualsContext,0) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + def diff(self): + return self.getTypedRuleContext(AutolevParser.DiffContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterRegularAssign" ): + listener.enterRegularAssign(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitRegularAssign" ): + listener.exitRegularAssign(self) + + + class IndexAssignContext(AssignmentContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.AssignmentContext + super().__init__(parser) + self.copyFrom(ctx) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + def index(self): + return self.getTypedRuleContext(AutolevParser.IndexContext,0) + + def equals(self): + return self.getTypedRuleContext(AutolevParser.EqualsContext,0) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterIndexAssign" ): + listener.enterIndexAssign(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitIndexAssign" ): + listener.exitIndexAssign(self) + + + + def assignment(self): + + localctx = AutolevParser.AssignmentContext(self, self._ctx, self.state) + self.enterRule(localctx, 4, self.RULE_assignment) + self._la = 0 # Token type + try: + self.state = 88 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,3,self._ctx) + if la_ == 1: + localctx = AutolevParser.VecAssignContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 70 + self.vec() + self.state = 71 + self.equals() + self.state = 72 + self.expr(0) + pass + + elif la_ == 2: + localctx = AutolevParser.IndexAssignContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 74 + self.match(AutolevParser.ID) + self.state = 75 + self.match(AutolevParser.T__0) + self.state = 76 + self.index() + self.state = 77 + self.match(AutolevParser.T__1) + self.state = 78 + self.equals() + self.state = 79 + self.expr(0) + pass + + elif la_ == 3: + localctx = AutolevParser.RegularAssignContext(self, localctx) + self.enterOuterAlt(localctx, 3) + self.state = 81 + self.match(AutolevParser.ID) + self.state = 83 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==11: + self.state = 82 + self.diff() + + + self.state = 85 + self.equals() + self.state = 86 + self.expr(0) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class EqualsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return AutolevParser.RULE_equals + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterEquals" ): + listener.enterEquals(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitEquals" ): + listener.exitEquals(self) + + + + + def equals(self): + + localctx = AutolevParser.EqualsContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_equals) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 90 + _la = self._input.LA(1) + if not(((_la) & ~0x3f) == 0 and ((1 << _la) & 1016) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class IndexContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_index + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterIndex" ): + listener.enterIndex(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitIndex" ): + listener.exitIndex(self) + + + + + def index(self): + + localctx = AutolevParser.IndexContext(self, self._ctx, self.state) + self.enterRule(localctx, 8, self.RULE_index) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 92 + self.expr(0) + self.state = 97 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 93 + self.match(AutolevParser.T__9) + self.state = 94 + self.expr(0) + self.state = 99 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class DiffContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return AutolevParser.RULE_diff + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterDiff" ): + listener.enterDiff(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitDiff" ): + listener.exitDiff(self) + + + + + def diff(self): + + localctx = AutolevParser.DiffContext(self, self._ctx, self.state) + self.enterRule(localctx, 10, self.RULE_diff) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 101 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 100 + self.match(AutolevParser.T__10) + self.state = 103 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==11): + break + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class FunctionCallContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def Mass(self): + return self.getToken(AutolevParser.Mass, 0) + + def Inertia(self): + return self.getToken(AutolevParser.Inertia, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_functionCall + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterFunctionCall" ): + listener.enterFunctionCall(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitFunctionCall" ): + listener.exitFunctionCall(self) + + + + + def functionCall(self): + + localctx = AutolevParser.FunctionCallContext(self, self._ctx, self.state) + self.enterRule(localctx, 12, self.RULE_functionCall) + self._la = 0 # Token type + try: + self.state = 131 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [48]: + self.enterOuterAlt(localctx, 1) + self.state = 105 + self.match(AutolevParser.ID) + self.state = 106 + self.match(AutolevParser.T__11) + self.state = 115 + self._errHandler.sync(self) + _la = self._input.LA(1) + if ((_la) & ~0x3f) == 0 and ((1 << _la) & 404620694540290) != 0: + self.state = 107 + self.expr(0) + self.state = 112 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 108 + self.match(AutolevParser.T__9) + self.state = 109 + self.expr(0) + self.state = 114 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 117 + self.match(AutolevParser.T__12) + pass + elif token in [27, 28]: + self.enterOuterAlt(localctx, 2) + self.state = 118 + _la = self._input.LA(1) + if not(_la==27 or _la==28): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 119 + self.match(AutolevParser.T__11) + self.state = 128 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==48: + self.state = 120 + self.match(AutolevParser.ID) + self.state = 125 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 121 + self.match(AutolevParser.T__9) + self.state = 122 + self.match(AutolevParser.ID) + self.state = 127 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 130 + self.match(AutolevParser.T__12) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarDeclContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def varType(self): + return self.getTypedRuleContext(AutolevParser.VarTypeContext,0) + + + def varDecl2(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.VarDecl2Context) + else: + return self.getTypedRuleContext(AutolevParser.VarDecl2Context,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_varDecl + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVarDecl" ): + listener.enterVarDecl(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVarDecl" ): + listener.exitVarDecl(self) + + + + + def varDecl(self): + + localctx = AutolevParser.VarDeclContext(self, self._ctx, self.state) + self.enterRule(localctx, 14, self.RULE_varDecl) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 133 + self.varType() + self.state = 134 + self.varDecl2() + self.state = 139 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 135 + self.match(AutolevParser.T__9) + self.state = 136 + self.varDecl2() + self.state = 141 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarTypeContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Newtonian(self): + return self.getToken(AutolevParser.Newtonian, 0) + + def Frames(self): + return self.getToken(AutolevParser.Frames, 0) + + def Bodies(self): + return self.getToken(AutolevParser.Bodies, 0) + + def Particles(self): + return self.getToken(AutolevParser.Particles, 0) + + def Points(self): + return self.getToken(AutolevParser.Points, 0) + + def Constants(self): + return self.getToken(AutolevParser.Constants, 0) + + def Specifieds(self): + return self.getToken(AutolevParser.Specifieds, 0) + + def Imaginary(self): + return self.getToken(AutolevParser.Imaginary, 0) + + def Variables(self): + return self.getToken(AutolevParser.Variables, 0) + + def MotionVariables(self): + return self.getToken(AutolevParser.MotionVariables, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_varType + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVarType" ): + listener.enterVarType(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVarType" ): + listener.exitVarType(self) + + + + + def varType(self): + + localctx = AutolevParser.VarTypeContext(self, self._ctx, self.state) + self.enterRule(localctx, 16, self.RULE_varType) + self._la = 0 # Token type + try: + self.state = 164 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [34]: + self.enterOuterAlt(localctx, 1) + self.state = 142 + self.match(AutolevParser.Newtonian) + pass + elif token in [35]: + self.enterOuterAlt(localctx, 2) + self.state = 143 + self.match(AutolevParser.Frames) + pass + elif token in [36]: + self.enterOuterAlt(localctx, 3) + self.state = 144 + self.match(AutolevParser.Bodies) + pass + elif token in [37]: + self.enterOuterAlt(localctx, 4) + self.state = 145 + self.match(AutolevParser.Particles) + pass + elif token in [38]: + self.enterOuterAlt(localctx, 5) + self.state = 146 + self.match(AutolevParser.Points) + pass + elif token in [39]: + self.enterOuterAlt(localctx, 6) + self.state = 147 + self.match(AutolevParser.Constants) + pass + elif token in [40]: + self.enterOuterAlt(localctx, 7) + self.state = 148 + self.match(AutolevParser.Specifieds) + pass + elif token in [41]: + self.enterOuterAlt(localctx, 8) + self.state = 149 + self.match(AutolevParser.Imaginary) + pass + elif token in [42]: + self.enterOuterAlt(localctx, 9) + self.state = 150 + self.match(AutolevParser.Variables) + self.state = 154 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==11: + self.state = 151 + self.match(AutolevParser.T__10) + self.state = 156 + self._errHandler.sync(self) + _la = self._input.LA(1) + + pass + elif token in [43]: + self.enterOuterAlt(localctx, 10) + self.state = 157 + self.match(AutolevParser.MotionVariables) + self.state = 161 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==11: + self.state = 158 + self.match(AutolevParser.T__10) + self.state = 163 + self._errHandler.sync(self) + _la = self._input.LA(1) + + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarDecl2Context(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def INT(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.INT) + else: + return self.getToken(AutolevParser.INT, i) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_varDecl2 + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVarDecl2" ): + listener.enterVarDecl2(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVarDecl2" ): + listener.exitVarDecl2(self) + + + + + def varDecl2(self): + + localctx = AutolevParser.VarDecl2Context(self, self._ctx, self.state) + self.enterRule(localctx, 18, self.RULE_varDecl2) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 166 + self.match(AutolevParser.ID) + self.state = 172 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,15,self._ctx) + if la_ == 1: + self.state = 167 + self.match(AutolevParser.T__13) + self.state = 168 + self.match(AutolevParser.INT) + self.state = 169 + self.match(AutolevParser.T__9) + self.state = 170 + self.match(AutolevParser.INT) + self.state = 171 + self.match(AutolevParser.T__14) + + + self.state = 188 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,17,self._ctx) + if la_ == 1: + self.state = 174 + self.match(AutolevParser.T__13) + self.state = 175 + self.match(AutolevParser.INT) + self.state = 176 + self.match(AutolevParser.T__15) + self.state = 177 + self.match(AutolevParser.INT) + self.state = 184 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 178 + self.match(AutolevParser.T__9) + self.state = 179 + self.match(AutolevParser.INT) + self.state = 180 + self.match(AutolevParser.T__15) + self.state = 181 + self.match(AutolevParser.INT) + self.state = 186 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 187 + self.match(AutolevParser.T__14) + + + self.state = 193 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==14: + self.state = 190 + self.match(AutolevParser.T__13) + self.state = 191 + self.match(AutolevParser.INT) + self.state = 192 + self.match(AutolevParser.T__14) + + + self.state = 196 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==17 or _la==18: + self.state = 195 + _la = self._input.LA(1) + if not(_la==17 or _la==18): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + + + self.state = 201 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==11: + self.state = 198 + self.match(AutolevParser.T__10) + self.state = 203 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 206 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==3: + self.state = 204 + self.match(AutolevParser.T__2) + self.state = 205 + self.expr(0) + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class RangesContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def INT(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.INT) + else: + return self.getToken(AutolevParser.INT, i) + + def getRuleIndex(self): + return AutolevParser.RULE_ranges + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterRanges" ): + listener.enterRanges(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitRanges" ): + listener.exitRanges(self) + + + + + def ranges(self): + + localctx = AutolevParser.RangesContext(self, self._ctx, self.state) + self.enterRule(localctx, 20, self.RULE_ranges) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 208 + self.match(AutolevParser.T__13) + self.state = 209 + self.match(AutolevParser.INT) + self.state = 210 + self.match(AutolevParser.T__15) + self.state = 211 + self.match(AutolevParser.INT) + self.state = 218 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 212 + self.match(AutolevParser.T__9) + self.state = 213 + self.match(AutolevParser.INT) + self.state = 214 + self.match(AutolevParser.T__15) + self.state = 215 + self.match(AutolevParser.INT) + self.state = 220 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 221 + self.match(AutolevParser.T__14) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MassDeclContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Mass(self): + return self.getToken(AutolevParser.Mass, 0) + + def massDecl2(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.MassDecl2Context) + else: + return self.getTypedRuleContext(AutolevParser.MassDecl2Context,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_massDecl + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMassDecl" ): + listener.enterMassDecl(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMassDecl" ): + listener.exitMassDecl(self) + + + + + def massDecl(self): + + localctx = AutolevParser.MassDeclContext(self, self._ctx, self.state) + self.enterRule(localctx, 22, self.RULE_massDecl) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 223 + self.match(AutolevParser.Mass) + self.state = 224 + self.massDecl2() + self.state = 229 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 225 + self.match(AutolevParser.T__9) + self.state = 226 + self.massDecl2() + self.state = 231 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MassDecl2Context(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_massDecl2 + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMassDecl2" ): + listener.enterMassDecl2(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMassDecl2" ): + listener.exitMassDecl2(self) + + + + + def massDecl2(self): + + localctx = AutolevParser.MassDecl2Context(self, self._ctx, self.state) + self.enterRule(localctx, 24, self.RULE_massDecl2) + try: + self.enterOuterAlt(localctx, 1) + self.state = 232 + self.match(AutolevParser.ID) + self.state = 233 + self.match(AutolevParser.T__2) + self.state = 234 + self.expr(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class InertiaDeclContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Inertia(self): + return self.getToken(AutolevParser.Inertia, 0) + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_inertiaDecl + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterInertiaDecl" ): + listener.enterInertiaDecl(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitInertiaDecl" ): + listener.exitInertiaDecl(self) + + + + + def inertiaDecl(self): + + localctx = AutolevParser.InertiaDeclContext(self, self._ctx, self.state) + self.enterRule(localctx, 26, self.RULE_inertiaDecl) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 236 + self.match(AutolevParser.Inertia) + self.state = 237 + self.match(AutolevParser.ID) + self.state = 241 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==12: + self.state = 238 + self.match(AutolevParser.T__11) + self.state = 239 + self.match(AutolevParser.ID) + self.state = 240 + self.match(AutolevParser.T__12) + + + self.state = 245 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 243 + self.match(AutolevParser.T__9) + self.state = 244 + self.expr(0) + self.state = 247 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==10): + break + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MatrixContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_matrix + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMatrix" ): + listener.enterMatrix(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMatrix" ): + listener.exitMatrix(self) + + + + + def matrix(self): + + localctx = AutolevParser.MatrixContext(self, self._ctx, self.state) + self.enterRule(localctx, 28, self.RULE_matrix) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 249 + self.match(AutolevParser.T__0) + self.state = 250 + self.expr(0) + self.state = 255 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10 or _la==19: + self.state = 251 + _la = self._input.LA(1) + if not(_la==10 or _la==19): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 252 + self.expr(0) + self.state = 257 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 258 + self.match(AutolevParser.T__1) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MatrixInOutputContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def FLOAT(self): + return self.getToken(AutolevParser.FLOAT, 0) + + def INT(self): + return self.getToken(AutolevParser.INT, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_matrixInOutput + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMatrixInOutput" ): + listener.enterMatrixInOutput(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMatrixInOutput" ): + listener.exitMatrixInOutput(self) + + + + + def matrixInOutput(self): + + localctx = AutolevParser.MatrixInOutputContext(self, self._ctx, self.state) + self.enterRule(localctx, 30, self.RULE_matrixInOutput) + self._la = 0 # Token type + try: + self.state = 268 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [48]: + self.enterOuterAlt(localctx, 1) + self.state = 260 + self.match(AutolevParser.ID) + + self.state = 261 + self.match(AutolevParser.ID) + self.state = 262 + self.match(AutolevParser.T__2) + self.state = 264 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==44 or _la==45: + self.state = 263 + _la = self._input.LA(1) + if not(_la==44 or _la==45): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + + + pass + elif token in [45]: + self.enterOuterAlt(localctx, 2) + self.state = 266 + self.match(AutolevParser.FLOAT) + pass + elif token in [44]: + self.enterOuterAlt(localctx, 3) + self.state = 267 + self.match(AutolevParser.INT) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class CodeCommandsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def units(self): + return self.getTypedRuleContext(AutolevParser.UnitsContext,0) + + + def inputs(self): + return self.getTypedRuleContext(AutolevParser.InputsContext,0) + + + def outputs(self): + return self.getTypedRuleContext(AutolevParser.OutputsContext,0) + + + def codegen(self): + return self.getTypedRuleContext(AutolevParser.CodegenContext,0) + + + def commands(self): + return self.getTypedRuleContext(AutolevParser.CommandsContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_codeCommands + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterCodeCommands" ): + listener.enterCodeCommands(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitCodeCommands" ): + listener.exitCodeCommands(self) + + + + + def codeCommands(self): + + localctx = AutolevParser.CodeCommandsContext(self, self._ctx, self.state) + self.enterRule(localctx, 32, self.RULE_codeCommands) + try: + self.state = 275 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [32]: + self.enterOuterAlt(localctx, 1) + self.state = 270 + self.units() + pass + elif token in [29]: + self.enterOuterAlt(localctx, 2) + self.state = 271 + self.inputs() + pass + elif token in [30]: + self.enterOuterAlt(localctx, 3) + self.state = 272 + self.outputs() + pass + elif token in [48]: + self.enterOuterAlt(localctx, 4) + self.state = 273 + self.codegen() + pass + elif token in [31, 33]: + self.enterOuterAlt(localctx, 5) + self.state = 274 + self.commands() + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SettingsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def EXP(self): + return self.getToken(AutolevParser.EXP, 0) + + def FLOAT(self): + return self.getToken(AutolevParser.FLOAT, 0) + + def INT(self): + return self.getToken(AutolevParser.INT, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_settings + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterSettings" ): + listener.enterSettings(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitSettings" ): + listener.exitSettings(self) + + + + + def settings(self): + + localctx = AutolevParser.SettingsContext(self, self._ctx, self.state) + self.enterRule(localctx, 34, self.RULE_settings) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 277 + self.match(AutolevParser.ID) + self.state = 279 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,30,self._ctx) + if la_ == 1: + self.state = 278 + _la = self._input.LA(1) + if not(((_la) & ~0x3f) == 0 and ((1 << _la) & 404620279021568) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class UnitsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UnitSystem(self): + return self.getToken(AutolevParser.UnitSystem, 0) + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def getRuleIndex(self): + return AutolevParser.RULE_units + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterUnits" ): + listener.enterUnits(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitUnits" ): + listener.exitUnits(self) + + + + + def units(self): + + localctx = AutolevParser.UnitsContext(self, self._ctx, self.state) + self.enterRule(localctx, 36, self.RULE_units) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 281 + self.match(AutolevParser.UnitSystem) + self.state = 282 + self.match(AutolevParser.ID) + self.state = 287 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 283 + self.match(AutolevParser.T__9) + self.state = 284 + self.match(AutolevParser.ID) + self.state = 289 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class InputsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Input(self): + return self.getToken(AutolevParser.Input, 0) + + def inputs2(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.Inputs2Context) + else: + return self.getTypedRuleContext(AutolevParser.Inputs2Context,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_inputs + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterInputs" ): + listener.enterInputs(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitInputs" ): + listener.exitInputs(self) + + + + + def inputs(self): + + localctx = AutolevParser.InputsContext(self, self._ctx, self.state) + self.enterRule(localctx, 38, self.RULE_inputs) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 290 + self.match(AutolevParser.Input) + self.state = 291 + self.inputs2() + self.state = 296 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 292 + self.match(AutolevParser.T__9) + self.state = 293 + self.inputs2() + self.state = 298 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Id_diffContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def diff(self): + return self.getTypedRuleContext(AutolevParser.DiffContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_id_diff + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterId_diff" ): + listener.enterId_diff(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitId_diff" ): + listener.exitId_diff(self) + + + + + def id_diff(self): + + localctx = AutolevParser.Id_diffContext(self, self._ctx, self.state) + self.enterRule(localctx, 40, self.RULE_id_diff) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 299 + self.match(AutolevParser.ID) + self.state = 301 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==11: + self.state = 300 + self.diff() + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Inputs2Context(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def id_diff(self): + return self.getTypedRuleContext(AutolevParser.Id_diffContext,0) + + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_inputs2 + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterInputs2" ): + listener.enterInputs2(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitInputs2" ): + listener.exitInputs2(self) + + + + + def inputs2(self): + + localctx = AutolevParser.Inputs2Context(self, self._ctx, self.state) + self.enterRule(localctx, 42, self.RULE_inputs2) + try: + self.enterOuterAlt(localctx, 1) + self.state = 303 + self.id_diff() + self.state = 304 + self.match(AutolevParser.T__2) + self.state = 305 + self.expr(0) + self.state = 307 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,34,self._ctx) + if la_ == 1: + self.state = 306 + self.expr(0) + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class OutputsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Output(self): + return self.getToken(AutolevParser.Output, 0) + + def outputs2(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.Outputs2Context) + else: + return self.getTypedRuleContext(AutolevParser.Outputs2Context,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_outputs + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterOutputs" ): + listener.enterOutputs(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitOutputs" ): + listener.exitOutputs(self) + + + + + def outputs(self): + + localctx = AutolevParser.OutputsContext(self, self._ctx, self.state) + self.enterRule(localctx, 44, self.RULE_outputs) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 309 + self.match(AutolevParser.Output) + self.state = 310 + self.outputs2() + self.state = 315 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 311 + self.match(AutolevParser.T__9) + self.state = 312 + self.outputs2() + self.state = 317 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Outputs2Context(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_outputs2 + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterOutputs2" ): + listener.enterOutputs2(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitOutputs2" ): + listener.exitOutputs2(self) + + + + + def outputs2(self): + + localctx = AutolevParser.Outputs2Context(self, self._ctx, self.state) + self.enterRule(localctx, 46, self.RULE_outputs2) + try: + self.enterOuterAlt(localctx, 1) + self.state = 318 + self.expr(0) + self.state = 320 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,36,self._ctx) + if la_ == 1: + self.state = 319 + self.expr(0) + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class CodegenContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def functionCall(self): + return self.getTypedRuleContext(AutolevParser.FunctionCallContext,0) + + + def matrixInOutput(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.MatrixInOutputContext) + else: + return self.getTypedRuleContext(AutolevParser.MatrixInOutputContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_codegen + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterCodegen" ): + listener.enterCodegen(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitCodegen" ): + listener.exitCodegen(self) + + + + + def codegen(self): + + localctx = AutolevParser.CodegenContext(self, self._ctx, self.state) + self.enterRule(localctx, 48, self.RULE_codegen) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 322 + self.match(AutolevParser.ID) + self.state = 323 + self.functionCall() + self.state = 335 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==1: + self.state = 324 + self.match(AutolevParser.T__0) + self.state = 325 + self.matrixInOutput() + self.state = 330 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 326 + self.match(AutolevParser.T__9) + self.state = 327 + self.matrixInOutput() + self.state = 332 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 333 + self.match(AutolevParser.T__1) + + + self.state = 337 + self.match(AutolevParser.ID) + self.state = 338 + self.match(AutolevParser.T__19) + self.state = 339 + self.match(AutolevParser.ID) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class CommandsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Save(self): + return self.getToken(AutolevParser.Save, 0) + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def Encode(self): + return self.getToken(AutolevParser.Encode, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_commands + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterCommands" ): + listener.enterCommands(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitCommands" ): + listener.exitCommands(self) + + + + + def commands(self): + + localctx = AutolevParser.CommandsContext(self, self._ctx, self.state) + self.enterRule(localctx, 50, self.RULE_commands) + self._la = 0 # Token type + try: + self.state = 354 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [31]: + self.enterOuterAlt(localctx, 1) + self.state = 341 + self.match(AutolevParser.Save) + self.state = 342 + self.match(AutolevParser.ID) + self.state = 343 + self.match(AutolevParser.T__19) + self.state = 344 + self.match(AutolevParser.ID) + pass + elif token in [33]: + self.enterOuterAlt(localctx, 2) + self.state = 345 + self.match(AutolevParser.Encode) + self.state = 346 + self.match(AutolevParser.ID) + self.state = 351 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 347 + self.match(AutolevParser.T__9) + self.state = 348 + self.match(AutolevParser.ID) + self.state = 353 + self._errHandler.sync(self) + _la = self._input.LA(1) + + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VecContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_vec + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVec" ): + listener.enterVec(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVec" ): + listener.exitVec(self) + + + + + def vec(self): + + localctx = AutolevParser.VecContext(self, self._ctx, self.state) + self.enterRule(localctx, 52, self.RULE_vec) + try: + self.state = 364 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [48]: + self.enterOuterAlt(localctx, 1) + self.state = 356 + self.match(AutolevParser.ID) + self.state = 358 + self._errHandler.sync(self) + _alt = 1 + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt == 1: + self.state = 357 + self.match(AutolevParser.T__20) + + else: + raise NoViableAltException(self) + self.state = 360 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,41,self._ctx) + + pass + elif token in [22]: + self.enterOuterAlt(localctx, 2) + self.state = 362 + self.match(AutolevParser.T__21) + pass + elif token in [23]: + self.enterOuterAlt(localctx, 3) + self.state = 363 + self.match(AutolevParser.T__22) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ExprContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return AutolevParser.RULE_expr + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + class ParensContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterParens" ): + listener.enterParens(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitParens" ): + listener.exitParens(self) + + + class VectorOrDyadicContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def vec(self): + return self.getTypedRuleContext(AutolevParser.VecContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVectorOrDyadic" ): + listener.enterVectorOrDyadic(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVectorOrDyadic" ): + listener.exitVectorOrDyadic(self) + + + class ExponentContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterExponent" ): + listener.enterExponent(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitExponent" ): + listener.exitExponent(self) + + + class MulDivContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMulDiv" ): + listener.enterMulDiv(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMulDiv" ): + listener.exitMulDiv(self) + + + class AddSubContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterAddSub" ): + listener.enterAddSub(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitAddSub" ): + listener.exitAddSub(self) + + + class FloatContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def FLOAT(self): + return self.getToken(AutolevParser.FLOAT, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterFloat" ): + listener.enterFloat(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitFloat" ): + listener.exitFloat(self) + + + class IntContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def INT(self): + return self.getToken(AutolevParser.INT, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterInt" ): + listener.enterInt(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitInt" ): + listener.exitInt(self) + + + class IdEqualsExprContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterIdEqualsExpr" ): + listener.enterIdEqualsExpr(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitIdEqualsExpr" ): + listener.exitIdEqualsExpr(self) + + + class NegativeOneContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterNegativeOne" ): + listener.enterNegativeOne(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitNegativeOne" ): + listener.exitNegativeOne(self) + + + class FunctionContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def functionCall(self): + return self.getTypedRuleContext(AutolevParser.FunctionCallContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterFunction" ): + listener.enterFunction(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitFunction" ): + listener.exitFunction(self) + + + class RangessContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def ranges(self): + return self.getTypedRuleContext(AutolevParser.RangesContext,0) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterRangess" ): + listener.enterRangess(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitRangess" ): + listener.exitRangess(self) + + + class ColonContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterColon" ): + listener.enterColon(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitColon" ): + listener.exitColon(self) + + + class IdContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterId" ): + listener.enterId(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitId" ): + listener.exitId(self) + + + class ExpContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def EXP(self): + return self.getToken(AutolevParser.EXP, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterExp" ): + listener.enterExp(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitExp" ): + listener.exitExp(self) + + + class MatricesContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def matrix(self): + return self.getTypedRuleContext(AutolevParser.MatrixContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMatrices" ): + listener.enterMatrices(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMatrices" ): + listener.exitMatrices(self) + + + class IndexingContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterIndexing" ): + listener.enterIndexing(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitIndexing" ): + listener.exitIndexing(self) + + + + def expr(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = AutolevParser.ExprContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 54 + self.enterRecursionRule(localctx, 54, self.RULE_expr, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 408 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,47,self._ctx) + if la_ == 1: + localctx = AutolevParser.ExpContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + + self.state = 367 + self.match(AutolevParser.EXP) + pass + + elif la_ == 2: + localctx = AutolevParser.NegativeOneContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 368 + self.match(AutolevParser.T__17) + self.state = 369 + self.expr(12) + pass + + elif la_ == 3: + localctx = AutolevParser.FloatContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 370 + self.match(AutolevParser.FLOAT) + pass + + elif la_ == 4: + localctx = AutolevParser.IntContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 371 + self.match(AutolevParser.INT) + pass + + elif la_ == 5: + localctx = AutolevParser.IdContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 372 + self.match(AutolevParser.ID) + self.state = 376 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,43,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 373 + self.match(AutolevParser.T__10) + self.state = 378 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,43,self._ctx) + + pass + + elif la_ == 6: + localctx = AutolevParser.VectorOrDyadicContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 379 + self.vec() + pass + + elif la_ == 7: + localctx = AutolevParser.IndexingContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 380 + self.match(AutolevParser.ID) + self.state = 381 + self.match(AutolevParser.T__0) + self.state = 382 + self.expr(0) + self.state = 387 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 383 + self.match(AutolevParser.T__9) + self.state = 384 + self.expr(0) + self.state = 389 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 390 + self.match(AutolevParser.T__1) + pass + + elif la_ == 8: + localctx = AutolevParser.FunctionContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 392 + self.functionCall() + pass + + elif la_ == 9: + localctx = AutolevParser.MatricesContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 393 + self.matrix() + pass + + elif la_ == 10: + localctx = AutolevParser.ParensContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 394 + self.match(AutolevParser.T__11) + self.state = 395 + self.expr(0) + self.state = 396 + self.match(AutolevParser.T__12) + pass + + elif la_ == 11: + localctx = AutolevParser.RangessContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 399 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==48: + self.state = 398 + self.match(AutolevParser.ID) + + + self.state = 401 + self.ranges() + self.state = 405 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,46,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 402 + self.match(AutolevParser.T__10) + self.state = 407 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,46,self._ctx) + + pass + + + self._ctx.stop = self._input.LT(-1) + self.state = 427 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,49,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + self.state = 425 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,48,self._ctx) + if la_ == 1: + localctx = AutolevParser.ExponentContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 410 + if not self.precpred(self._ctx, 16): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 16)") + self.state = 411 + self.match(AutolevParser.T__23) + self.state = 412 + self.expr(17) + pass + + elif la_ == 2: + localctx = AutolevParser.MulDivContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 413 + if not self.precpred(self._ctx, 15): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 15)") + self.state = 414 + _la = self._input.LA(1) + if not(_la==25 or _la==26): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 415 + self.expr(16) + pass + + elif la_ == 3: + localctx = AutolevParser.AddSubContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 416 + if not self.precpred(self._ctx, 14): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 14)") + self.state = 417 + _la = self._input.LA(1) + if not(_la==17 or _la==18): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 418 + self.expr(15) + pass + + elif la_ == 4: + localctx = AutolevParser.IdEqualsExprContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 419 + if not self.precpred(self._ctx, 3): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 3)") + self.state = 420 + self.match(AutolevParser.T__2) + self.state = 421 + self.expr(4) + pass + + elif la_ == 5: + localctx = AutolevParser.ColonContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 422 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 423 + self.match(AutolevParser.T__15) + self.state = 424 + self.expr(3) + pass + + + self.state = 429 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,49,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + + def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): + if self._predicates == None: + self._predicates = dict() + self._predicates[27] = self.expr_sempred + pred = self._predicates.get(ruleIndex, None) + if pred is None: + raise Exception("No predicate with index:" + str(ruleIndex)) + else: + return pred(localctx, predIndex) + + def expr_sempred(self, localctx:ExprContext, predIndex:int): + if predIndex == 0: + return self.precpred(self._ctx, 16) + + + if predIndex == 1: + return self.precpred(self._ctx, 15) + + + if predIndex == 2: + return self.precpred(self._ctx, 14) + + + if predIndex == 3: + return self.precpred(self._ctx, 3) + + + if predIndex == 4: + return self.precpred(self._ctx, 2) + + + + + diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_build_autolev_antlr.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_build_autolev_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..8314b2f546c0a18a8e281768b60d66556c852e3b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_build_autolev_antlr.py @@ -0,0 +1,86 @@ +import os +import subprocess +import glob + +from sympy.utilities.misc import debug + +here = os.path.dirname(__file__) +grammar_file = os.path.abspath(os.path.join(here, "Autolev.g4")) +dir_autolev_antlr = os.path.join(here, "_antlr") + +header = '''\ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +''' + + +def check_antlr_version(): + debug("Checking antlr4 version...") + + try: + debug(subprocess.check_output(["antlr4"]) + .decode('utf-8').split("\n")[0]) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + debug("The 'antlr4' command line tool is not installed, " + "or not on your PATH.\n" + "> Please refer to the README.md file for more information.") + return False + + +def build_parser(output_dir=dir_autolev_antlr): + check_antlr_version() + + debug("Updating ANTLR-generated code in {}".format(output_dir)) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + with open(os.path.join(output_dir, "__init__.py"), "w+") as fp: + fp.write(header) + + args = [ + "antlr4", + grammar_file, + "-o", output_dir, + "-no-visitor", + ] + + debug("Running code generation...\n\t$ {}".format(" ".join(args))) + subprocess.check_output(args, cwd=output_dir) + + debug("Applying headers, removing unnecessary files and renaming...") + # Handle case insensitive file systems. If the files are already + # generated, they will be written to autolev* but Autolev*.* won't match them. + for path in (glob.glob(os.path.join(output_dir, "Autolev*.*")) or + glob.glob(os.path.join(output_dir, "autolev*.*"))): + + # Remove files ending in .interp or .tokens as they are not needed. + if not path.endswith(".py"): + os.unlink(path) + continue + + new_path = os.path.join(output_dir, os.path.basename(path).lower()) + with open(path, 'r') as f: + lines = [line.rstrip().replace('AutolevParser import', 'autolevparser import') +'\n' + for line in f] + + os.unlink(path) + + with open(new_path, "w") as out_file: + offset = 0 + while lines[offset].startswith('#'): + offset += 1 + out_file.write(header) + out_file.writelines(lines[offset:]) + + debug("\t{}".format(new_path)) + + return True + + +if __name__ == "__main__": + build_parser() diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_listener_autolev_antlr.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_listener_autolev_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca2f8af88de18036b90788fd29d02707f098213 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_listener_autolev_antlr.py @@ -0,0 +1,2083 @@ +import collections +import warnings + +from sympy.external import import_module + +autolevparser = import_module('sympy.parsing.autolev._antlr.autolevparser', + import_kwargs={'fromlist': ['AutolevParser']}) +autolevlexer = import_module('sympy.parsing.autolev._antlr.autolevlexer', + import_kwargs={'fromlist': ['AutolevLexer']}) +autolevlistener = import_module('sympy.parsing.autolev._antlr.autolevlistener', + import_kwargs={'fromlist': ['AutolevListener']}) + +AutolevParser = getattr(autolevparser, 'AutolevParser', None) +AutolevLexer = getattr(autolevlexer, 'AutolevLexer', None) +AutolevListener = getattr(autolevlistener, 'AutolevListener', None) + + +def strfunc(z): + if z == 0: + return "" + elif z == 1: + return "_d" + else: + return "_" + "d" * z + +def declare_phy_entities(self, ctx, phy_type, i, j=None): + if phy_type in ("frame", "newtonian"): + declare_frames(self, ctx, i, j) + elif phy_type == "particle": + declare_particles(self, ctx, i, j) + elif phy_type == "point": + declare_points(self, ctx, i, j) + elif phy_type == "bodies": + declare_bodies(self, ctx, i, j) + +def declare_frames(self, ctx, i, j=None): + if "{" in ctx.getText(): + if j: + name1 = ctx.ID().getText().lower() + str(i) + str(j) + else: + name1 = ctx.ID().getText().lower() + str(i) + else: + name1 = ctx.ID().getText().lower() + name2 = "frame_" + name1 + if self.getValue(ctx.parentCtx.varType()) == "newtonian": + self.newtonian = name2 + + self.symbol_table2.update({name1: name2}) + + self.symbol_table.update({name1 + "1>": name2 + ".x"}) + self.symbol_table.update({name1 + "2>": name2 + ".y"}) + self.symbol_table.update({name1 + "3>": name2 + ".z"}) + + self.type2.update({name1: "frame"}) + self.write(name2 + " = " + "_me.ReferenceFrame('" + name1 + "')\n") + +def declare_points(self, ctx, i, j=None): + if "{" in ctx.getText(): + if j: + name1 = ctx.ID().getText().lower() + str(i) + str(j) + else: + name1 = ctx.ID().getText().lower() + str(i) + else: + name1 = ctx.ID().getText().lower() + + name2 = "point_" + name1 + + self.symbol_table2.update({name1: name2}) + self.type2.update({name1: "point"}) + self.write(name2 + " = " + "_me.Point('" + name1 + "')\n") + +def declare_particles(self, ctx, i, j=None): + if "{" in ctx.getText(): + if j: + name1 = ctx.ID().getText().lower() + str(i) + str(j) + else: + name1 = ctx.ID().getText().lower() + str(i) + else: + name1 = ctx.ID().getText().lower() + + name2 = "particle_" + name1 + + self.symbol_table2.update({name1: name2}) + self.type2.update({name1: "particle"}) + self.bodies.update({name1: name2}) + self.write(name2 + " = " + "_me.Particle('" + name1 + "', " + "_me.Point('" + + name1 + "_pt" + "'), " + "_sm.Symbol('m'))\n") + +def declare_bodies(self, ctx, i, j=None): + if "{" in ctx.getText(): + if j: + name1 = ctx.ID().getText().lower() + str(i) + str(j) + else: + name1 = ctx.ID().getText().lower() + str(i) + else: + name1 = ctx.ID().getText().lower() + + name2 = "body_" + name1 + self.bodies.update({name1: name2}) + masscenter = name2 + "_cm" + refFrame = name2 + "_f" + + self.symbol_table2.update({name1: name2}) + self.symbol_table2.update({name1 + "o": masscenter}) + self.symbol_table.update({name1 + "1>": refFrame+".x"}) + self.symbol_table.update({name1 + "2>": refFrame+".y"}) + self.symbol_table.update({name1 + "3>": refFrame+".z"}) + + self.type2.update({name1: "bodies"}) + self.type2.update({name1+"o": "point"}) + + self.write(masscenter + " = " + "_me.Point('" + name1 + "_cm" + "')\n") + if self.newtonian: + self.write(masscenter + ".set_vel(" + self.newtonian + ", " + "0)\n") + self.write(refFrame + " = " + "_me.ReferenceFrame('" + name1 + "_f" + "')\n") + # We set a dummy mass and inertia here. + # They will be reset using the setters later in the code anyway. + self.write(name2 + " = " + "_me.RigidBody('" + name1 + "', " + masscenter + ", " + + refFrame + ", " + "_sm.symbols('m'), (_me.outer(" + refFrame + + ".x," + refFrame + ".x)," + masscenter + "))\n") + +def inertia_func(self, v1, v2, l, frame): + + if self.type2[v1] == "particle": + l.append("_me.inertia_of_point_mass(" + self.bodies[v1] + ".mass, " + self.bodies[v1] + + ".point.pos_from(" + self.symbol_table2[v2] + "), " + frame + ")") + + elif self.type2[v1] == "bodies": + # Inertia has been defined about center of mass. + if self.inertia_point[v1] == v1 + "o": + # Asking point is cm as well + if v2 == self.inertia_point[v1]: + l.append(self.symbol_table2[v1] + ".inertia[0]") + + # Asking point is not cm + else: + l.append(self.bodies[v1] + ".inertia[0]" + " + " + + "_me.inertia_of_point_mass(" + self.bodies[v1] + + ".mass, " + self.bodies[v1] + ".masscenter" + + ".pos_from(" + self.symbol_table2[v2] + + "), " + frame + ")") + + # Inertia has been defined about another point + else: + # Asking point is the defined point + if v2 == self.inertia_point[v1]: + l.append(self.symbol_table2[v1] + ".inertia[0]") + # Asking point is cm + elif v2 == v1 + "o": + l.append(self.bodies[v1] + ".inertia[0]" + " - " + + "_me.inertia_of_point_mass(" + self.bodies[v1] + + ".mass, " + self.bodies[v1] + ".masscenter" + + ".pos_from(" + self.symbol_table2[self.inertia_point[v1]] + + "), " + frame + ")") + # Asking point is some other point + else: + l.append(self.bodies[v1] + ".inertia[0]" + " - " + + "_me.inertia_of_point_mass(" + self.bodies[v1] + + ".mass, " + self.bodies[v1] + ".masscenter" + + ".pos_from(" + self.symbol_table2[self.inertia_point[v1]] + + "), " + frame + ")" + " + " + + "_me.inertia_of_point_mass(" + self.bodies[v1] + + ".mass, " + self.bodies[v1] + ".masscenter" + + ".pos_from(" + self.symbol_table2[v2] + + "), " + frame + ")") + + +def processConstants(self, ctx): + # Process constant declarations of the type: Constants F = 3, g = 9.81 + name = ctx.ID().getText().lower() + if "=" in ctx.getText(): + self.symbol_table.update({name: name}) + # self.inputs.update({self.symbol_table[name]: self.getValue(ctx.getChild(2))}) + self.write(self.symbol_table[name] + " = " + "_sm.S(" + self.getValue(ctx.getChild(2)) + ")\n") + self.type.update({name: "constants"}) + return + + # Constants declarations of the type: Constants A, B + else: + if "{" not in ctx.getText(): + self.symbol_table[name] = name + self.type[name] = "constants" + + # Process constant declarations of the type: Constants C+, D- + if ctx.getChildCount() == 2: + # This is set for declaring nonpositive=True and nonnegative=True + if ctx.getChild(1).getText() == "+": + self.sign[name] = "+" + elif ctx.getChild(1).getText() == "-": + self.sign[name] = "-" + else: + if "{" not in ctx.getText(): + self.sign[name] = "o" + + # Process constant declarations of the type: Constants K{4}, a{1:2, 1:2}, b{1:2} + if "{" in ctx.getText(): + if ":" in ctx.getText(): + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + else: + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + + if ":" in ctx.getText(): + if "," in ctx.getText(): + num3 = int(ctx.INT(2).getText()) + num4 = int(ctx.INT(3).getText()) + 1 + for i in range(num1, num2): + for j in range(num3, num4): + self.symbol_table[name + str(i) + str(j)] = name + str(i) + str(j) + self.type[name + str(i) + str(j)] = "constants" + self.var_list.append(name + str(i) + str(j)) + self.sign[name + str(i) + str(j)] = "o" + else: + for i in range(num1, num2): + self.symbol_table[name + str(i)] = name + str(i) + self.type[name + str(i)] = "constants" + self.var_list.append(name + str(i)) + self.sign[name + str(i)] = "o" + + elif "," in ctx.getText(): + for i in range(1, int(ctx.INT(0).getText()) + 1): + for j in range(1, int(ctx.INT(1).getText()) + 1): + self.symbol_table[name] = name + str(i) + str(j) + self.type[name + str(i) + str(j)] = "constants" + self.var_list.append(name + str(i) + str(j)) + self.sign[name + str(i) + str(j)] = "o" + + else: + for i in range(num1, num2): + self.symbol_table[name + str(i)] = name + str(i) + self.type[name + str(i)] = "constants" + self.var_list.append(name + str(i)) + self.sign[name + str(i)] = "o" + + if "{" not in ctx.getText(): + self.var_list.append(name) + + +def writeConstants(self, ctx): + l1 = list(filter(lambda x: self.sign[x] == "o", self.var_list)) + l2 = list(filter(lambda x: self.sign[x] == "+", self.var_list)) + l3 = list(filter(lambda x: self.sign[x] == "-", self.var_list)) + try: + if self.settings["complex"] == "on": + real = ", real=True" + elif self.settings["complex"] == "off": + real = "" + except Exception: + real = ", real=True" + + if l1: + a = ", ".join(l1) + " = " + "_sm.symbols(" + "'" +\ + " ".join(l1) + "'" + real + ")\n" + self.write(a) + if l2: + a = ", ".join(l2) + " = " + "_sm.symbols(" + "'" +\ + " ".join(l2) + "'" + real + ", nonnegative=True)\n" + self.write(a) + if l3: + a = ", ".join(l3) + " = " + "_sm.symbols(" + "'" + \ + " ".join(l3) + "'" + real + ", nonpositive=True)\n" + self.write(a) + self.var_list = [] + + +def processVariables(self, ctx): + # Specified F = x*N1> + y*N2> + name = ctx.ID().getText().lower() + if "=" in ctx.getText(): + text = name + "'"*(ctx.getChildCount()-3) + self.write(text + " = " + self.getValue(ctx.expr()) + "\n") + return + + # Process variables of the type: Variables qA, qB + if ctx.getChildCount() == 1: + self.symbol_table[name] = name + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name: self.getValue(ctx.parentCtx.getChild(0))}) + + self.var_list.append(name) + self.sign[name] = 0 + + # Process variables of the type: Variables x', y'' + elif "'" in ctx.getText() and "{" not in ctx.getText(): + if ctx.getText().count("'") > self.maxDegree: + self.maxDegree = ctx.getText().count("'") + for i in range(ctx.getChildCount()): + self.sign[name + strfunc(i)] = i + self.symbol_table[name + "'"*i] = name + strfunc(i) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + "'"*i: self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + strfunc(i)) + + elif "{" in ctx.getText(): + # Process variables of the type: Variables x{3}, y{2} + + if "'" in ctx.getText(): + dash_count = ctx.getText().count("'") + if dash_count > self.maxDegree: + self.maxDegree = dash_count + + if ":" in ctx.getText(): + # Variables C{1:2, 1:2} + if "," in ctx.getText(): + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + num3 = int(ctx.INT(2).getText()) + num4 = int(ctx.INT(3).getText()) + 1 + # Variables C{1:2} + else: + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + + # Variables C{1,3} + elif "," in ctx.getText(): + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + num3 = 1 + num4 = int(ctx.INT(1).getText()) + 1 + else: + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + + for i in range(num1, num2): + try: + for j in range(num3, num4): + try: + for z in range(dash_count+1): + self.symbol_table.update({name + str(i) + str(j) + "'"*z: name + str(i) + str(j) + strfunc(z)}) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + str(i) + str(j) + "'"*z: self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + str(i) + str(j) + strfunc(z)) + self.sign.update({name + str(i) + str(j) + strfunc(z): z}) + if dash_count > self.maxDegree: + self.maxDegree = dash_count + except Exception: + self.symbol_table.update({name + str(i) + str(j): name + str(i) + str(j)}) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + str(i) + str(j): self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + str(i) + str(j)) + self.sign.update({name + str(i) + str(j): 0}) + except Exception: + try: + for z in range(dash_count+1): + self.symbol_table.update({name + str(i) + "'"*z: name + str(i) + strfunc(z)}) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + str(i) + "'"*z: self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + str(i) + strfunc(z)) + self.sign.update({name + str(i) + strfunc(z): z}) + if dash_count > self.maxDegree: + self.maxDegree = dash_count + except Exception: + self.symbol_table.update({name + str(i): name + str(i)}) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + str(i): self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + str(i)) + self.sign.update({name + str(i): 0}) + +def writeVariables(self, ctx): + #print(self.sign) + #print(self.symbol_table) + if self.var_list: + for i in range(self.maxDegree+1): + if i == 0: + j = "" + t = "" + else: + j = str(i) + t = ", " + l = [] + for k in list(filter(lambda x: self.sign[x] == i, self.var_list)): + if i == 0: + l.append(k) + if i == 1: + l.append(k[:-1]) + if i > 1: + l.append(k[:-2]) + a = ", ".join(list(filter(lambda x: self.sign[x] == i, self.var_list))) + " = " +\ + "_me.dynamicsymbols(" + "'" + " ".join(l) + "'" + t + j + ")\n" + l = [] + self.write(a) + self.maxDegree = 0 + self.var_list = [] + +def processImaginary(self, ctx): + name = ctx.ID().getText().lower() + self.symbol_table[name] = name + self.type[name] = "imaginary" + self.var_list.append(name) + + +def writeImaginary(self, ctx): + a = ", ".join(self.var_list) + " = " + "_sm.symbols(" + "'" + \ + " ".join(self.var_list) + "')\n" + b = ", ".join(self.var_list) + " = " + "_sm.I\n" + self.write(a) + self.write(b) + self.var_list = [] + +if AutolevListener: + class MyListener(AutolevListener): # type: ignore + def __init__(self, include_numeric=False): + # Stores data in tree nodes(tree annotation). Especially useful for expr reconstruction. + self.tree_property = {} + + # Stores the declared variables, constants etc as they are declared in Autolev and SymPy + # {"": ""}. + self.symbol_table = collections.OrderedDict() + + # Similar to symbol_table. Used for storing Physical entities like Frames, Points, + # Particles, Bodies etc + self.symbol_table2 = collections.OrderedDict() + + # Used to store nonpositive, nonnegative etc for constants and number of "'"s (order of diff) + # in variables. + self.sign = {} + + # Simple list used as a store to pass around variables between the 'process' and 'write' + # methods. + self.var_list = [] + + # Stores the type of a declared variable (constants, variables, specifieds etc) + self.type = collections.OrderedDict() + + # Similar to self.type. Used for storing the type of Physical entities like Frames, Points, + # Particles, Bodies etc + self.type2 = collections.OrderedDict() + + # These lists are used to distinguish matrix, numeric and vector expressions. + self.matrix_expr = [] + self.numeric_expr = [] + self.vector_expr = [] + self.fr_expr = [] + + self.output_code = [] + + # Stores the variables and their rhs for substituting upon the Autolev command EXPLICIT. + self.explicit = collections.OrderedDict() + + # Write code to import common dependencies. + self.output_code.append("import sympy.physics.mechanics as _me\n") + self.output_code.append("import sympy as _sm\n") + self.output_code.append("import math as m\n") + self.output_code.append("import numpy as _np\n") + self.output_code.append("\n") + + # Just a store for the max degree variable in a line. + self.maxDegree = 0 + + # Stores the input parameters which are then used for codegen and numerical analysis. + self.inputs = collections.OrderedDict() + # Stores the variables which appear in Output Autolev commands. + self.outputs = [] + # Stores the settings specified by the user. Ex: Complex on/off, Degrees on/off + self.settings = {} + # Boolean which changes the behaviour of some expression reconstruction + # when parsing Input Autolev commands. + self.in_inputs = False + self.in_outputs = False + + # Stores for the physical entities. + self.newtonian = None + self.bodies = collections.OrderedDict() + self.constants = [] + self.forces = collections.OrderedDict() + self.q_ind = [] + self.q_dep = [] + self.u_ind = [] + self.u_dep = [] + self.kd_eqs = [] + self.dependent_variables = [] + self.kd_equivalents = collections.OrderedDict() + self.kd_equivalents2 = collections.OrderedDict() + self.kd_eqs_supplied = None + self.kane_type = "no_args" + self.inertia_point = collections.OrderedDict() + self.kane_parsed = False + self.t = False + + # PyDy ode code will be included only if this flag is set to True. + self.include_numeric = include_numeric + + def write(self, string): + self.output_code.append(string) + + def getValue(self, node): + return self.tree_property[node] + + def setValue(self, node, value): + self.tree_property[node] = value + + def getSymbolTable(self): + return self.symbol_table + + def getType(self): + return self.type + + def exitVarDecl(self, ctx): + # This event method handles variable declarations. The parse tree node varDecl contains + # one or more varDecl2 nodes. Eg varDecl for 'Constants a{1:2, 1:2}, b{1:2}' has two varDecl2 + # nodes(one for a{1:2, 1:2} and one for b{1:2}). + + # Variable declarations are processed and stored in the event method exitVarDecl2. + # This stored information is used to write the final SymPy output code in the exitVarDecl event method. + + # determine the type of declaration + if self.getValue(ctx.varType()) == "constant": + writeConstants(self, ctx) + elif self.getValue(ctx.varType()) in\ + ("variable", "motionvariable", "motionvariable'", "specified"): + writeVariables(self, ctx) + elif self.getValue(ctx.varType()) == "imaginary": + writeImaginary(self, ctx) + + def exitVarType(self, ctx): + # Annotate the varType tree node with the type of the variable declaration. + name = ctx.getChild(0).getText().lower() + if name[-1] == "s" and name != "bodies": + self.setValue(ctx, name[:-1]) + else: + self.setValue(ctx, name) + + def exitVarDecl2(self, ctx): + # Variable declarations are processed and stored in the event method exitVarDecl2. + # This stored information is used to write the final SymPy output code in the exitVarDecl event method. + # This is the case for constants, variables, specifieds etc. + + # This isn't the case for all types of declarations though. For instance + # Frames A, B, C, N cannot be defined on one line in SymPy. So we do not append A, B, C, N + # to a var_list or use exitVarDecl. exitVarDecl2 directly writes out to the file. + + # determine the type of declaration + if self.getValue(ctx.parentCtx.varType()) == "constant": + processConstants(self, ctx) + + elif self.getValue(ctx.parentCtx.varType()) in \ + ("variable", "motionvariable", "motionvariable'", "specified"): + processVariables(self, ctx) + + elif self.getValue(ctx.parentCtx.varType()) == "imaginary": + processImaginary(self, ctx) + + elif self.getValue(ctx.parentCtx.varType()) in ("frame", "newtonian", "point", "particle", "bodies"): + if "{" in ctx.getText(): + if ":" in ctx.getText() and "," not in ctx.getText(): + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + elif ":" not in ctx.getText() and "," in ctx.getText(): + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + num3 = 1 + num4 = int(ctx.INT(1).getText()) + 1 + elif ":" in ctx.getText() and "," in ctx.getText(): + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + num3 = int(ctx.INT(2).getText()) + num4 = int(ctx.INT(3).getText()) + 1 + else: + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + else: + num1 = 1 + num2 = 2 + for i in range(num1, num2): + try: + for j in range(num3, num4): + declare_phy_entities(self, ctx, self.getValue(ctx.parentCtx.varType()), i, j) + except Exception: + declare_phy_entities(self, ctx, self.getValue(ctx.parentCtx.varType()), i) + # ================== Subrules of parser rule expr (Start) ====================== # + + def exitId(self, ctx): + # Tree annotation for ID which is a labeled subrule of the parser rule expr. + # A_C + python_keywords = ["and", "as", "assert", "break", "class", "continue", "def", "del", "elif", "else", "except",\ + "exec", "finally", "for", "from", "global", "if", "import", "in", "is", "lambda", "not", "or", "pass", "print",\ + "raise", "return", "try", "while", "with", "yield"] + + if ctx.ID().getText().lower() in python_keywords: + warnings.warn("Python keywords must not be used as identifiers. Please refer to the list of keywords at https://docs.python.org/2.5/ref/keywords.html", + SyntaxWarning) + + if "_" in ctx.ID().getText() and ctx.ID().getText().count('_') == 1: + e1, e2 = ctx.ID().getText().lower().split('_') + try: + if self.type2[e1] == "frame": + e1 = self.symbol_table2[e1] + elif self.type2[e1] == "bodies": + e1 = self.symbol_table2[e1] + "_f" + if self.type2[e2] == "frame": + e2 = self.symbol_table2[e2] + elif self.type2[e2] == "bodies": + e2 = self.symbol_table2[e2] + "_f" + + self.setValue(ctx, e1 + ".dcm(" + e2 + ")") + except Exception: + self.setValue(ctx, ctx.ID().getText().lower()) + else: + # Reserved constant Pi + if ctx.ID().getText().lower() == "pi": + self.setValue(ctx, "_sm.pi") + self.numeric_expr.append(ctx) + + # Reserved variable T (for time) + elif ctx.ID().getText().lower() == "t": + self.setValue(ctx, "_me.dynamicsymbols._t") + if not self.in_inputs and not self.in_outputs: + self.t = True + + else: + idText = ctx.ID().getText().lower() + "'"*(ctx.getChildCount() - 1) + if idText in self.type.keys() and self.type[idText] == "matrix": + self.matrix_expr.append(ctx) + if self.in_inputs: + try: + self.setValue(ctx, self.symbol_table[idText]) + except Exception: + self.setValue(ctx, idText.lower()) + else: + try: + self.setValue(ctx, self.symbol_table[idText]) + except Exception: + pass + + def exitInt(self, ctx): + # Tree annotation for int which is a labeled subrule of the parser rule expr. + int_text = ctx.INT().getText() + self.setValue(ctx, int_text) + self.numeric_expr.append(ctx) + + def exitFloat(self, ctx): + # Tree annotation for float which is a labeled subrule of the parser rule expr. + floatText = ctx.FLOAT().getText() + self.setValue(ctx, floatText) + self.numeric_expr.append(ctx) + + def exitAddSub(self, ctx): + # Tree annotation for AddSub which is a labeled subrule of the parser rule expr. + # The subrule is expr = expr (+|-) expr + if ctx.expr(0) in self.matrix_expr or ctx.expr(1) in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.expr(0) in self.vector_expr or ctx.expr(1) in self.vector_expr: + self.vector_expr.append(ctx) + if ctx.expr(0) in self.numeric_expr and ctx.expr(1) in self.numeric_expr: + self.numeric_expr.append(ctx) + self.setValue(ctx, self.getValue(ctx.expr(0)) + ctx.getChild(1).getText() + + self.getValue(ctx.expr(1))) + + def exitMulDiv(self, ctx): + # Tree annotation for MulDiv which is a labeled subrule of the parser rule expr. + # The subrule is expr = expr (*|/) expr + try: + if ctx.expr(0) in self.vector_expr and ctx.expr(1) in self.vector_expr: + self.setValue(ctx, "_me.outer(" + self.getValue(ctx.expr(0)) + ", " + + self.getValue(ctx.expr(1)) + ")") + else: + if ctx.expr(0) in self.matrix_expr or ctx.expr(1) in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.expr(0) in self.vector_expr or ctx.expr(1) in self.vector_expr: + self.vector_expr.append(ctx) + if ctx.expr(0) in self.numeric_expr and ctx.expr(1) in self.numeric_expr: + self.numeric_expr.append(ctx) + self.setValue(ctx, self.getValue(ctx.expr(0)) + ctx.getChild(1).getText() + + self.getValue(ctx.expr(1))) + except Exception: + pass + + def exitNegativeOne(self, ctx): + # Tree annotation for negativeOne which is a labeled subrule of the parser rule expr. + self.setValue(ctx, "-1*" + self.getValue(ctx.getChild(1))) + if ctx.getChild(1) in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.getChild(1) in self.numeric_expr: + self.numeric_expr.append(ctx) + + def exitParens(self, ctx): + # Tree annotation for parens which is a labeled subrule of the parser rule expr. + # The subrule is expr = '(' expr ')' + if ctx.expr() in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.expr() in self.vector_expr: + self.vector_expr.append(ctx) + if ctx.expr() in self.numeric_expr: + self.numeric_expr.append(ctx) + self.setValue(ctx, "(" + self.getValue(ctx.expr()) + ")") + + def exitExponent(self, ctx): + # Tree annotation for Exponent which is a labeled subrule of the parser rule expr. + # The subrule is expr = expr ^ expr + if ctx.expr(0) in self.matrix_expr or ctx.expr(1) in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.expr(0) in self.vector_expr or ctx.expr(1) in self.vector_expr: + self.vector_expr.append(ctx) + if ctx.expr(0) in self.numeric_expr and ctx.expr(1) in self.numeric_expr: + self.numeric_expr.append(ctx) + self.setValue(ctx, self.getValue(ctx.expr(0)) + "**" + self.getValue(ctx.expr(1))) + + def exitExp(self, ctx): + s = ctx.EXP().getText()[ctx.EXP().getText().index('E')+1:] + if "-" in s: + s = s[0] + s[1:].lstrip("0") + else: + s = s.lstrip("0") + self.setValue(ctx, ctx.EXP().getText()[:ctx.EXP().getText().index('E')] + + "*10**(" + s + ")") + + def exitFunction(self, ctx): + # Tree annotation for function which is a labeled subrule of the parser rule expr. + + # The difference between this and FunctionCall is that this is used for non standalone functions + # appearing in expressions and assignments. + # Eg: + # When we come across a standalone function say Expand(E, n:m) then it is categorized as FunctionCall + # which is a parser rule in itself under rule stat. exitFunctionCall() takes care of it and writes to the file. + # + # On the other hand, while we come across E_diff = D(E, y), we annotate the tree node + # of the function D(E, y) with the SymPy equivalent in exitFunction(). + # In this case it is the method exitAssignment() that writes the code to the file and not exitFunction(). + + ch = ctx.getChild(0) + func_name = ch.getChild(0).getText().lower() + + # Expand(y, n:m) * + if func_name == "expand": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + # _sm.Matrix([i.expand() for i in z]).reshape(z.shape[0], z.shape[1]) + self.setValue(ctx, "_sm.Matrix([i.expand() for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "(" + expr + ")" + "." + "expand()") + + # Factor(y, x) * + elif func_name == "factor": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([_sm.factor(i, " + self.getValue(ch.expr(1)) + ") for i in " + + expr + "])" + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "_sm.factor(" + "(" + expr + ")" + + ", " + self.getValue(ch.expr(1)) + ")") + + # D(y, x) + elif func_name == "d": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.diff(" + self.getValue(ch.expr(1)) + ") for i in " + + expr + "])" + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + if ch.getChildCount() == 8: + frame = self.symbol_table2[ch.expr(2).getText().lower()] + self.setValue(ctx, "(" + expr + ")" + "." + "diff(" + self.getValue(ch.expr(1)) + + ", " + frame + ")") + else: + self.setValue(ctx, "(" + expr + ")" + "." + "diff(" + + self.getValue(ch.expr(1)) + ")") + + # Dt(y) + elif func_name == "dt": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.vector_expr: + text = "dt(" + else: + text = "diff(_sm.Symbol('t')" + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i." + text + + ") for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + if ch.getChildCount() == 6: + frame = self.symbol_table2[ch.expr(1).getText().lower()] + self.setValue(ctx, "(" + expr + ")" + "." + "dt(" + + frame + ")") + else: + self.setValue(ctx, "(" + expr + ")" + "." + text + ")") + + # Explicit(EXPRESS(IMPLICIT>,C)) + elif func_name == "explicit": + if ch.expr(0) in self.vector_expr: + self.vector_expr.append(ctx) + expr = self.getValue(ch.expr(0)) + if self.explicit.keys(): + explicit_list = [] + for i in self.explicit.keys(): + explicit_list.append(i + ":" + self.explicit[i]) + self.setValue(ctx, "(" + expr + ")" + ".subs({" + ", ".join(explicit_list) + "})") + else: + self.setValue(ctx, expr) + + # Taylor(y, 0:2, w=a, x=0) + # TODO: Currently only works with symbols. Make it work for dynamicsymbols. + elif func_name == "taylor": + exp = self.getValue(ch.expr(0)) + order = self.getValue(ch.expr(1).expr(1)) + x = (ch.getChildCount()-6)//2 + l = [] + for i in range(x): + index = 2 + i + child = ch.expr(index) + l.append(".series(" + self.getValue(child.getChild(0)) + + ", " + self.getValue(child.getChild(2)) + + ", " + order + ").removeO()") + self.setValue(ctx, "(" + exp + ")" + "".join(l)) + + # Evaluate(y, a=x, b=2) + elif func_name == "evaluate": + expr = self.getValue(ch.expr(0)) + l = [] + x = (ch.getChildCount()-4)//2 + for i in range(x): + index = 1 + i + child = ch.expr(index) + l.append(self.getValue(child.getChild(0)) + ":" + + self.getValue(child.getChild(2))) + + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.subs({" + ",".join(l) + "}) for i in " + + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + if self.explicit: + explicit_list = [] + for i in self.explicit.keys(): + explicit_list.append(i + ":" + self.explicit[i]) + self.setValue(ctx, "(" + expr + ")" + ".subs({" + ",".join(explicit_list) + + "}).subs({" + ",".join(l) + "})") + else: + self.setValue(ctx, "(" + expr + ")" + ".subs({" + ",".join(l) + "})") + + # Polynomial([a, b, c], x) + elif func_name == "polynomial": + self.setValue(ctx, "_sm.Poly(" + self.getValue(ch.expr(0)) + ", " + + self.getValue(ch.expr(1)) + ")") + + # Roots(Poly, x, 2) + # Roots([1; 2; 3; 4]) + elif func_name == "roots": + self.matrix_expr.append(ctx) + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.setValue(ctx, "[i.evalf() for i in " + "_sm.solve(" + + "_sm.Poly(" + expr + ", " + "x),x)]") + else: + self.setValue(ctx, "[i.evalf() for i in " + "_sm.solve(" + + expr + ", " + self.getValue(ch.expr(1)) + ")]") + + # Transpose(A), Inv(A) + elif func_name in ("transpose", "inv", "inverse"): + self.matrix_expr.append(ctx) + if func_name == "transpose": + e = ".T" + elif func_name in ("inv", "inverse"): + e = "**(-1)" + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + e) + + # Eig(A) + elif func_name == "eig": + # "_sm.Matrix([i.evalf() for i in " + + self.setValue(ctx, "_sm.Matrix([i.evalf() for i in (" + + self.getValue(ch.expr(0)) + ").eigenvals().keys()])") + + # Diagmat(n, m, x) + # Diagmat(3, 1) + elif func_name == "diagmat": + self.matrix_expr.append(ctx) + if ch.getChildCount() == 6: + l = [] + for i in range(int(self.getValue(ch.expr(0)))): + l.append(self.getValue(ch.expr(1)) + ",") + + self.setValue(ctx, "_sm.diag(" + ("".join(l))[:-1] + ")") + + elif ch.getChildCount() == 8: + # _sm.Matrix([x if i==j else 0 for i in range(n) for j in range(m)]).reshape(n, m) + n = self.getValue(ch.expr(0)) + m = self.getValue(ch.expr(1)) + x = self.getValue(ch.expr(2)) + self.setValue(ctx, "_sm.Matrix([" + x + " if i==j else 0 for i in range(" + + n + ") for j in range(" + m + ")]).reshape(" + n + ", " + m + ")") + + # Cols(A) + # Cols(A, 1) + # Cols(A, 1, 2:4, 3) + elif func_name in ("cols", "rows"): + self.matrix_expr.append(ctx) + if func_name == "cols": + e1 = ".cols" + e2 = ".T." + else: + e1 = ".rows" + e2 = "." + if ch.getChildCount() == 4: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + e1) + elif ch.getChildCount() == 6: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + + e1[:-1] + "(" + str(int(self.getValue(ch.expr(1))) - 1) + ")") + else: + l = [] + for i in range(4, ch.getChildCount()): + try: + if ch.getChild(i).getChildCount() > 1 and ch.getChild(i).getChild(1).getText() == ":": + for j in range(int(ch.getChild(i).getChild(0).getText()), + int(ch.getChild(i).getChild(2).getText())+1): + l.append("(" + self.getValue(ch.getChild(2)) + ")" + e2 + + "row(" + str(j-1) + ")") + else: + l.append("(" + self.getValue(ch.getChild(2)) + ")" + e2 + + "row(" + str(int(ch.getChild(i).getText())-1) + ")") + except Exception: + pass + self.setValue(ctx, "_sm.Matrix([" + ",".join(l) + "])") + + # Det(A) Trace(A) + elif func_name in ["det", "trace"]: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + "." + + func_name + "()") + + # Element(A, 2, 3) + elif func_name == "element": + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + "[" + + str(int(self.getValue(ch.expr(1)))-1) + "," + + str(int(self.getValue(ch.expr(2)))-1) + "]") + + elif func_name in \ + ["cos", "sin", "tan", "cosh", "sinh", "tanh", "acos", "asin", "atan", + "log", "exp", "sqrt", "factorial", "floor", "sign"]: + self.setValue(ctx, "_sm." + func_name + "(" + self.getValue(ch.expr(0)) + ")") + + elif func_name == "ceil": + self.setValue(ctx, "_sm.ceiling" + "(" + self.getValue(ch.expr(0)) + ")") + + elif func_name == "sqr": + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + + ")" + "**2") + + elif func_name == "log10": + self.setValue(ctx, "_sm.log" + + "(" + self.getValue(ch.expr(0)) + ", 10)") + + elif func_name == "atan2": + self.setValue(ctx, "_sm.atan2" + "(" + self.getValue(ch.expr(0)) + ", " + + self.getValue(ch.expr(1)) + ")") + + elif func_name in ["int", "round"]: + self.setValue(ctx, func_name + + "(" + self.getValue(ch.expr(0)) + ")") + + elif func_name == "abs": + self.setValue(ctx, "_sm.Abs(" + self.getValue(ch.expr(0)) + ")") + + elif func_name in ["max", "min"]: + # max(x, y, z) + l = [] + for i in range(1, ch.getChildCount()): + if ch.getChild(i) in self.tree_property.keys(): + l.append(self.getValue(ch.getChild(i))) + elif ch.getChild(i).getText() in [",", "(", ")"]: + l.append(ch.getChild(i).getText()) + self.setValue(ctx, "_sm." + ch.getChild(0).getText().capitalize() + "".join(l)) + + # Coef(y, x) + elif func_name == "coef": + #A41_A53=COEF([RHS(U4);RHS(U5)],[U1,U2,U3]) + if ch.expr(0) in self.matrix_expr and ch.expr(1) in self.matrix_expr: + icount = jcount = 0 + for i in range(ch.expr(0).getChild(0).getChildCount()): + try: + ch.expr(0).getChild(0).getChild(i).getRuleIndex() + icount+=1 + except Exception: + pass + for j in range(ch.expr(1).getChild(0).getChildCount()): + try: + ch.expr(1).getChild(0).getChild(j).getRuleIndex() + jcount+=1 + except Exception: + pass + l = [] + for i in range(icount): + for j in range(jcount): + # a41_a53[i,j] = u4.expand().coeff(u1) + l.append(self.getValue(ch.expr(0).getChild(0).expr(i)) + ".expand().coeff(" + + self.getValue(ch.expr(1).getChild(0).expr(j)) + ")") + self.setValue(ctx, "_sm.Matrix([" + ", ".join(l) + "]).reshape(" + str(icount) + ", " + str(jcount) + ")") + else: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + + ")" + ".expand().coeff(" + self.getValue(ch.expr(1)) + ")") + + # Exclude(y, x) Include(y, x) + elif func_name in ("exclude", "include"): + if func_name == "exclude": + e = "0" + else: + e = "1" + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.collect(" + self.getValue(ch.expr(1)) + "])" + + ".coeff(" + self.getValue(ch.expr(1)) + "," + e + ")" + "for i in " + expr + ")" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "(" + expr + + ")" + ".collect(" + self.getValue(ch.expr(1)) + ")" + + ".coeff(" + self.getValue(ch.expr(1)) + "," + e + ")") + + # RHS(y) + elif func_name == "rhs": + self.setValue(ctx, self.explicit[self.getValue(ch.expr(0))]) + + # Arrange(y, n, x) * + elif func_name == "arrange": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.collect(" + self.getValue(ch.expr(2)) + + ")" + "for i in " + expr + "])"+ + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "(" + expr + + ")" + ".collect(" + self.getValue(ch.expr(2)) + ")") + + # Replace(y, sin(x)=3) + elif func_name == "replace": + l = [] + for i in range(1, ch.getChildCount()): + try: + if ch.getChild(i).getChild(1).getText() == "=": + l.append(self.getValue(ch.getChild(i).getChild(0)) + + ":" + self.getValue(ch.getChild(i).getChild(2))) + except Exception: + pass + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.subs({" + ",".join(l) + "}) for i in " + + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + + ".subs({" + ",".join(l) + "})") + + # Dot(Loop>, N1>) + elif func_name == "dot": + l = [] + num = (ch.expr(1).getChild(0).getChildCount()-1)//2 + if ch.expr(1) in self.matrix_expr: + for i in range(num): + l.append("_me.dot(" + self.getValue(ch.expr(0)) + ", " + self.getValue(ch.expr(1).getChild(0).expr(i)) + ")") + self.setValue(ctx, "_sm.Matrix([" + ",".join(l) + "]).reshape(" + str(num) + ", " + "1)") + else: + self.setValue(ctx, "_me.dot(" + self.getValue(ch.expr(0)) + ", " + self.getValue(ch.expr(1)) + ")") + # Cross(w_A_N>, P_NA_AB>) + elif func_name == "cross": + self.vector_expr.append(ctx) + self.setValue(ctx, "_me.cross(" + self.getValue(ch.expr(0)) + ", " + self.getValue(ch.expr(1)) + ")") + + # Mag(P_O_Q>) + elif func_name == "mag": + self.setValue(ctx, self.getValue(ch.expr(0)) + "." + "magnitude()") + + # MATRIX(A, I_R>>) + elif func_name == "matrix": + if self.type2[ch.expr(0).getText().lower()] == "frame": + text = "" + elif self.type2[ch.expr(0).getText().lower()] == "bodies": + text = "_f" + self.setValue(ctx, "(" + self.getValue(ch.expr(1)) + ")" + ".to_matrix(" + + self.symbol_table2[ch.expr(0).getText().lower()] + text + ")") + + # VECTOR(A, ROWS(EIGVECS,1)) + elif func_name == "vector": + if self.type2[ch.expr(0).getText().lower()] == "frame": + text = "" + elif self.type2[ch.expr(0).getText().lower()] == "bodies": + text = "_f" + v = self.getValue(ch.expr(1)) + f = self.symbol_table2[ch.expr(0).getText().lower()] + text + self.setValue(ctx, v + "[0]*" + f + ".x +" + v + "[1]*" + f + ".y +" + + v + "[2]*" + f + ".z") + + # Express(A2>, B) + # Here I am dealing with all the Inertia commands as I expect the users to use Inertia + # commands only with Express because SymPy needs the Reference frame to be specified unlike Autolev. + elif func_name == "express": + self.vector_expr.append(ctx) + if self.type2[ch.expr(1).getText().lower()] == "frame": + frame = self.symbol_table2[ch.expr(1).getText().lower()] + else: + frame = self.symbol_table2[ch.expr(1).getText().lower()] + "_f" + if ch.expr(0).getText().lower() == "1>>": + self.setValue(ctx, "_me.inertia(" + frame + ", 1, 1, 1)") + + elif '_' in ch.expr(0).getText().lower() and ch.expr(0).getText().lower().count('_') == 2\ + and ch.expr(0).getText().lower()[0] == "i" and ch.expr(0).getText().lower()[-2:] == ">>": + v1 = ch.expr(0).getText().lower()[:-2].split('_')[1] + v2 = ch.expr(0).getText().lower()[:-2].split('_')[2] + l = [] + inertia_func(self, v1, v2, l, frame) + self.setValue(ctx, " + ".join(l)) + + elif ch.expr(0).getChild(0).getChild(0).getText().lower() == "inertia": + if ch.expr(0).getChild(0).getChildCount() == 4: + l = [] + v2 = ch.expr(0).getChild(0).ID(0).getText().lower() + for v1 in self.bodies: + inertia_func(self, v1, v2, l, frame) + self.setValue(ctx, " + ".join(l)) + + else: + l = [] + l2 = [] + v2 = ch.expr(0).getChild(0).ID(0).getText().lower() + for i in range(1, (ch.expr(0).getChild(0).getChildCount()-2)//2): + l2.append(ch.expr(0).getChild(0).ID(i).getText().lower()) + for v1 in l2: + inertia_func(self, v1, v2, l, frame) + self.setValue(ctx, " + ".join(l)) + + else: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + ".express(" + + self.symbol_table2[ch.expr(1).getText().lower()] + ")") + # CM(P) + elif func_name == "cm": + if self.type2[ch.expr(0).getText().lower()] == "point": + text = "" + else: + text = ".point" + if ch.getChildCount() == 4: + self.setValue(ctx, "_me.functions.center_of_mass(" + self.symbol_table2[ch.expr(0).getText().lower()] + + text + "," + ", ".join(self.bodies.values()) + ")") + else: + bodies = [] + for i in range(1, (ch.getChildCount()-1)//2): + bodies.append(self.symbol_table2[ch.expr(i).getText().lower()]) + self.setValue(ctx, "_me.functions.center_of_mass(" + self.symbol_table2[ch.expr(0).getText().lower()] + + text + "," + ", ".join(bodies) + ")") + + # PARTIALS(V_P1_E>,U1) + elif func_name == "partials": + speeds = [] + for i in range(1, (ch.getChildCount()-1)//2): + if self.kd_equivalents2: + speeds.append(self.kd_equivalents2[self.symbol_table[ch.expr(i).getText().lower()]]) + else: + speeds.append(self.symbol_table[ch.expr(i).getText().lower()]) + v1, v2, v3 = ch.expr(0).getText().lower().replace(">","").split('_') + if self.type2[v2] == "point": + point = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + point = self.symbol_table2[v2] + ".point" + frame = self.symbol_table2[v3] + self.setValue(ctx, point + ".partial_velocity(" + frame + ", " + ",".join(speeds) + ")") + + # UnitVec(A1>+A2>+A3>) + elif func_name == "unitvec": + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + ".normalize()") + + # Units(deg, rad) + elif func_name == "units": + if ch.expr(0).getText().lower() == "deg" and ch.expr(1).getText().lower() == "rad": + factor = 0.0174533 + elif ch.expr(0).getText().lower() == "rad" and ch.expr(1).getText().lower() == "deg": + factor = 57.2958 + self.setValue(ctx, str(factor)) + # Mass(A) + elif func_name == "mass": + l = [] + try: + ch.ID(0).getText().lower() + for i in range((ch.getChildCount()-1)//2): + l.append(self.symbol_table2[ch.ID(i).getText().lower()] + ".mass") + self.setValue(ctx, "+".join(l)) + except Exception: + for i in self.bodies.keys(): + l.append(self.bodies[i] + ".mass") + self.setValue(ctx, "+".join(l)) + + # Fr() FrStar() + # _me.KanesMethod(n, q_ind, u_ind, kd, velocity_constraints).kanes_equations(pl, fl)[0] + elif func_name in ["fr", "frstar"]: + if not self.kane_parsed: + if self.kd_eqs: + for i in self.kd_eqs: + self.q_ind.append(self.symbol_table[i.strip().split('-')[0].replace("'","")]) + self.u_ind.append(self.symbol_table[i.strip().split('-')[1].replace("'","")]) + + for i in range(len(self.kd_eqs)): + self.kd_eqs[i] = self.symbol_table[self.kd_eqs[i].strip().split('-')[0]] + " - " +\ + self.symbol_table[self.kd_eqs[i].strip().split('-')[1]] + + # Do all of this if kd_eqs are not specified + if not self.kd_eqs: + self.kd_eqs_supplied = False + self.matrix_expr.append(ctx) + for i in self.type.keys(): + if self.type[i] == "motionvariable": + if self.sign[self.symbol_table[i.lower()]] == 0: + self.q_ind.append(self.symbol_table[i.lower()]) + elif self.sign[self.symbol_table[i.lower()]] == 1: + name = "u_" + self.symbol_table[i.lower()] + self.symbol_table.update({name: name}) + self.write(name + " = " + "_me.dynamicsymbols('" + name + "')\n") + if self.symbol_table[i.lower()] not in self.dependent_variables: + self.u_ind.append(name) + self.kd_equivalents.update({name: self.symbol_table[i.lower()]}) + else: + self.u_dep.append(name) + self.kd_equivalents.update({name: self.symbol_table[i.lower()]}) + + for i in self.kd_equivalents.keys(): + self.kd_eqs.append(self.kd_equivalents[i] + "-" + i) + + if not self.u_ind and not self.kd_eqs: + self.u_ind = self.q_ind.copy() + self.q_ind = [] + + # deal with velocity constraints + if self.dependent_variables: + for i in self.dependent_variables: + self.u_dep.append(i) + if i in self.u_ind: + self.u_ind.remove(i) + + + self.u_dep[:] = [i for i in self.u_dep if i not in self.kd_equivalents.values()] + + force_list = [] + for i in self.forces.keys(): + force_list.append("(" + i + "," + self.forces[i] + ")") + if self.u_dep: + u_dep_text = ", u_dependent=[" + ", ".join(self.u_dep) + "]" + else: + u_dep_text = "" + if self.dependent_variables: + velocity_constraints_text = ", velocity_constraints = velocity_constraints" + else: + velocity_constraints_text = "" + if ctx.parentCtx not in self.fr_expr: + self.write("kd_eqs = [" + ", ".join(self.kd_eqs) + "]\n") + self.write("forceList = " + "[" + ", ".join(force_list) + "]\n") + self.write("kane = _me.KanesMethod(" + self.newtonian + ", " + "q_ind=[" + + ",".join(self.q_ind) + "], " + "u_ind=[" + + ", ".join(self.u_ind) + "]" + u_dep_text + ", " + + "kd_eqs = kd_eqs" + velocity_constraints_text + ")\n") + self.write("fr, frstar = kane." + "kanes_equations([" + + ", ".join(self.bodies.values()) + "], forceList)\n") + self.fr_expr.append(ctx.parentCtx) + self.kane_parsed = True + self.setValue(ctx, func_name) + + def exitMatrices(self, ctx): + # Tree annotation for Matrices which is a labeled subrule of the parser rule expr. + + # MO = [a, b; c, d] + # we generate _sm.Matrix([a, b, c, d]).reshape(2, 2) + # The reshape values are determined by counting the "," and ";" in the Autolev matrix + + # Eg: + # [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12] + # semicolon_count = 3 and rows = 3+1 = 4 + # comma_count = 8 and cols = 8/rows + 1 = 8/4 + 1 = 3 + + # TODO** Parse block matrices + self.matrix_expr.append(ctx) + l = [] + semicolon_count = 0 + comma_count = 0 + for i in range(ctx.matrix().getChildCount()): + child = ctx.matrix().getChild(i) + if child == AutolevParser.ExprContext: + l.append(self.getValue(child)) + elif child.getText() == ";": + semicolon_count += 1 + l.append(",") + elif child.getText() == ",": + comma_count += 1 + l.append(",") + else: + try: + try: + l.append(self.getValue(child)) + except Exception: + l.append(self.symbol_table[child.getText().lower()]) + except Exception: + l.append(child.getText().lower()) + num_of_rows = semicolon_count + 1 + num_of_cols = (comma_count//num_of_rows) + 1 + + self.setValue(ctx, "_sm.Matrix(" + "".join(l) + ")" + ".reshape(" + + str(num_of_rows) + ", " + str(num_of_cols) + ")") + + def exitVectorOrDyadic(self, ctx): + self.vector_expr.append(ctx) + ch = ctx.vec() + + if ch.getChild(0).getText() == "0>": + self.setValue(ctx, "0") + + elif ch.getChild(0).getText() == "1>>": + self.setValue(ctx, "1>>") + + elif "_" in ch.ID().getText() and ch.ID().getText().count('_') == 2: + vec_text = ch.getText().lower() + v1, v2, v3 = ch.ID().getText().lower().split('_') + + if v1 == "p": + if self.type2[v2] == "point": + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + if self.type2[v3] == "point": + e3 = self.symbol_table2[v3] + elif self.type2[v3] == "particle": + e3 = self.symbol_table2[v3] + ".point" + get_vec = e3 + ".pos_from(" + e2 + ")" + self.setValue(ctx, get_vec) + + elif v1 in ("w", "alf"): + if v1 == "w": + text = ".ang_vel_in(" + elif v1 == "alf": + text = ".ang_acc_in(" + if self.type2[v2] == "bodies": + e2 = self.symbol_table2[v2] + "_f" + elif self.type2[v2] == "frame": + e2 = self.symbol_table2[v2] + if self.type2[v3] == "bodies": + e3 = self.symbol_table2[v3] + "_f" + elif self.type2[v3] == "frame": + e3 = self.symbol_table2[v3] + get_vec = e2 + text + e3 + ")" + self.setValue(ctx, get_vec) + + elif v1 in ("v", "a"): + if v1 == "v": + text = ".vel(" + elif v1 == "a": + text = ".acc(" + if self.type2[v2] == "point": + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + get_vec = e2 + text + self.symbol_table2[v3] + ")" + self.setValue(ctx, get_vec) + + else: + self.setValue(ctx, vec_text.replace(">", "")) + + else: + vec_text = ch.getText().lower() + name = self.symbol_table[vec_text] + self.setValue(ctx, name) + + def exitIndexing(self, ctx): + if ctx.getChildCount() == 4: + try: + int_text = str(int(self.getValue(ctx.getChild(2))) - 1) + except Exception: + int_text = self.getValue(ctx.getChild(2)) + " - 1" + self.setValue(ctx, ctx.ID().getText().lower() + "[" + int_text + "]") + elif ctx.getChildCount() == 6: + try: + int_text1 = str(int(self.getValue(ctx.getChild(2))) - 1) + except Exception: + int_text1 = self.getValue(ctx.getChild(2)) + " - 1" + try: + int_text2 = str(int(self.getValue(ctx.getChild(4))) - 1) + except Exception: + int_text2 = self.getValue(ctx.getChild(2)) + " - 1" + self.setValue(ctx, ctx.ID().getText().lower() + "[" + int_text1 + ", " + int_text2 + "]") + + + # ================== Subrules of parser rule expr (End) ====================== # + + def exitRegularAssign(self, ctx): + # Handle assignments of type ID = expr + if ctx.equals().getText() in ["=", "+=", "-=", "*=", "/="]: + equals = ctx.equals().getText() + elif ctx.equals().getText() == ":=": + equals = " = " + elif ctx.equals().getText() == "^=": + equals = "**=" + + try: + a = ctx.ID().getText().lower() + "'"*ctx.diff().getText().count("'") + except Exception: + a = ctx.ID().getText().lower() + + if a in self.type.keys() and self.type[a] in ("motionvariable", "motionvariable'") and\ + self.type[ctx.expr().getText().lower()] in ("motionvariable", "motionvariable'"): + b = ctx.expr().getText().lower() + if "'" in b and "'" not in a: + a, b = b, a + if not self.kane_parsed: + self.kd_eqs.append(a + "-" + b) + self.kd_equivalents.update({self.symbol_table[a]: + self.symbol_table[b]}) + self.kd_equivalents2.update({self.symbol_table[b]: + self.symbol_table[a]}) + + if a in self.symbol_table.keys() and a in self.type.keys() and self.type[a] in ("variable", "motionvariable"): + self.explicit.update({self.symbol_table[a]: self.getValue(ctx.expr())}) + + else: + if ctx.expr() in self.matrix_expr: + self.type.update({a: "matrix"}) + + try: + b = self.symbol_table[a] + except KeyError: + self.symbol_table[a] = a + + if "_" in a and a.count("_") == 1: + e1, e2 = a.split('_') + if e1 in self.type2.keys() and self.type2[e1] in ("frame", "bodies")\ + and e2 in self.type2.keys() and self.type2[e2] in ("frame", "bodies"): + if self.type2[e1] == "bodies": + t1 = "_f" + else: + t1 = "" + if self.type2[e2] == "bodies": + t2 = "_f" + else: + t2 = "" + + self.write(self.symbol_table2[e2] + t2 + ".orient(" + self.symbol_table2[e1] + + t1 + ", 'DCM', " + self.getValue(ctx.expr()) + ")\n") + else: + self.write(self.symbol_table[a] + " " + equals + " " + + self.getValue(ctx.expr()) + "\n") + else: + self.write(self.symbol_table[a] + " " + equals + " " + + self.getValue(ctx.expr()) + "\n") + + def exitIndexAssign(self, ctx): + # Handle assignments of type ID[index] = expr + if ctx.equals().getText() in ["=", "+=", "-=", "*=", "/="]: + equals = ctx.equals().getText() + elif ctx.equals().getText() == ":=": + equals = " = " + elif ctx.equals().getText() == "^=": + equals = "**=" + + text = ctx.ID().getText().lower() + self.type.update({text: "matrix"}) + # Handle assignments of type ID[2] = expr + if ctx.index().getChildCount() == 1: + if ctx.index().getChild(0).getText() == "1": + self.type.update({text: "matrix"}) + self.symbol_table.update({text: text}) + self.write(text + " = " + "_sm.Matrix([[0]])\n") + self.write(text + "[0] = " + self.getValue(ctx.expr()) + "\n") + else: + # m = m.row_insert(m.shape[0], _sm.Matrix([[0]])) + self.write(text + " = " + text + + ".row_insert(" + text + ".shape[0]" + ", " + "_sm.Matrix([[0]])" + ")\n") + self.write(text + "[" + text + ".shape[0]-1" + "] = " + self.getValue(ctx.expr()) + "\n") + + # Handle assignments of type ID[2, 2] = expr + elif ctx.index().getChildCount() == 3: + l = [] + try: + l.append(str(int(self.getValue(ctx.index().getChild(0)))-1)) + except Exception: + l.append(self.getValue(ctx.index().getChild(0)) + "-1") + l.append(",") + try: + l.append(str(int(self.getValue(ctx.index().getChild(2)))-1)) + except Exception: + l.append(self.getValue(ctx.index().getChild(2)) + "-1") + self.write(self.symbol_table[ctx.ID().getText().lower()] + + "[" + "".join(l) + "]" + " " + equals + " " + self.getValue(ctx.expr()) + "\n") + + def exitVecAssign(self, ctx): + # Handle assignments of the type vec = expr + ch = ctx.vec() + vec_text = ch.getText().lower() + + if "_" in ch.ID().getText(): + num = ch.ID().getText().count('_') + + if num == 2: + v1, v2, v3 = ch.ID().getText().lower().split('_') + + if v1 == "p": + if self.type2[v2] == "point": + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + if self.type2[v3] == "point": + e3 = self.symbol_table2[v3] + elif self.type2[v3] == "particle": + e3 = self.symbol_table2[v3] + ".point" + # ab.set_pos(na, la*a.x) + self.write(e3 + ".set_pos(" + e2 + ", " + self.getValue(ctx.expr()) + ")\n") + + elif v1 in ("w", "alf"): + if v1 == "w": + text = ".set_ang_vel(" + elif v1 == "alf": + text = ".set_ang_acc(" + # a.set_ang_vel(n, qad*a.z) + if self.type2[v2] == "bodies": + e2 = self.symbol_table2[v2] + "_f" + else: + e2 = self.symbol_table2[v2] + if self.type2[v3] == "bodies": + e3 = self.symbol_table2[v3] + "_f" + else: + e3 = self.symbol_table2[v3] + self.write(e2 + text + e3 + ", " + self.getValue(ctx.expr()) + ")\n") + + elif v1 in ("v", "a"): + if v1 == "v": + text = ".set_vel(" + elif v1 == "a": + text = ".set_acc(" + if self.type2[v2] == "point": + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + self.write(e2 + text + self.symbol_table2[v3] + + ", " + self.getValue(ctx.expr()) + ")\n") + elif v1 == "i": + if v2 in self.type2.keys() and self.type2[v2] == "bodies": + self.write(self.symbol_table2[v2] + ".inertia = (" + self.getValue(ctx.expr()) + + ", " + self.symbol_table2[v3] + ")\n") + self.inertia_point.update({v2: v3}) + elif v2 in self.type2.keys() and self.type2[v2] == "particle": + self.write(ch.ID().getText().lower() + " = " + self.getValue(ctx.expr()) + "\n") + else: + self.write(ch.ID().getText().lower() + " = " + self.getValue(ctx.expr()) + "\n") + else: + self.write(ch.ID().getText().lower() + " = " + self.getValue(ctx.expr()) + "\n") + + elif num == 1: + v1, v2 = ch.ID().getText().lower().split('_') + + if v1 in ("force", "torque"): + if self.type2[v2] in ("point", "frame"): + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + self.symbol_table.update({vec_text: ch.ID().getText().lower()}) + + if e2 in self.forces.keys(): + self.forces[e2] = self.forces[e2] + " + " + self.getValue(ctx.expr()) + else: + self.forces.update({e2: self.getValue(ctx.expr())}) + self.write(ch.ID().getText().lower() + " = " + self.forces[e2] + "\n") + + else: + name = ch.ID().getText().lower() + self.symbol_table.update({vec_text: name}) + self.write(ch.ID().getText().lower() + " = " + self.getValue(ctx.expr()) + "\n") + else: + name = ch.ID().getText().lower() + self.symbol_table.update({vec_text: name}) + self.write(name + " " + ctx.getChild(1).getText() + " " + self.getValue(ctx.expr()) + "\n") + else: + name = ch.ID().getText().lower() + self.symbol_table.update({vec_text: name}) + self.write(name + " " + ctx.getChild(1).getText() + " " + self.getValue(ctx.expr()) + "\n") + + def enterInputs2(self, ctx): + self.in_inputs = True + + # Inputs + def exitInputs2(self, ctx): + # Stores numerical values given by the input command which + # are used for codegen and numerical analysis. + if ctx.getChildCount() == 3: + try: + self.inputs.update({self.symbol_table[ctx.id_diff().getText().lower()]: self.getValue(ctx.expr(0))}) + except Exception: + self.inputs.update({ctx.id_diff().getText().lower(): self.getValue(ctx.expr(0))}) + elif ctx.getChildCount() == 4: + try: + self.inputs.update({self.symbol_table[ctx.id_diff().getText().lower()]: + (self.getValue(ctx.expr(0)), self.getValue(ctx.expr(1)))}) + except Exception: + self.inputs.update({ctx.id_diff().getText().lower(): + (self.getValue(ctx.expr(0)), self.getValue(ctx.expr(1)))}) + + self.in_inputs = False + + def enterOutputs(self, ctx): + self.in_outputs = True + def exitOutputs(self, ctx): + self.in_outputs = False + + def exitOutputs2(self, ctx): + try: + if "[" in ctx.expr(1).getText(): + self.outputs.append(self.symbol_table[ctx.expr(0).getText().lower()] + + ctx.expr(1).getText().lower()) + else: + self.outputs.append(self.symbol_table[ctx.expr(0).getText().lower()]) + + except Exception: + pass + + # Code commands + def exitCodegen(self, ctx): + # Handles the CODE() command ie the solvers and the codgen part. + # Uses linsolve for the algebraic solvers and nsolve for non linear solvers. + + if ctx.functionCall().getChild(0).getText().lower() == "algebraic": + matrix_name = self.getValue(ctx.functionCall().expr(0)) + e = [] + d = [] + for i in range(1, (ctx.functionCall().getChildCount()-2)//2): + a = self.getValue(ctx.functionCall().expr(i)) + e.append(a) + + for i in self.inputs.keys(): + d.append(i + ":" + self.inputs[i]) + self.write(matrix_name + "_list" + " = " + "[]\n") + self.write("for i in " + matrix_name + ": " + matrix_name + + "_list" + ".append(i.subs({" + ", ".join(d) + "}))\n") + self.write("print(_sm.linsolve(" + matrix_name + "_list" + ", " + ",".join(e) + "))\n") + + elif ctx.functionCall().getChild(0).getText().lower() == "nonlinear": + e = [] + d = [] + guess = [] + for i in range(1, (ctx.functionCall().getChildCount()-2)//2): + a = self.getValue(ctx.functionCall().expr(i)) + e.append(a) + #print(self.inputs) + for i in self.inputs.keys(): + if i in self.symbol_table.keys(): + if type(self.inputs[i]) is tuple: + j, z = self.inputs[i] + else: + j = self.inputs[i] + z = "" + if i not in e: + if z == "deg": + d.append(i + ":" + "_np.deg2rad(" + j + ")") + else: + d.append(i + ":" + j) + else: + if z == "deg": + guess.append("_np.deg2rad(" + j + ")") + else: + guess.append(j) + + self.write("matrix_list" + " = " + "[]\n") + self.write("for i in " + self.getValue(ctx.functionCall().expr(0)) + ":") + self.write("matrix_list" + ".append(i.subs({" + ", ".join(d) + "}))\n") + self.write("print(_sm.nsolve(matrix_list," + "(" + ",".join(e) + ")" + + ",(" + ",".join(guess) + ")" + "))\n") + + elif ctx.functionCall().getChild(0).getText().lower() in ["ode", "dynamics"] and self.include_numeric: + if self.kane_type == "no_args": + for i in self.symbol_table.keys(): + try: + if self.type[i] == "constants" or self.type[self.symbol_table[i]] == "constants": + self.constants.append(self.symbol_table[i]) + except Exception: + pass + q_add_u = self.q_ind + self.q_dep + self.u_ind + self.u_dep + x0 = [] + for i in q_add_u: + try: + if i in self.inputs.keys(): + if type(self.inputs[i]) is tuple: + if self.inputs[i][1] == "deg": + x0.append(i + ":" + "_np.deg2rad(" + self.inputs[i][0] + ")") + else: + x0.append(i + ":" + self.inputs[i][0]) + else: + x0.append(i + ":" + self.inputs[i]) + elif self.kd_equivalents[i] in self.inputs.keys(): + if type(self.inputs[self.kd_equivalents[i]]) is tuple: + x0.append(i + ":" + self.inputs[self.kd_equivalents[i]][0]) + else: + x0.append(i + ":" + self.inputs[self.kd_equivalents[i]]) + except Exception: + pass + + # numerical constants + numerical_constants = [] + for i in self.constants: + if i in self.inputs.keys(): + if type(self.inputs[i]) is tuple: + numerical_constants.append(self.inputs[i][0]) + else: + numerical_constants.append(self.inputs[i]) + + # t = linspace + t_final = self.inputs["tfinal"] + integ_stp = self.inputs["integstp"] + + self.write("from pydy.system import System\n") + const_list = [] + if numerical_constants: + for i in range(len(self.constants)): + const_list.append(self.constants[i] + ":" + numerical_constants[i]) + specifieds = [] + if self.t: + specifieds.append("_me.dynamicsymbols('t')" + ":" + "lambda x, t: t") + + for i in self.inputs: + if i in self.symbol_table.keys() and self.symbol_table[i] not in\ + self.constants + self.q_ind + self.q_dep + self.u_ind + self.u_dep: + specifieds.append(self.symbol_table[i] + ":" + self.inputs[i]) + + self.write("sys = System(kane, constants = {" + ", ".join(const_list) + "},\n" + + "specifieds={" + ", ".join(specifieds) + "},\n" + + "initial_conditions={" + ", ".join(x0) + "},\n" + + "times = _np.linspace(0.0, " + str(t_final) + ", " + str(t_final) + + "/" + str(integ_stp) + "))\n\ny=sys.integrate()\n") + + # For outputs other than qs and us. + other_outputs = [] + for i in self.outputs: + if i not in q_add_u: + if "[" in i: + other_outputs.append((i[:-3] + i[-2], i[:-3] + "[" + str(int(i[-2])-1) + "]")) + else: + other_outputs.append((i, i)) + + for i in other_outputs: + self.write(i[0] + "_out" + " = " + "[]\n") + if other_outputs: + self.write("for i in y:\n") + self.write(" q_u_dict = dict(zip(sys.coordinates+sys.speeds, i))\n") + for i in other_outputs: + self.write(" "*4 + i[0] + "_out" + ".append(" + i[1] + ".subs(q_u_dict)" + + ".subs(sys.constants).evalf())\n") + + # Standalone function calls (used for dual functions) + def exitFunctionCall(self, ctx): + # Basically deals with standalone function calls ie functions which are not a part of + # expressions and assignments. Autolev Dual functions can both appear in standalone + # function calls and also on the right hand side as part of expr or assignment. + + # Dual functions are indicated by a * in the comments below + + # Checks if the function is a statement on its own + if ctx.parentCtx.getRuleIndex() == AutolevParser.RULE_stat: + func_name = ctx.getChild(0).getText().lower() + # Expand(E, n:m) * + if func_name == "expand": + # If the first argument is a pre declared variable. + expr = self.getValue(ctx.expr(0)) + symbol = self.symbol_table[ctx.expr(0).getText().lower()] + if ctx.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.write(symbol + " = " + "_sm.Matrix([i.expand() for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])\n") + else: + self.write(symbol + " = " + symbol + "." + "expand()\n") + + # Factor(E, x) * + elif func_name == "factor": + expr = self.getValue(ctx.expr(0)) + symbol = self.symbol_table[ctx.expr(0).getText().lower()] + if ctx.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.write(symbol + " = " + "_sm.Matrix([_sm.factor(i," + self.getValue(ctx.expr(1)) + + ") for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])\n") + else: + self.write(expr + " = " + "_sm.factor(" + expr + ", " + + self.getValue(ctx.expr(1)) + ")\n") + + # Solve(Zero, x, y) + elif func_name == "solve": + l = [] + l2 = [] + num = 0 + for i in range(1, ctx.getChildCount()): + if ctx.getChild(i).getText() == ",": + num+=1 + try: + l.append(self.getValue(ctx.getChild(i))) + except Exception: + l.append(ctx.getChild(i).getText()) + + if i != 2: + try: + l2.append(self.getValue(ctx.getChild(i))) + except Exception: + pass + + for i in l2: + self.explicit.update({i: "_sm.solve" + "".join(l) + "[" + i + "]"}) + + self.write("print(_sm.solve" + "".join(l) + ")\n") + + # Arrange(y, n, x) * + elif func_name == "arrange": + expr = self.getValue(ctx.expr(0)) + symbol = self.symbol_table[ctx.expr(0).getText().lower()] + + if ctx.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.write(symbol + " = " + "_sm.Matrix([i.collect(" + self.getValue(ctx.expr(2)) + + ")" + "for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])\n") + else: + self.write(self.getValue(ctx.expr(0)) + ".collect(" + + self.getValue(ctx.expr(2)) + ")\n") + + # Eig(M, EigenValue, EigenVec) + elif func_name == "eig": + self.symbol_table.update({ctx.expr(1).getText().lower(): ctx.expr(1).getText().lower()}) + self.symbol_table.update({ctx.expr(2).getText().lower(): ctx.expr(2).getText().lower()}) + # _sm.Matrix([i.evalf() for i in (i_s_so).eigenvals().keys()]) + self.write(ctx.expr(1).getText().lower() + " = " + + "_sm.Matrix([i.evalf() for i in " + + "(" + self.getValue(ctx.expr(0)) + ")" + ".eigenvals().keys()])\n") + # _sm.Matrix([i[2][0].evalf() for i in (i_s_o).eigenvects()]).reshape(i_s_o.shape[0], i_s_o.shape[1]) + self.write(ctx.expr(2).getText().lower() + " = " + + "_sm.Matrix([i[2][0].evalf() for i in " + "(" + self.getValue(ctx.expr(0)) + ")" + + ".eigenvects()]).reshape(" + self.getValue(ctx.expr(0)) + ".shape[0], " + + self.getValue(ctx.expr(0)) + ".shape[1])\n") + + # Simprot(N, A, 3, qA) + elif func_name == "simprot": + # A.orient(N, 'Axis', qA, N.z) + if self.type2[ctx.expr(0).getText().lower()] == "frame": + frame1 = self.symbol_table2[ctx.expr(0).getText().lower()] + elif self.type2[ctx.expr(0).getText().lower()] == "bodies": + frame1 = self.symbol_table2[ctx.expr(0).getText().lower()] + "_f" + if self.type2[ctx.expr(1).getText().lower()] == "frame": + frame2 = self.symbol_table2[ctx.expr(1).getText().lower()] + elif self.type2[ctx.expr(1).getText().lower()] == "bodies": + frame2 = self.symbol_table2[ctx.expr(1).getText().lower()] + "_f" + e2 = "" + if ctx.expr(2).getText()[0] == "-": + e2 = "-1*" + if ctx.expr(2).getText() in ("1", "-1"): + e = frame1 + ".x" + elif ctx.expr(2).getText() in ("2", "-2"): + e = frame1 + ".y" + elif ctx.expr(2).getText() in ("3", "-3"): + e = frame1 + ".z" + else: + e = self.getValue(ctx.expr(2)) + e2 = "" + + if "degrees" in self.settings.keys() and self.settings["degrees"] == "off": + value = self.getValue(ctx.expr(3)) + else: + if ctx.expr(3) in self.numeric_expr: + value = "_np.deg2rad(" + self.getValue(ctx.expr(3)) + ")" + else: + value = self.getValue(ctx.expr(3)) + self.write(frame2 + ".orient(" + frame1 + + ", " + "'Axis'" + ", " + "[" + value + + ", " + e2 + e + "]" + ")\n") + + # Express(A2>, B) * + elif func_name == "express": + if self.type2[ctx.expr(1).getText().lower()] == "bodies": + f = "_f" + else: + f = "" + + if '_' in ctx.expr(0).getText().lower() and ctx.expr(0).getText().count('_') == 2: + vec = ctx.expr(0).getText().lower().replace(">", "").split('_') + v1 = self.symbol_table2[vec[1]] + v2 = self.symbol_table2[vec[2]] + if vec[0] == "p": + self.write(v2 + ".set_pos(" + v1 + ", " + "(" + self.getValue(ctx.expr(0)) + + ")" + ".express(" + self.symbol_table2[ctx.expr(1).getText().lower()] + f + "))\n") + elif vec[0] == "v": + self.write(v1 + ".set_vel(" + v2 + ", " + "(" + self.getValue(ctx.expr(0)) + + ")" + ".express(" + self.symbol_table2[ctx.expr(1).getText().lower()] + f + "))\n") + elif vec[0] == "a": + self.write(v1 + ".set_acc(" + v2 + ", " + "(" + self.getValue(ctx.expr(0)) + + ")" + ".express(" + self.symbol_table2[ctx.expr(1).getText().lower()] + f + "))\n") + else: + self.write(self.getValue(ctx.expr(0)) + " = " + "(" + self.getValue(ctx.expr(0)) + ")" + ".express(" + + self.symbol_table2[ctx.expr(1).getText().lower()] + f + ")\n") + else: + self.write(self.getValue(ctx.expr(0)) + " = " + "(" + self.getValue(ctx.expr(0)) + ")" + ".express(" + + self.symbol_table2[ctx.expr(1).getText().lower()] + f + ")\n") + + # Angvel(A, B) + elif func_name == "angvel": + self.write("print(" + self.symbol_table2[ctx.expr(1).getText().lower()] + + ".ang_vel_in(" + self.symbol_table2[ctx.expr(0).getText().lower()] + "))\n") + + # v2pts(N, A, O, P) + elif func_name in ("v2pts", "a2pts", "v2pt", "a1pt"): + if func_name == "v2pts": + text = ".v2pt_theory(" + elif func_name == "a2pts": + text = ".a2pt_theory(" + elif func_name == "v1pt": + text = ".v1pt_theory(" + elif func_name == "a1pt": + text = ".a1pt_theory(" + if self.type2[ctx.expr(1).getText().lower()] == "frame": + frame = self.symbol_table2[ctx.expr(1).getText().lower()] + elif self.type2[ctx.expr(1).getText().lower()] == "bodies": + frame = self.symbol_table2[ctx.expr(1).getText().lower()] + "_f" + expr_list = [] + for i in range(2, 4): + if self.type2[ctx.expr(i).getText().lower()] == "point": + expr_list.append(self.symbol_table2[ctx.expr(i).getText().lower()]) + elif self.type2[ctx.expr(i).getText().lower()] == "particle": + expr_list.append(self.symbol_table2[ctx.expr(i).getText().lower()] + ".point") + + self.write(expr_list[1] + text + expr_list[0] + + "," + self.symbol_table2[ctx.expr(0).getText().lower()] + "," + + frame + ")\n") + + # Gravity(g*N1>) + elif func_name == "gravity": + for i in self.bodies.keys(): + if self.type2[i] == "bodies": + e = self.symbol_table2[i] + ".masscenter" + elif self.type2[i] == "particle": + e = self.symbol_table2[i] + ".point" + if e in self.forces.keys(): + self.forces[e] = self.forces[e] + self.symbol_table2[i] +\ + ".mass*(" + self.getValue(ctx.expr(0)) + ")" + else: + self.forces.update({e: self.symbol_table2[i] + + ".mass*(" + self.getValue(ctx.expr(0)) + ")"}) + self.write("force_" + i + " = " + self.forces[e] + "\n") + + # Explicit(EXPRESS(IMPLICIT>,C)) + elif func_name == "explicit": + if ctx.expr(0) in self.vector_expr: + self.vector_expr.append(ctx) + expr = self.getValue(ctx.expr(0)) + if self.explicit.keys(): + explicit_list = [] + for i in self.explicit.keys(): + explicit_list.append(i + ":" + self.explicit[i]) + if '_' in ctx.expr(0).getText().lower() and ctx.expr(0).getText().count('_') == 2: + vec = ctx.expr(0).getText().lower().replace(">", "").split('_') + v1 = self.symbol_table2[vec[1]] + v2 = self.symbol_table2[vec[2]] + if vec[0] == "p": + self.write(v2 + ".set_pos(" + v1 + ", " + "(" + expr + + ")" + ".subs({" + ", ".join(explicit_list) + "}))\n") + elif vec[0] == "v": + self.write(v2 + ".set_vel(" + v1 + ", " + "(" + expr + + ")" + ".subs({" + ", ".join(explicit_list) + "}))\n") + elif vec[0] == "a": + self.write(v2 + ".set_acc(" + v1 + ", " + "(" + expr + + ")" + ".subs({" + ", ".join(explicit_list) + "}))\n") + else: + self.write(expr + " = " + "(" + expr + ")" + ".subs({" + ", ".join(explicit_list) + "})\n") + else: + self.write(expr + " = " + "(" + expr + ")" + ".subs({" + ", ".join(explicit_list) + "})\n") + + # Force(O/Q, -k*Stretch*Uvec>) + elif func_name in ("force", "torque"): + + if "/" in ctx.expr(0).getText().lower(): + p1 = ctx.expr(0).getText().lower().split('/')[0] + p2 = ctx.expr(0).getText().lower().split('/')[1] + if self.type2[p1] in ("point", "frame"): + pt1 = self.symbol_table2[p1] + elif self.type2[p1] == "particle": + pt1 = self.symbol_table2[p1] + ".point" + if self.type2[p2] in ("point", "frame"): + pt2 = self.symbol_table2[p2] + elif self.type2[p2] == "particle": + pt2 = self.symbol_table2[p2] + ".point" + if pt1 in self.forces.keys(): + self.forces[pt1] = self.forces[pt1] + " + -1*("+self.getValue(ctx.expr(1)) + ")" + self.write("force_" + p1 + " = " + self.forces[pt1] + "\n") + else: + self.forces.update({pt1: "-1*("+self.getValue(ctx.expr(1)) + ")"}) + self.write("force_" + p1 + " = " + self.forces[pt1] + "\n") + if pt2 in self.forces.keys(): + self.forces[pt2] = self.forces[pt2] + "+ " + self.getValue(ctx.expr(1)) + self.write("force_" + p2 + " = " + self.forces[pt2] + "\n") + else: + self.forces.update({pt2: self.getValue(ctx.expr(1))}) + self.write("force_" + p2 + " = " + self.forces[pt2] + "\n") + + elif ctx.expr(0).getChildCount() == 1: + p1 = ctx.expr(0).getText().lower() + if self.type2[p1] in ("point", "frame"): + pt1 = self.symbol_table2[p1] + elif self.type2[p1] == "particle": + pt1 = self.symbol_table2[p1] + ".point" + if pt1 in self.forces.keys(): + self.forces[pt1] = self.forces[pt1] + "+ -1*(" + self.getValue(ctx.expr(1)) + ")" + else: + self.forces.update({pt1: "-1*(" + self.getValue(ctx.expr(1)) + ")"}) + + # Constrain(Dependent[qB]) + elif func_name == "constrain": + if ctx.getChild(2).getChild(0).getText().lower() == "dependent": + self.write("velocity_constraints = [i for i in dependent]\n") + x = (ctx.expr(0).getChildCount()-2)//2 + for i in range(x): + self.dependent_variables.append(self.getValue(ctx.expr(0).expr(i))) + + # Kane() + elif func_name == "kane": + if ctx.getChildCount() == 3: + self.kane_type = "no_args" + + # Settings + def exitSettings(self, ctx): + # Stores settings like Complex on/off, Degrees on/off etc in self.settings. + try: + self.settings.update({ctx.getChild(0).getText().lower(): + ctx.getChild(1).getText().lower()}) + except Exception: + pass + + def exitMassDecl2(self, ctx): + # Used for declaring the masses of particles and rigidbodies. + particle = self.symbol_table2[ctx.getChild(0).getText().lower()] + if ctx.getText().count("=") == 2: + if ctx.expr().expr(1) in self.numeric_expr: + e = "_sm.S(" + self.getValue(ctx.expr().expr(1)) + ")" + else: + e = self.getValue(ctx.expr().expr(1)) + self.symbol_table.update({ctx.expr().expr(0).getText().lower(): ctx.expr().expr(0).getText().lower()}) + self.write(ctx.expr().expr(0).getText().lower() + " = " + e + "\n") + mass = ctx.expr().expr(0).getText().lower() + else: + try: + if ctx.expr() in self.numeric_expr: + mass = "_sm.S(" + self.getValue(ctx.expr()) + ")" + else: + mass = self.getValue(ctx.expr()) + except Exception: + a_text = ctx.expr().getText().lower() + self.symbol_table.update({a_text: a_text}) + self.type.update({a_text: "constants"}) + self.write(a_text + " = " + "_sm.symbols('" + a_text + "')\n") + mass = a_text + + self.write(particle + ".mass = " + mass + "\n") + + def exitInertiaDecl(self, ctx): + inertia_list = [] + try: + ctx.ID(1).getText() + num = 5 + except Exception: + num = 2 + for i in range((ctx.getChildCount()-num)//2): + try: + if ctx.expr(i) in self.numeric_expr: + inertia_list.append("_sm.S(" + self.getValue(ctx.expr(i)) + ")") + else: + inertia_list.append(self.getValue(ctx.expr(i))) + except Exception: + a_text = ctx.expr(i).getText().lower() + self.symbol_table.update({a_text: a_text}) + self.type.update({a_text: "constants"}) + self.write(a_text + " = " + "_sm.symbols('" + a_text + "')\n") + inertia_list.append(a_text) + + if len(inertia_list) < 6: + for i in range(6-len(inertia_list)): + inertia_list.append("0") + # body_a.inertia = (_me.inertia(body_a, I1, I2, I3, 0, 0, 0), body_a_cm) + try: + frame = self.symbol_table2[ctx.ID(1).getText().lower()] + point = self.symbol_table2[ctx.ID(0).getText().lower().split('_')[1]] + body = self.symbol_table2[ctx.ID(0).getText().lower().split('_')[0]] + self.inertia_point.update({ctx.ID(0).getText().lower().split('_')[0] + : ctx.ID(0).getText().lower().split('_')[1]}) + self.write(body + ".inertia" + " = " + "(_me.inertia(" + frame + ", " + + ", ".join(inertia_list) + "), " + point + ")\n") + + except Exception: + body_name = self.symbol_table2[ctx.ID(0).getText().lower()] + body_name_cm = body_name + "_cm" + self.inertia_point.update({ctx.ID(0).getText().lower(): ctx.ID(0).getText().lower() + "o"}) + self.write(body_name + ".inertia" + " = " + "(_me.inertia(" + body_name + "_f" + ", " + + ", ".join(inertia_list) + "), " + body_name_cm + ")\n") diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_parse_autolev_antlr.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_parse_autolev_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..e43924aac30903ade996b31921d3960afae90284 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/_parse_autolev_antlr.py @@ -0,0 +1,38 @@ +from importlib.metadata import version +from sympy.external import import_module + + +autolevparser = import_module('sympy.parsing.autolev._antlr.autolevparser', + import_kwargs={'fromlist': ['AutolevParser']}) +autolevlexer = import_module('sympy.parsing.autolev._antlr.autolevlexer', + import_kwargs={'fromlist': ['AutolevLexer']}) +autolevlistener = import_module('sympy.parsing.autolev._antlr.autolevlistener', + import_kwargs={'fromlist': ['AutolevListener']}) + +AutolevParser = getattr(autolevparser, 'AutolevParser', None) +AutolevLexer = getattr(autolevlexer, 'AutolevLexer', None) +AutolevListener = getattr(autolevlistener, 'AutolevListener', None) + + +def parse_autolev(autolev_code, include_numeric): + antlr4 = import_module('antlr4') + if not antlr4 or not version('antlr4-python3-runtime').startswith('4.11'): + raise ImportError("Autolev parsing requires the antlr4 Python package," + " provided by pip (antlr4-python3-runtime)" + " conda (antlr-python-runtime), version 4.11") + try: + l = autolev_code.readlines() + input_stream = antlr4.InputStream("".join(l)) + except Exception: + input_stream = antlr4.InputStream(autolev_code) + + if AutolevListener: + from ._listener_autolev_antlr import MyListener + lexer = AutolevLexer(input_stream) + token_stream = antlr4.CommonTokenStream(lexer) + parser = AutolevParser(token_stream) + tree = parser.prog() + my_listener = MyListener(include_numeric) + walker = antlr4.ParseTreeWalker() + walker.walk(my_listener, tree) + return "".join(my_listener.output_code) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/README.txt b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/README.txt new file mode 100644 index 0000000000000000000000000000000000000000..946b006bac33544fadd2dc6d24c22240c8fbc8e4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/README.txt @@ -0,0 +1,9 @@ +# parsing/tests/test_autolev.py uses the .al files in this directory as inputs and checks +# the equivalence of the parser generated codes and the respective .py files. + +# By default, this directory contains tests for all rules of the parser. + +# Additional tests consisting of full physics examples shall be made available soon in +# the form of another repository. One shall be able to copy the contents of that repo +# to this folder and use those tests after uncommenting the respective code in +# parsing/tests/test_autolev.py. diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.al new file mode 100644 index 0000000000000000000000000000000000000000..3bbb4d51b853bfd759df38d666a42adc1cbea190 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.al @@ -0,0 +1,33 @@ +CONSTANTS G,LB,W,H +MOTIONVARIABLES' THETA'',PHI'',OMEGA',ALPHA' +NEWTONIAN N +BODIES A,B +SIMPROT(N,A,2,THETA) +SIMPROT(A,B,3,PHI) +POINT O +LA = (LB-H/2)/2 +P_O_AO> = LA*A3> +P_O_BO> = LB*A3> +OMEGA = THETA' +ALPHA = PHI' +W_A_N> = OMEGA*N2> +W_B_A> = ALPHA*A3> +V_O_N> = 0> +V2PTS(N, A, O, AO) +V2PTS(N, A, O, BO) +MASS A=MA, B=MB +IAXX = 1/12*MA*(2*LA)^2 +IAYY = IAXX +IAZZ = 0 +IBXX = 1/12*MB*H^2 +IBYY = 1/12*MB*(W^2+H^2) +IBZZ = 1/12*MB*W^2 +INERTIA A, IAXX, IAYY, IAZZ +INERTIA B, IBXX, IBYY, IBZZ +GRAVITY(G*N3>) +ZERO = FR() + FRSTAR() +KANE() +INPUT LB=0.2,H=0.1,W=0.2,MA=0.01,MB=0.1,G=9.81 +INPUT THETA = 90 DEG, PHI = 0.5 DEG, OMEGA=0, ALPHA=0 +INPUT TFINAL=10, INTEGSTP=0.02 +CODE DYNAMICS() some_filename.c diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.py new file mode 100644 index 0000000000000000000000000000000000000000..4435635720bb38f40366f55bb3ace0f6f6899284 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.py @@ -0,0 +1,55 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +g, lb, w, h = _sm.symbols('g lb w h', real=True) +theta, phi, omega, alpha = _me.dynamicsymbols('theta phi omega alpha') +theta_d, phi_d, omega_d, alpha_d = _me.dynamicsymbols('theta_ phi_ omega_ alpha_', 1) +theta_dd, phi_dd = _me.dynamicsymbols('theta_ phi_', 2) +frame_n = _me.ReferenceFrame('n') +body_a_cm = _me.Point('a_cm') +body_a_cm.set_vel(frame_n, 0) +body_a_f = _me.ReferenceFrame('a_f') +body_a = _me.RigidBody('a', body_a_cm, body_a_f, _sm.symbols('m'), (_me.outer(body_a_f.x,body_a_f.x),body_a_cm)) +body_b_cm = _me.Point('b_cm') +body_b_cm.set_vel(frame_n, 0) +body_b_f = _me.ReferenceFrame('b_f') +body_b = _me.RigidBody('b', body_b_cm, body_b_f, _sm.symbols('m'), (_me.outer(body_b_f.x,body_b_f.x),body_b_cm)) +body_a_f.orient(frame_n, 'Axis', [theta, frame_n.y]) +body_b_f.orient(body_a_f, 'Axis', [phi, body_a_f.z]) +point_o = _me.Point('o') +la = (lb-h/2)/2 +body_a_cm.set_pos(point_o, la*body_a_f.z) +body_b_cm.set_pos(point_o, lb*body_a_f.z) +body_a_f.set_ang_vel(frame_n, omega*frame_n.y) +body_b_f.set_ang_vel(body_a_f, alpha*body_a_f.z) +point_o.set_vel(frame_n, 0) +body_a_cm.v2pt_theory(point_o,frame_n,body_a_f) +body_b_cm.v2pt_theory(point_o,frame_n,body_a_f) +ma = _sm.symbols('ma') +body_a.mass = ma +mb = _sm.symbols('mb') +body_b.mass = mb +iaxx = 1/12*ma*(2*la)**2 +iayy = iaxx +iazz = 0 +ibxx = 1/12*mb*h**2 +ibyy = 1/12*mb*(w**2+h**2) +ibzz = 1/12*mb*w**2 +body_a.inertia = (_me.inertia(body_a_f, iaxx, iayy, iazz, 0, 0, 0), body_a_cm) +body_b.inertia = (_me.inertia(body_b_f, ibxx, ibyy, ibzz, 0, 0, 0), body_b_cm) +force_a = body_a.mass*(g*frame_n.z) +force_b = body_b.mass*(g*frame_n.z) +kd_eqs = [theta_d - omega, phi_d - alpha] +forceList = [(body_a.masscenter,body_a.mass*(g*frame_n.z)), (body_b.masscenter,body_b.mass*(g*frame_n.z))] +kane = _me.KanesMethod(frame_n, q_ind=[theta,phi], u_ind=[omega, alpha], kd_eqs = kd_eqs) +fr, frstar = kane.kanes_equations([body_a, body_b], forceList) +zero = fr+frstar +from pydy.system import System +sys = System(kane, constants = {g:9.81, lb:0.2, w:0.2, h:0.1, ma:0.01, mb:0.1}, +specifieds={}, +initial_conditions={theta:_np.deg2rad(90), phi:_np.deg2rad(0.5), omega:0, alpha:0}, +times = _np.linspace(0.0, 10, 10/0.02)) + +y=sys.integrate() diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.al new file mode 100644 index 0000000000000000000000000000000000000000..0b6d72a072e093a6cb048a0b7976041ee9c2f4f3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.al @@ -0,0 +1,25 @@ +MOTIONVARIABLES' Q{2}', U{2}' +CONSTANTS L,M,G +NEWTONIAN N +FRAMES A,B +SIMPROT(N, A, 3, Q1) +SIMPROT(N, B, 3, Q2) +W_A_N>=U1*N3> +W_B_N>=U2*N3> +POINT O +PARTICLES P,R +P_O_P> = L*A1> +P_P_R> = L*B1> +V_O_N> = 0> +V2PTS(N, A, O, P) +V2PTS(N, B, P, R) +MASS P=M, R=M +Q1' = U1 +Q2' = U2 +GRAVITY(G*N1>) +ZERO = FR() + FRSTAR() +KANE() +INPUT M=1,G=9.81,L=1 +INPUT Q1=.1,Q2=.2,U1=0,U2=0 +INPUT TFINAL=10, INTEGSTP=.01 +CODE DYNAMICS() some_filename.c diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.py new file mode 100644 index 0000000000000000000000000000000000000000..12c73c3b4b198399f4c45f5e00d556c859caff74 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.py @@ -0,0 +1,39 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +q1, q2, u1, u2 = _me.dynamicsymbols('q1 q2 u1 u2') +q1_d, q2_d, u1_d, u2_d = _me.dynamicsymbols('q1_ q2_ u1_ u2_', 1) +l, m, g = _sm.symbols('l m g', real=True) +frame_n = _me.ReferenceFrame('n') +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +frame_a.orient(frame_n, 'Axis', [q1, frame_n.z]) +frame_b.orient(frame_n, 'Axis', [q2, frame_n.z]) +frame_a.set_ang_vel(frame_n, u1*frame_n.z) +frame_b.set_ang_vel(frame_n, u2*frame_n.z) +point_o = _me.Point('o') +particle_p = _me.Particle('p', _me.Point('p_pt'), _sm.Symbol('m')) +particle_r = _me.Particle('r', _me.Point('r_pt'), _sm.Symbol('m')) +particle_p.point.set_pos(point_o, l*frame_a.x) +particle_r.point.set_pos(particle_p.point, l*frame_b.x) +point_o.set_vel(frame_n, 0) +particle_p.point.v2pt_theory(point_o,frame_n,frame_a) +particle_r.point.v2pt_theory(particle_p.point,frame_n,frame_b) +particle_p.mass = m +particle_r.mass = m +force_p = particle_p.mass*(g*frame_n.x) +force_r = particle_r.mass*(g*frame_n.x) +kd_eqs = [q1_d - u1, q2_d - u2] +forceList = [(particle_p.point,particle_p.mass*(g*frame_n.x)), (particle_r.point,particle_r.mass*(g*frame_n.x))] +kane = _me.KanesMethod(frame_n, q_ind=[q1,q2], u_ind=[u1, u2], kd_eqs = kd_eqs) +fr, frstar = kane.kanes_equations([particle_p, particle_r], forceList) +zero = fr+frstar +from pydy.system import System +sys = System(kane, constants = {l:1, m:1, g:9.81}, +specifieds={}, +initial_conditions={q1:.1, q2:.2, u1:0, u2:0}, +times = _np.linspace(0.0, 10, 10/.01)) + +y=sys.integrate() diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.al new file mode 100644 index 0000000000000000000000000000000000000000..4892e5ca8cb18cad6b14a2a37cbdc1f7fb8217ac --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.al @@ -0,0 +1,19 @@ +CONSTANTS M,K,B,G +MOTIONVARIABLES' POSITION',SPEED' +VARIABLES O +FORCE = O*SIN(T) +NEWTONIAN CEILING +POINTS ORIGIN +V_ORIGIN_CEILING> = 0> +PARTICLES BLOCK +P_ORIGIN_BLOCK> = POSITION*CEILING1> +MASS BLOCK=M +V_BLOCK_CEILING>=SPEED*CEILING1> +POSITION' = SPEED +FORCE_MAGNITUDE = M*G-K*POSITION-B*SPEED+FORCE +FORCE_BLOCK>=EXPLICIT(FORCE_MAGNITUDE*CEILING1>) +ZERO = FR() + FRSTAR() +KANE() +INPUT TFINAL=10.0, INTEGSTP=0.01 +INPUT M=1.0, K=1.0, B=0.2, G=9.8, POSITION=0.1, SPEED=-1.0, O=2 +CODE DYNAMICS() dummy_file.c diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5baab9642ff140e0ee81027a1e8f9152d7050c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.py @@ -0,0 +1,31 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +m, k, b, g = _sm.symbols('m k b g', real=True) +position, speed = _me.dynamicsymbols('position speed') +position_d, speed_d = _me.dynamicsymbols('position_ speed_', 1) +o = _me.dynamicsymbols('o') +force = o*_sm.sin(_me.dynamicsymbols._t) +frame_ceiling = _me.ReferenceFrame('ceiling') +point_origin = _me.Point('origin') +point_origin.set_vel(frame_ceiling, 0) +particle_block = _me.Particle('block', _me.Point('block_pt'), _sm.Symbol('m')) +particle_block.point.set_pos(point_origin, position*frame_ceiling.x) +particle_block.mass = m +particle_block.point.set_vel(frame_ceiling, speed*frame_ceiling.x) +force_magnitude = m*g-k*position-b*speed+force +force_block = (force_magnitude*frame_ceiling.x).subs({position_d:speed}) +kd_eqs = [position_d - speed] +forceList = [(particle_block.point,(force_magnitude*frame_ceiling.x).subs({position_d:speed}))] +kane = _me.KanesMethod(frame_ceiling, q_ind=[position], u_ind=[speed], kd_eqs = kd_eqs) +fr, frstar = kane.kanes_equations([particle_block], forceList) +zero = fr+frstar +from pydy.system import System +sys = System(kane, constants = {m:1.0, k:1.0, b:0.2, g:9.8}, +specifieds={_me.dynamicsymbols('t'):lambda x, t: t, o:2}, +initial_conditions={position:0.1, speed:-1*1.0}, +times = _np.linspace(0.0, 10.0, 10.0/0.01)) + +y=sys.integrate() diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.al new file mode 100644 index 0000000000000000000000000000000000000000..74f5062d80926db7acd634a04759abce857087e5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.al @@ -0,0 +1,20 @@ +MOTIONVARIABLES' Q{2}'' +CONSTANTS L,M,G +NEWTONIAN N +POINT PN +V_PN_N> = 0> +THETA1 = ATAN(Q2/Q1) +FRAMES A +SIMPROT(N, A, 3, THETA1) +PARTICLES P +P_PN_P> = Q1*N1>+Q2*N2> +MASS P=M +V_P_N>=DT(P_P_PN>, N) +F_V = DOT(EXPRESS(V_P_N>,A), A1>) +GRAVITY(G*N1>) +DEPENDENT[1] = F_V +CONSTRAIN(DEPENDENT[Q1']) +ZERO=FR()+FRSTAR() +F_C = MAG(P_P_PN>)-L +CONFIG[1]=F_C +ZERO[2]=CONFIG[1] diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.py new file mode 100644 index 0000000000000000000000000000000000000000..fc972ebd518e77da5e1902c149f2699979865e7f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.py @@ -0,0 +1,36 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +q1, q2 = _me.dynamicsymbols('q1 q2') +q1_d, q2_d = _me.dynamicsymbols('q1_ q2_', 1) +q1_dd, q2_dd = _me.dynamicsymbols('q1_ q2_', 2) +l, m, g = _sm.symbols('l m g', real=True) +frame_n = _me.ReferenceFrame('n') +point_pn = _me.Point('pn') +point_pn.set_vel(frame_n, 0) +theta1 = _sm.atan(q2/q1) +frame_a = _me.ReferenceFrame('a') +frame_a.orient(frame_n, 'Axis', [theta1, frame_n.z]) +particle_p = _me.Particle('p', _me.Point('p_pt'), _sm.Symbol('m')) +particle_p.point.set_pos(point_pn, q1*frame_n.x+q2*frame_n.y) +particle_p.mass = m +particle_p.point.set_vel(frame_n, (point_pn.pos_from(particle_p.point)).dt(frame_n)) +f_v = _me.dot((particle_p.point.vel(frame_n)).express(frame_a), frame_a.x) +force_p = particle_p.mass*(g*frame_n.x) +dependent = _sm.Matrix([[0]]) +dependent[0] = f_v +velocity_constraints = [i for i in dependent] +u_q1_d = _me.dynamicsymbols('u_q1_d') +u_q2_d = _me.dynamicsymbols('u_q2_d') +kd_eqs = [q1_d-u_q1_d, q2_d-u_q2_d] +forceList = [(particle_p.point,particle_p.mass*(g*frame_n.x))] +kane = _me.KanesMethod(frame_n, q_ind=[q1,q2], u_ind=[u_q2_d], u_dependent=[u_q1_d], kd_eqs = kd_eqs, velocity_constraints = velocity_constraints) +fr, frstar = kane.kanes_equations([particle_p], forceList) +zero = fr+frstar +f_c = point_pn.pos_from(particle_p.point).magnitude()-l +config = _sm.Matrix([[0]]) +config[0] = f_c +zero = zero.row_insert(zero.shape[0], _sm.Matrix([[0]])) +zero[zero.shape[0]-1] = config[0] diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest1.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest1.al new file mode 100644 index 0000000000000000000000000000000000000000..457e79fd646677c0decdc69f921bc05e9e0dcf51 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest1.al @@ -0,0 +1,8 @@ +% ruletest1.al +CONSTANTS F = 3, G = 9.81 +CONSTANTS A, B +CONSTANTS S, S1, S2+, S3+, S4- +CONSTANTS K{4}, L{1:3}, P{1:2,1:3} +CONSTANTS C{2,3} +E1 = A*F + S2 - G +E2 = F^2 + K3*K2*G diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest1.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest1.py new file mode 100644 index 0000000000000000000000000000000000000000..8466392ac930f13f2419c9c04eef9dcc2884e9bd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest1.py @@ -0,0 +1,15 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +f = _sm.S(3) +g = _sm.S(9.81) +a, b = _sm.symbols('a b', real=True) +s, s1 = _sm.symbols('s s1', real=True) +s2, s3 = _sm.symbols('s2 s3', real=True, nonnegative=True) +s4 = _sm.symbols('s4', real=True, nonpositive=True) +k1, k2, k3, k4, l1, l2, l3, p11, p12, p13, p21, p22, p23 = _sm.symbols('k1 k2 k3 k4 l1 l2 l3 p11 p12 p13 p21 p22 p23', real=True) +c11, c12, c13, c21, c22, c23 = _sm.symbols('c11 c12 c13 c21 c22 c23', real=True) +e1 = a*f+s2-g +e2 = f**2+k3*k2*g diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest10.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest10.al new file mode 100644 index 0000000000000000000000000000000000000000..9d5f76f063c43bcb5e2a8d4f29619a6952abf9e5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest10.al @@ -0,0 +1,58 @@ +% ruletest10.al + +VARIABLES X,Y +COMPLEX ON +CONSTANTS A,B +E = A*(B*X+Y)^2 +M = [E;E] +EXPAND(E) +EXPAND(M) +FACTOR(E,X) +FACTOR(M,X) + +EQN[1] = A*X + B*Y +EQN[2] = 2*A*X - 3*B*Y +SOLVE(EQN, X, Y) +RHS_Y = RHS(Y) +E = (X+Y)^2 + 2*X^2 +ARRANGE(E, 2, X) + +CONSTANTS A,B,C +M = [A,B;C,0] +M2 = EVALUATE(M,A=1,B=2,C=3) +EIG(M2, EIGVALUE, EIGVEC) + +NEWTONIAN N +FRAMES A +SIMPROT(N, A, N1>, X) +DEGREES OFF +SIMPROT(N, A, N1>, PI/2) + +CONSTANTS C{3} +V> = C1*A1> + C2*A2> + C3*A3> +POINTS O, P +P_P_O> = C1*A1> +EXPRESS(V>,N) +EXPRESS(P_P_O>,N) +W_A_N> = C3*A3> +ANGVEL(A,N) + +V2PTS(N,A,O,P) +PARTICLES P{2} +V2PTS(N,A,P1,P2) +A2PTS(N,A,P1,P) + +BODIES B{2} +CONSTANT G +GRAVITY(G*N1>) + +VARIABLE Z +V> = X*A1> + Y*A3> +P_P_O> = X*A1> + Y*A2> +X = 2*Z +Y = Z +EXPLICIT(V>) +EXPLICIT(P_P_O>) + +FORCE(O/P1, X*Y*A1>) +FORCE(P2, X*Y*A1>) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest10.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest10.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9674e47d5f6132c5a79a33b9d8d55a131942d6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest10.py @@ -0,0 +1,64 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +a, b = _sm.symbols('a b', real=True) +e = a*(b*x+y)**2 +m = _sm.Matrix([e,e]).reshape(2, 1) +e = e.expand() +m = _sm.Matrix([i.expand() for i in m]).reshape((m).shape[0], (m).shape[1]) +e = _sm.factor(e, x) +m = _sm.Matrix([_sm.factor(i,x) for i in m]).reshape((m).shape[0], (m).shape[1]) +eqn = _sm.Matrix([[0]]) +eqn[0] = a*x+b*y +eqn = eqn.row_insert(eqn.shape[0], _sm.Matrix([[0]])) +eqn[eqn.shape[0]-1] = 2*a*x-3*b*y +print(_sm.solve(eqn,x,y)) +rhs_y = _sm.solve(eqn,x,y)[y] +e = (x+y)**2+2*x**2 +e.collect(x) +a, b, c = _sm.symbols('a b c', real=True) +m = _sm.Matrix([a,b,c,0]).reshape(2, 2) +m2 = _sm.Matrix([i.subs({a:1,b:2,c:3}) for i in m]).reshape((m).shape[0], (m).shape[1]) +eigvalue = _sm.Matrix([i.evalf() for i in (m2).eigenvals().keys()]) +eigvec = _sm.Matrix([i[2][0].evalf() for i in (m2).eigenvects()]).reshape(m2.shape[0], m2.shape[1]) +frame_n = _me.ReferenceFrame('n') +frame_a = _me.ReferenceFrame('a') +frame_a.orient(frame_n, 'Axis', [x, frame_n.x]) +frame_a.orient(frame_n, 'Axis', [_sm.pi/2, frame_n.x]) +c1, c2, c3 = _sm.symbols('c1 c2 c3', real=True) +v = c1*frame_a.x+c2*frame_a.y+c3*frame_a.z +point_o = _me.Point('o') +point_p = _me.Point('p') +point_o.set_pos(point_p, c1*frame_a.x) +v = (v).express(frame_n) +point_o.set_pos(point_p, (point_o.pos_from(point_p)).express(frame_n)) +frame_a.set_ang_vel(frame_n, c3*frame_a.z) +print(frame_n.ang_vel_in(frame_a)) +point_p.v2pt_theory(point_o,frame_n,frame_a) +particle_p1 = _me.Particle('p1', _me.Point('p1_pt'), _sm.Symbol('m')) +particle_p2 = _me.Particle('p2', _me.Point('p2_pt'), _sm.Symbol('m')) +particle_p2.point.v2pt_theory(particle_p1.point,frame_n,frame_a) +point_p.a2pt_theory(particle_p1.point,frame_n,frame_a) +body_b1_cm = _me.Point('b1_cm') +body_b1_cm.set_vel(frame_n, 0) +body_b1_f = _me.ReferenceFrame('b1_f') +body_b1 = _me.RigidBody('b1', body_b1_cm, body_b1_f, _sm.symbols('m'), (_me.outer(body_b1_f.x,body_b1_f.x),body_b1_cm)) +body_b2_cm = _me.Point('b2_cm') +body_b2_cm.set_vel(frame_n, 0) +body_b2_f = _me.ReferenceFrame('b2_f') +body_b2 = _me.RigidBody('b2', body_b2_cm, body_b2_f, _sm.symbols('m'), (_me.outer(body_b2_f.x,body_b2_f.x),body_b2_cm)) +g = _sm.symbols('g', real=True) +force_p1 = particle_p1.mass*(g*frame_n.x) +force_p2 = particle_p2.mass*(g*frame_n.x) +force_b1 = body_b1.mass*(g*frame_n.x) +force_b2 = body_b2.mass*(g*frame_n.x) +z = _me.dynamicsymbols('z') +v = x*frame_a.x+y*frame_a.z +point_o.set_pos(point_p, x*frame_a.x+y*frame_a.y) +v = (v).subs({x:2*z, y:z}) +point_o.set_pos(point_p, (point_o.pos_from(point_p)).subs({x:2*z, y:z})) +force_o = -1*(x*y*frame_a.x) +force_p1 = particle_p1.mass*(g*frame_n.x)+ x*y*frame_a.x diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest11.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest11.al new file mode 100644 index 0000000000000000000000000000000000000000..60934c1ca563024828110bfe984a90d5686b89e4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest11.al @@ -0,0 +1,6 @@ +VARIABLES X, Y +CONSTANTS A{1:2, 1:2}, B{1:2} +EQN[1] = A11*x + A12*y - B1 +EQN[2] = A21*x + A22*y - B2 +INPUT A11=2, A12=5, A21=3, A22=4, B1=7, B2=6 +CODE ALGEBRAIC(EQN, X, Y) some_filename.c diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest11.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest11.py new file mode 100644 index 0000000000000000000000000000000000000000..4ec2397ea96261d7b582d1f699e3897caae88f20 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest11.py @@ -0,0 +1,14 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +a11, a12, a21, a22, b1, b2 = _sm.symbols('a11 a12 a21 a22 b1 b2', real=True) +eqn = _sm.Matrix([[0]]) +eqn[0] = a11*x+a12*y-b1 +eqn = eqn.row_insert(eqn.shape[0], _sm.Matrix([[0]])) +eqn[eqn.shape[0]-1] = a21*x+a22*y-b2 +eqn_list = [] +for i in eqn: eqn_list.append(i.subs({a11:2, a12:5, a21:3, a22:4, b1:7, b2:6})) +print(_sm.linsolve(eqn_list, x,y)) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest12.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest12.al new file mode 100644 index 0000000000000000000000000000000000000000..f147f55afd1438436767960e0487d5d9e7161c8f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest12.al @@ -0,0 +1,7 @@ +VARIABLES X,Y +CONSTANTS A,B,R +EQN[1] = A*X^3+B*Y^2-R +EQN[2] = A*SIN(X)^2 + B*COS(2*Y) - R^2 +INPUT A=2.0, B=3.0, R=1.0 +INPUT X = 30 DEG, Y = 3.14 +CODE NONLINEAR(EQN,X,Y) some_filename.c diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest12.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest12.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7d996fa649f796a536dba20c1a36554acd8046 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest12.py @@ -0,0 +1,14 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +a, b, r = _sm.symbols('a b r', real=True) +eqn = _sm.Matrix([[0]]) +eqn[0] = a*x**3+b*y**2-r +eqn = eqn.row_insert(eqn.shape[0], _sm.Matrix([[0]])) +eqn[eqn.shape[0]-1] = a*_sm.sin(x)**2+b*_sm.cos(2*y)-r**2 +matrix_list = [] +for i in eqn:matrix_list.append(i.subs({a:2.0, b:3.0, r:1.0})) +print(_sm.nsolve(matrix_list,(x,y),(_np.deg2rad(30),3.14))) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest2.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest2.al new file mode 100644 index 0000000000000000000000000000000000000000..17937e58bd20a9fb82f44ccd05f0c081a1aa6c9b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest2.al @@ -0,0 +1,12 @@ +% ruletest2.al +VARIABLES X1,X2 +SPECIFIED F1 = X1*X2 + 3*X1^2 +SPECIFIED F2=X1*T+X2*T^2 +VARIABLE X', Y'' +MOTIONVARIABLES Q{3}, U{2} +VARIABLES P{2}' +VARIABLE W{3}', R{2}'' +VARIABLES C{1:2, 1:2} +VARIABLES D{1,3} +VARIABLES J{1:2} +IMAGINARY N diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest2.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest2.py new file mode 100644 index 0000000000000000000000000000000000000000..31c1d9974c2292466b805b91f8254bffaa94e2ac --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest2.py @@ -0,0 +1,22 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x1, x2 = _me.dynamicsymbols('x1 x2') +f1 = x1*x2+3*x1**2 +f2 = x1*_me.dynamicsymbols._t+x2*_me.dynamicsymbols._t**2 +x, y = _me.dynamicsymbols('x y') +x_d, y_d = _me.dynamicsymbols('x_ y_', 1) +y_dd = _me.dynamicsymbols('y_', 2) +q1, q2, q3, u1, u2 = _me.dynamicsymbols('q1 q2 q3 u1 u2') +p1, p2 = _me.dynamicsymbols('p1 p2') +p1_d, p2_d = _me.dynamicsymbols('p1_ p2_', 1) +w1, w2, w3, r1, r2 = _me.dynamicsymbols('w1 w2 w3 r1 r2') +w1_d, w2_d, w3_d, r1_d, r2_d = _me.dynamicsymbols('w1_ w2_ w3_ r1_ r2_', 1) +r1_dd, r2_dd = _me.dynamicsymbols('r1_ r2_', 2) +c11, c12, c21, c22 = _me.dynamicsymbols('c11 c12 c21 c22') +d11, d12, d13 = _me.dynamicsymbols('d11 d12 d13') +j1, j2 = _me.dynamicsymbols('j1 j2') +n = _sm.symbols('n') +n = _sm.I diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest3.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest3.al new file mode 100644 index 0000000000000000000000000000000000000000..f263f1802ebca2725481dd5fdd3540bf8e9f11bf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest3.al @@ -0,0 +1,25 @@ +% ruletest3.al +FRAMES A, B +NEWTONIAN N + +VARIABLES X{3} +CONSTANTS L + +V1> = X1*A1> + X2*A2> + X3*A3> +V2> = X1*B1> + X2*B2> + X3*B3> +V3> = X1*N1> + X2*N2> + X3*N3> + +V> = V1> + V2> + V3> + +POINTS C, D +POINTS PO{3} + +PARTICLES L +PARTICLES P{3} + +BODIES S +BODIES R{2} + +V4> = X1*S1> + X2*S2> + X3*S3> + +P_C_SO> = L*N1> diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest3.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest3.py new file mode 100644 index 0000000000000000000000000000000000000000..23f79aa571337f200b3ff4d56b5747f7704985c0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest3.py @@ -0,0 +1,37 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +frame_n = _me.ReferenceFrame('n') +x1, x2, x3 = _me.dynamicsymbols('x1 x2 x3') +l = _sm.symbols('l', real=True) +v1 = x1*frame_a.x+x2*frame_a.y+x3*frame_a.z +v2 = x1*frame_b.x+x2*frame_b.y+x3*frame_b.z +v3 = x1*frame_n.x+x2*frame_n.y+x3*frame_n.z +v = v1+v2+v3 +point_c = _me.Point('c') +point_d = _me.Point('d') +point_po1 = _me.Point('po1') +point_po2 = _me.Point('po2') +point_po3 = _me.Point('po3') +particle_l = _me.Particle('l', _me.Point('l_pt'), _sm.Symbol('m')) +particle_p1 = _me.Particle('p1', _me.Point('p1_pt'), _sm.Symbol('m')) +particle_p2 = _me.Particle('p2', _me.Point('p2_pt'), _sm.Symbol('m')) +particle_p3 = _me.Particle('p3', _me.Point('p3_pt'), _sm.Symbol('m')) +body_s_cm = _me.Point('s_cm') +body_s_cm.set_vel(frame_n, 0) +body_s_f = _me.ReferenceFrame('s_f') +body_s = _me.RigidBody('s', body_s_cm, body_s_f, _sm.symbols('m'), (_me.outer(body_s_f.x,body_s_f.x),body_s_cm)) +body_r1_cm = _me.Point('r1_cm') +body_r1_cm.set_vel(frame_n, 0) +body_r1_f = _me.ReferenceFrame('r1_f') +body_r1 = _me.RigidBody('r1', body_r1_cm, body_r1_f, _sm.symbols('m'), (_me.outer(body_r1_f.x,body_r1_f.x),body_r1_cm)) +body_r2_cm = _me.Point('r2_cm') +body_r2_cm.set_vel(frame_n, 0) +body_r2_f = _me.ReferenceFrame('r2_f') +body_r2 = _me.RigidBody('r2', body_r2_cm, body_r2_f, _sm.symbols('m'), (_me.outer(body_r2_f.x,body_r2_f.x),body_r2_cm)) +v4 = x1*body_s_f.x+x2*body_s_f.y+x3*body_s_f.z +body_s_cm.set_pos(point_c, l*frame_n.x) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest4.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest4.al new file mode 100644 index 0000000000000000000000000000000000000000..7302bd7724bad9b763c75fe4230faa42b5070408 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest4.al @@ -0,0 +1,20 @@ +% ruletest4.al + +FRAMES A, B +MOTIONVARIABLES Q{3} +SIMPROT(A, B, 1, Q3) +DCM = A_B +M = DCM*3 - A_B + +VARIABLES R +CIRCLE_AREA = PI*R^2 + +VARIABLES U, A +VARIABLES X, Y +S = U*T - 1/2*A*T^2 + +EXPR1 = 2*A*0.5 - 1.25 + 0.25 +EXPR2 = -X^2 + Y^2 + 0.25*(X+Y)^2 +EXPR3 = 0.5E-10 + +DYADIC>> = A1>*A1> + A2>*A2> + A3>*A3> diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest4.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest4.py new file mode 100644 index 0000000000000000000000000000000000000000..74b18543e04d6c9e42dd569d2152040c13ae0899 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest4.py @@ -0,0 +1,20 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +q1, q2, q3 = _me.dynamicsymbols('q1 q2 q3') +frame_b.orient(frame_a, 'Axis', [q3, frame_a.x]) +dcm = frame_a.dcm(frame_b) +m = dcm*3-frame_a.dcm(frame_b) +r = _me.dynamicsymbols('r') +circle_area = _sm.pi*r**2 +u, a = _me.dynamicsymbols('u a') +x, y = _me.dynamicsymbols('x y') +s = u*_me.dynamicsymbols._t-1/2*a*_me.dynamicsymbols._t**2 +expr1 = 2*a*0.5-1.25+0.25 +expr2 = -1*x**2+y**2+0.25*(x+y)**2 +expr3 = 0.5*10**(-10) +dyadic = _me.outer(frame_a.x, frame_a.x)+_me.outer(frame_a.y, frame_a.y)+_me.outer(frame_a.z, frame_a.z) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest5.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest5.al new file mode 100644 index 0000000000000000000000000000000000000000..a859dc8bb1f0251af14809681d995c59b31377ba --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest5.al @@ -0,0 +1,32 @@ +% ruletest5.al +VARIABLES X', Y' + +E1 = (X+Y)^2 + (X-Y)^3 +E2 = (X-Y)^2 +E3 = X^2 + Y^2 + 2*X*Y + +M1 = [E1;E2] +M2 = [(X+Y)^2,(X-Y)^2] +M3 = M1 + [X;Y] + +AM = EXPAND(M1) +CM = EXPAND([(X+Y)^2,(X-Y)^2]) +EM = EXPAND(M1 + [X;Y]) +F = EXPAND(E1) +G = EXPAND(E2) + +A = FACTOR(E3, X) +BM = FACTOR(M1, X) +CM = FACTOR(M1 + [X;Y], X) + +A = D(E3, X) +B = D(E3, Y) +CM = D(M2, X) +DM = D(M1 + [X;Y], X) +FRAMES A, B +A_B = [1,0,0;1,0,0;1,0,0] +V1> = X*A1> + Y*A2> + X*Y*A3> +E> = D(V1>, X, B) +FM = DT(M1) +GM = DT([(X+Y)^2,(X-Y)^2]) +H> = DT(V1>, B) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest5.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest5.py new file mode 100644 index 0000000000000000000000000000000000000000..93684435b402f5b56e2f4a5c3c81500208556423 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest5.py @@ -0,0 +1,33 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +x_d, y_d = _me.dynamicsymbols('x_ y_', 1) +e1 = (x+y)**2+(x-y)**3 +e2 = (x-y)**2 +e3 = x**2+y**2+2*x*y +m1 = _sm.Matrix([e1,e2]).reshape(2, 1) +m2 = _sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2) +m3 = m1+_sm.Matrix([x,y]).reshape(2, 1) +am = _sm.Matrix([i.expand() for i in m1]).reshape((m1).shape[0], (m1).shape[1]) +cm = _sm.Matrix([i.expand() for i in _sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)]).reshape((_sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)).shape[0], (_sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)).shape[1]) +em = _sm.Matrix([i.expand() for i in m1+_sm.Matrix([x,y]).reshape(2, 1)]).reshape((m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[0], (m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[1]) +f = (e1).expand() +g = (e2).expand() +a = _sm.factor((e3), x) +bm = _sm.Matrix([_sm.factor(i, x) for i in m1]).reshape((m1).shape[0], (m1).shape[1]) +cm = _sm.Matrix([_sm.factor(i, x) for i in m1+_sm.Matrix([x,y]).reshape(2, 1)]).reshape((m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[0], (m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[1]) +a = (e3).diff(x) +b = (e3).diff(y) +cm = _sm.Matrix([i.diff(x) for i in m2]).reshape((m2).shape[0], (m2).shape[1]) +dm = _sm.Matrix([i.diff(x) for i in m1+_sm.Matrix([x,y]).reshape(2, 1)]).reshape((m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[0], (m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[1]) +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +frame_b.orient(frame_a, 'DCM', _sm.Matrix([1,0,0,1,0,0,1,0,0]).reshape(3, 3)) +v1 = x*frame_a.x+y*frame_a.y+x*y*frame_a.z +e = (v1).diff(x, frame_b) +fm = _sm.Matrix([i.diff(_sm.Symbol('t')) for i in m1]).reshape((m1).shape[0], (m1).shape[1]) +gm = _sm.Matrix([i.diff(_sm.Symbol('t')) for i in _sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)]).reshape((_sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)).shape[0], (_sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)).shape[1]) +h = (v1).dt(frame_b) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest6.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest6.al new file mode 100644 index 0000000000000000000000000000000000000000..7ec3ba61590e77772ae631237df048b932fe778c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest6.al @@ -0,0 +1,41 @@ +% ruletest6.al +VARIABLES Q{2} +VARIABLES X,Y,Z +Q1 = X^2 + Y^2 +Q2 = X-Y +E = Q1 + Q2 +A = EXPLICIT(E) +E2 = COS(X) +E3 = COS(X*Y) +A = TAYLOR(E2, 0:2, X=0) +B = TAYLOR(E3, 0:2, X=0, Y=0) + +E = EXPAND((X+Y)^2) +A = EVALUATE(E, X=1, Y=Z) +BM = EVALUATE([E;2*E], X=1, Y=Z) + +E = Q1 + Q2 +A = EVALUATE(E, X=2, Y=Z^2) + +CONSTANTS J,K,L +P1 = POLYNOMIAL([J,K,L],X) +P2 = POLYNOMIAL(J*X+K,X,1) + +ROOT1 = ROOTS(P1, X, 2) +ROOT2 = ROOTS([1;2;3]) + +M = [1,2,3,4;5,6,7,8;9,10,11,12;13,14,15,16] + +AM = TRANSPOSE(M) + M +BM = EIG(M) +C1 = DIAGMAT(4, 1) +C2 = DIAGMAT(3, 4, 2) +DM = INV(M+C1) +E = DET(M+C1) + TRACE([1,0;0,1]) +F = ELEMENT(M, 2, 3) + +A = COLS(M) +BM = COLS(M, 1) +CM = COLS(M, 1, 2:4, 3) +DM = ROWS(M, 1) +EM = ROWS(M, 1, 2:4, 3) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest6.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest6.py new file mode 100644 index 0000000000000000000000000000000000000000..85f1a0b49518bb0ae5766cbe91b9c24a1b8e9c20 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest6.py @@ -0,0 +1,36 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +q1, q2 = _me.dynamicsymbols('q1 q2') +x, y, z = _me.dynamicsymbols('x y z') +e = q1+q2 +a = (e).subs({q1:x**2+y**2, q2:x-y}) +e2 = _sm.cos(x) +e3 = _sm.cos(x*y) +a = (e2).series(x, 0, 2).removeO() +b = (e3).series(x, 0, 2).removeO().series(y, 0, 2).removeO() +e = ((x+y)**2).expand() +a = (e).subs({q1:x**2+y**2,q2:x-y}).subs({x:1,y:z}) +bm = _sm.Matrix([i.subs({x:1,y:z}) for i in _sm.Matrix([e,2*e]).reshape(2, 1)]).reshape((_sm.Matrix([e,2*e]).reshape(2, 1)).shape[0], (_sm.Matrix([e,2*e]).reshape(2, 1)).shape[1]) +e = q1+q2 +a = (e).subs({q1:x**2+y**2,q2:x-y}).subs({x:2,y:z**2}) +j, k, l = _sm.symbols('j k l', real=True) +p1 = _sm.Poly(_sm.Matrix([j,k,l]).reshape(1, 3), x) +p2 = _sm.Poly(j*x+k, x) +root1 = [i.evalf() for i in _sm.solve(p1, x)] +root2 = [i.evalf() for i in _sm.solve(_sm.Poly(_sm.Matrix([1,2,3]).reshape(3, 1), x),x)] +m = _sm.Matrix([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]).reshape(4, 4) +am = (m).T+m +bm = _sm.Matrix([i.evalf() for i in (m).eigenvals().keys()]) +c1 = _sm.diag(1,1,1,1) +c2 = _sm.Matrix([2 if i==j else 0 for i in range(3) for j in range(4)]).reshape(3, 4) +dm = (m+c1)**(-1) +e = (m+c1).det()+(_sm.Matrix([1,0,0,1]).reshape(2, 2)).trace() +f = (m)[1,2] +a = (m).cols +bm = (m).col(0) +cm = _sm.Matrix([(m).T.row(0),(m).T.row(1),(m).T.row(2),(m).T.row(3),(m).T.row(2)]) +dm = (m).row(0) +em = _sm.Matrix([(m).row(0),(m).row(1),(m).row(2),(m).row(3),(m).row(2)]) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest7.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest7.al new file mode 100644 index 0000000000000000000000000000000000000000..2904a602f589645d22e1d3d378d077dd6a1ec27e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest7.al @@ -0,0 +1,39 @@ +% ruletest7.al +VARIABLES X', Y' +E = COS(X) + SIN(X) + TAN(X)& ++ COSH(X) + SINH(X) + TANH(X)& ++ ACOS(X) + ASIN(X) + ATAN(X)& ++ LOG(X) + EXP(X) + SQRT(X)& ++ FACTORIAL(X) + CEIL(X) +& +FLOOR(X) + SIGN(X) + +E = SQR(X) + LOG10(X) + +A = ABS(-1) + INT(1.5) + ROUND(1.9) + +E1 = 2*X + 3*Y +E2 = X + Y + +AM = COEF([E1;E2], [X,Y]) +B = COEF(E1, X) +C = COEF(E2, Y) +D1 = EXCLUDE(E1, X) +D2 = INCLUDE(E1, X) +FM = ARRANGE([E1,E2],2,X) +F = ARRANGE(E1, 2, Y) +G = REPLACE(E1, X=2*X) +GM = REPLACE([E1;E2], X=3) + +FRAMES A, B +VARIABLES THETA +SIMPROT(A,B,3,THETA) +V1> = 2*A1> - 3*A2> + A3> +V2> = B1> + B2> + B3> +A = DOT(V1>, V2>) +BM = DOT(V1>, [V2>;2*V2>]) +C> = CROSS(V1>,V2>) +D = MAG(2*V1>) + MAG(3*V1>) +DYADIC>> = 3*A1>*A1> + A2>*A2> + 2*A3>*A3> +AM = MATRIX(B, DYADIC>>) +M = [1;2;3] +V> = VECTOR(A, M) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest7.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest7.py new file mode 100644 index 0000000000000000000000000000000000000000..19147856dc3b0d451184a6bb539c1c331f61a6d2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest7.py @@ -0,0 +1,35 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +x_d, y_d = _me.dynamicsymbols('x_ y_', 1) +e = _sm.cos(x)+_sm.sin(x)+_sm.tan(x)+_sm.cosh(x)+_sm.sinh(x)+_sm.tanh(x)+_sm.acos(x)+_sm.asin(x)+_sm.atan(x)+_sm.log(x)+_sm.exp(x)+_sm.sqrt(x)+_sm.factorial(x)+_sm.ceiling(x)+_sm.floor(x)+_sm.sign(x) +e = (x)**2+_sm.log(x, 10) +a = _sm.Abs(-1*1)+int(1.5)+round(1.9) +e1 = 2*x+3*y +e2 = x+y +am = _sm.Matrix([e1.expand().coeff(x), e1.expand().coeff(y), e2.expand().coeff(x), e2.expand().coeff(y)]).reshape(2, 2) +b = (e1).expand().coeff(x) +c = (e2).expand().coeff(y) +d1 = (e1).collect(x).coeff(x,0) +d2 = (e1).collect(x).coeff(x,1) +fm = _sm.Matrix([i.collect(x)for i in _sm.Matrix([e1,e2]).reshape(1, 2)]).reshape((_sm.Matrix([e1,e2]).reshape(1, 2)).shape[0], (_sm.Matrix([e1,e2]).reshape(1, 2)).shape[1]) +f = (e1).collect(y) +g = (e1).subs({x:2*x}) +gm = _sm.Matrix([i.subs({x:3}) for i in _sm.Matrix([e1,e2]).reshape(2, 1)]).reshape((_sm.Matrix([e1,e2]).reshape(2, 1)).shape[0], (_sm.Matrix([e1,e2]).reshape(2, 1)).shape[1]) +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +theta = _me.dynamicsymbols('theta') +frame_b.orient(frame_a, 'Axis', [theta, frame_a.z]) +v1 = 2*frame_a.x-3*frame_a.y+frame_a.z +v2 = frame_b.x+frame_b.y+frame_b.z +a = _me.dot(v1, v2) +bm = _sm.Matrix([_me.dot(v1, v2),_me.dot(v1, 2*v2)]).reshape(2, 1) +c = _me.cross(v1, v2) +d = 2*v1.magnitude()+3*v1.magnitude() +dyadic = _me.outer(3*frame_a.x, frame_a.x)+_me.outer(frame_a.y, frame_a.y)+_me.outer(2*frame_a.z, frame_a.z) +am = (dyadic).to_matrix(frame_b) +m = _sm.Matrix([1,2,3]).reshape(3, 1) +v = m[0]*frame_a.x +m[1]*frame_a.y +m[2]*frame_a.z diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest8.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest8.al new file mode 100644 index 0000000000000000000000000000000000000000..4b2462c51e6730f46bf60b4b21ab6cfbf1993640 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest8.al @@ -0,0 +1,38 @@ +% ruletest8.al +FRAMES A +CONSTANTS C{3} +A>> = EXPRESS(1>>,A) +PARTICLES P1, P2 +BODIES R +R_A = [1,1,1;1,1,0;0,0,1] +POINT O +MASS P1=M1, P2=M2, R=MR +INERTIA R, I1, I2, I3 +P_P1_O> = C1*A1> +P_P2_O> = C2*A2> +P_RO_O> = C3*A3> +A>> = EXPRESS(I_P1_O>>, A) +A>> = EXPRESS(I_P2_O>>, A) +A>> = EXPRESS(I_R_O>>, A) +A>> = EXPRESS(INERTIA(O), A) +A>> = EXPRESS(INERTIA(O, P1, R), A) +A>> = EXPRESS(I_R_O>>, A) +A>> = EXPRESS(I_R_RO>>, A) + +P_P1_P2> = C1*A1> + C2*A2> +P_P1_RO> = C3*A1> +P_P2_RO> = C3*A2> + +B> = CM(O) +B> = CM(O, P1, R) +B> = CM(P1) + +MOTIONVARIABLES U{3} +V> = U1*A1> + U2*A2> + U3*A3> +U> = UNITVEC(V> + C1*A1>) +V_P1_A> = U1*A1> +A> = PARTIALS(V_P1_A>, U1) + +M = MASS(P1,R) +M = MASS(P2) +M = MASS() \ No newline at end of file diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest8.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest8.py new file mode 100644 index 0000000000000000000000000000000000000000..6809c47138e40027c700536e807ca7cfa5f468d7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest8.py @@ -0,0 +1,49 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +frame_a = _me.ReferenceFrame('a') +c1, c2, c3 = _sm.symbols('c1 c2 c3', real=True) +a = _me.inertia(frame_a, 1, 1, 1) +particle_p1 = _me.Particle('p1', _me.Point('p1_pt'), _sm.Symbol('m')) +particle_p2 = _me.Particle('p2', _me.Point('p2_pt'), _sm.Symbol('m')) +body_r_cm = _me.Point('r_cm') +body_r_f = _me.ReferenceFrame('r_f') +body_r = _me.RigidBody('r', body_r_cm, body_r_f, _sm.symbols('m'), (_me.outer(body_r_f.x,body_r_f.x),body_r_cm)) +frame_a.orient(body_r_f, 'DCM', _sm.Matrix([1,1,1,1,1,0,0,0,1]).reshape(3, 3)) +point_o = _me.Point('o') +m1 = _sm.symbols('m1') +particle_p1.mass = m1 +m2 = _sm.symbols('m2') +particle_p2.mass = m2 +mr = _sm.symbols('mr') +body_r.mass = mr +i1 = _sm.symbols('i1') +i2 = _sm.symbols('i2') +i3 = _sm.symbols('i3') +body_r.inertia = (_me.inertia(body_r_f, i1, i2, i3, 0, 0, 0), body_r_cm) +point_o.set_pos(particle_p1.point, c1*frame_a.x) +point_o.set_pos(particle_p2.point, c2*frame_a.y) +point_o.set_pos(body_r_cm, c3*frame_a.z) +a = _me.inertia_of_point_mass(particle_p1.mass, particle_p1.point.pos_from(point_o), frame_a) +a = _me.inertia_of_point_mass(particle_p2.mass, particle_p2.point.pos_from(point_o), frame_a) +a = body_r.inertia[0] + _me.inertia_of_point_mass(body_r.mass, body_r.masscenter.pos_from(point_o), frame_a) +a = _me.inertia_of_point_mass(particle_p1.mass, particle_p1.point.pos_from(point_o), frame_a) + _me.inertia_of_point_mass(particle_p2.mass, particle_p2.point.pos_from(point_o), frame_a) + body_r.inertia[0] + _me.inertia_of_point_mass(body_r.mass, body_r.masscenter.pos_from(point_o), frame_a) +a = _me.inertia_of_point_mass(particle_p1.mass, particle_p1.point.pos_from(point_o), frame_a) + body_r.inertia[0] + _me.inertia_of_point_mass(body_r.mass, body_r.masscenter.pos_from(point_o), frame_a) +a = body_r.inertia[0] + _me.inertia_of_point_mass(body_r.mass, body_r.masscenter.pos_from(point_o), frame_a) +a = body_r.inertia[0] +particle_p2.point.set_pos(particle_p1.point, c1*frame_a.x+c2*frame_a.y) +body_r_cm.set_pos(particle_p1.point, c3*frame_a.x) +body_r_cm.set_pos(particle_p2.point, c3*frame_a.y) +b = _me.functions.center_of_mass(point_o,particle_p1, particle_p2, body_r) +b = _me.functions.center_of_mass(point_o,particle_p1, body_r) +b = _me.functions.center_of_mass(particle_p1.point,particle_p1, particle_p2, body_r) +u1, u2, u3 = _me.dynamicsymbols('u1 u2 u3') +v = u1*frame_a.x+u2*frame_a.y+u3*frame_a.z +u = (v+c1*frame_a.x).normalize() +particle_p1.point.set_vel(frame_a, u1*frame_a.x) +a = particle_p1.point.partial_velocity(frame_a, u1) +m = particle_p1.mass+body_r.mass +m = particle_p2.mass +m = particle_p1.mass+particle_p2.mass+body_r.mass diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest9.al b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest9.al new file mode 100644 index 0000000000000000000000000000000000000000..df5c70f05b76fc215f829672e281491b0c96c6a6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest9.al @@ -0,0 +1,54 @@ +% ruletest9.al +NEWTONIAN N +FRAMES A +A> = 0> +D>> = EXPRESS(1>>, A) + +POINTS PO{2} +PARTICLES P{2} +MOTIONVARIABLES' C{3}' +BODIES R +P_P1_PO2> = C1*A1> +V> = 2*P_P1_PO2> + C2*A2> + +W_A_N> = C3*A3> +V> = 2*W_A_N> + C2*A2> +W_R_N> = C3*A3> +V> = 2*W_R_N> + C2*A2> + +ALF_A_N> = DT(W_A_N>, A) +V> = 2*ALF_A_N> + C2*A2> + +V_P1_A> = C1*A1> + C3*A2> +A_RO_N> = C2*A2> +V_A> = CROSS(A_RO_N>, V_P1_A>) + +X_B_C> = V_A> +X_B_D> = 2*X_B_C> +A_B_C_D_E> = X_B_D>*2 + +A_B_C = 2*C1*C2*C3 +A_B_C += 2*C1 +A_B_C := 3*C1 + +MOTIONVARIABLES' Q{2}', U{2}' +Q1' = U1 +Q2' = U2 + +VARIABLES X'', Y'' +SPECIFIED YY +Y'' = X*X'^2 + 1 +YY = X*X'^2 + 1 + +M[1] = 2*X +M[2] = 2*Y +A = 2*M[1] + +M = [1,2,3;4,5,6;7,8,9] +M[1, 2] = 5 +A = M[1, 2]*2 + +FORCE_RO> = Q1*N1> +TORQUE_A> = Q2*N3> +FORCE_RO> = Q2*N2> +F> = FORCE_RO>*2 diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest9.py b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest9.py new file mode 100644 index 0000000000000000000000000000000000000000..09d8ae4ee8385bde5c38b946458a43c8ffdaa9b8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/autolev/test-examples/ruletest9.py @@ -0,0 +1,55 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +frame_n = _me.ReferenceFrame('n') +frame_a = _me.ReferenceFrame('a') +a = 0 +d = _me.inertia(frame_a, 1, 1, 1) +point_po1 = _me.Point('po1') +point_po2 = _me.Point('po2') +particle_p1 = _me.Particle('p1', _me.Point('p1_pt'), _sm.Symbol('m')) +particle_p2 = _me.Particle('p2', _me.Point('p2_pt'), _sm.Symbol('m')) +c1, c2, c3 = _me.dynamicsymbols('c1 c2 c3') +c1_d, c2_d, c3_d = _me.dynamicsymbols('c1_ c2_ c3_', 1) +body_r_cm = _me.Point('r_cm') +body_r_cm.set_vel(frame_n, 0) +body_r_f = _me.ReferenceFrame('r_f') +body_r = _me.RigidBody('r', body_r_cm, body_r_f, _sm.symbols('m'), (_me.outer(body_r_f.x,body_r_f.x),body_r_cm)) +point_po2.set_pos(particle_p1.point, c1*frame_a.x) +v = 2*point_po2.pos_from(particle_p1.point)+c2*frame_a.y +frame_a.set_ang_vel(frame_n, c3*frame_a.z) +v = 2*frame_a.ang_vel_in(frame_n)+c2*frame_a.y +body_r_f.set_ang_vel(frame_n, c3*frame_a.z) +v = 2*body_r_f.ang_vel_in(frame_n)+c2*frame_a.y +frame_a.set_ang_acc(frame_n, (frame_a.ang_vel_in(frame_n)).dt(frame_a)) +v = 2*frame_a.ang_acc_in(frame_n)+c2*frame_a.y +particle_p1.point.set_vel(frame_a, c1*frame_a.x+c3*frame_a.y) +body_r_cm.set_acc(frame_n, c2*frame_a.y) +v_a = _me.cross(body_r_cm.acc(frame_n), particle_p1.point.vel(frame_a)) +x_b_c = v_a +x_b_d = 2*x_b_c +a_b_c_d_e = x_b_d*2 +a_b_c = 2*c1*c2*c3 +a_b_c += 2*c1 +a_b_c = 3*c1 +q1, q2, u1, u2 = _me.dynamicsymbols('q1 q2 u1 u2') +q1_d, q2_d, u1_d, u2_d = _me.dynamicsymbols('q1_ q2_ u1_ u2_', 1) +x, y = _me.dynamicsymbols('x y') +x_d, y_d = _me.dynamicsymbols('x_ y_', 1) +x_dd, y_dd = _me.dynamicsymbols('x_ y_', 2) +yy = _me.dynamicsymbols('yy') +yy = x*x_d**2+1 +m = _sm.Matrix([[0]]) +m[0] = 2*x +m = m.row_insert(m.shape[0], _sm.Matrix([[0]])) +m[m.shape[0]-1] = 2*y +a = 2*m[0] +m = _sm.Matrix([1,2,3,4,5,6,7,8,9]).reshape(3, 3) +m[0,1] = 5 +a = m[0, 1]*2 +force_ro = q1*frame_n.x +torque_a = q2*frame_n.z +force_ro = q1*frame_n.x + q2*frame_n.y +f = force_ro*2 diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/c/__init__.py b/.venv/lib/python3.13/site-packages/sympy/parsing/c/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18d3d5301cb001c78fc4a9bc04b25aa36f282a93 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/c/__init__.py @@ -0,0 +1 @@ +"""Used for translating C source code into a SymPy expression""" diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/c/c_parser.py b/.venv/lib/python3.13/site-packages/sympy/parsing/c/c_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7223f8351205272e803773589649fcf1902f15 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/c/c_parser.py @@ -0,0 +1,1059 @@ +from sympy.external import import_module +import os + +cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']}) + +""" +This module contains all the necessary Classes and Function used to Parse C and +C++ code into SymPy expression +The module serves as a backend for SymPyExpression to parse C code +It is also dependent on Clang's AST and SymPy's Codegen AST. +The module only supports the features currently supported by the Clang and +codegen AST which will be updated as the development of codegen AST and this +module progresses. +You might find unexpected bugs and exceptions while using the module, feel free +to report them to the SymPy Issue Tracker + +Features Supported +================== + +- Variable Declarations (integers and reals) +- Assignment (using integer & floating literal and function calls) +- Function Definitions and Declaration +- Function Calls +- Compound statements, Return statements + +Notes +===== + +The module is dependent on an external dependency which needs to be installed +to use the features of this module. + +Clang: The C and C++ compiler which is used to extract an AST from the provided +C source code. + +References +========== + +.. [1] https://github.com/sympy/sympy/issues +.. [2] https://clang.llvm.org/docs/ +.. [3] https://clang.llvm.org/docs/IntroductionToTheClangAST.html + +""" + +if cin: + from sympy.codegen.ast import (Variable, Integer, Float, + FunctionPrototype, FunctionDefinition, FunctionCall, + none, Return, Assignment, intc, int8, int16, int64, + uint8, uint16, uint32, uint64, float32, float64, float80, + aug_assign, bool_, While, CodeBlock) + from sympy.codegen.cnodes import (PreDecrement, PostDecrement, + PreIncrement, PostIncrement) + from sympy.core import Add, Mod, Mul, Pow, Rel + from sympy.logic.boolalg import And, as_Boolean, Not, Or + from sympy.core.symbol import Symbol + from sympy.core.sympify import sympify + from sympy.logic.boolalg import (false, true) + import sys + import tempfile + + class BaseParser: + """Base Class for the C parser""" + + def __init__(self): + """Initializes the Base parser creating a Clang AST index""" + self.index = cin.Index.create() + + def diagnostics(self, out): + """Diagostics function for the Clang AST""" + for diag in self.tu.diagnostics: + # tu = translation unit + print('%s %s (line %s, col %s) %s' % ( + { + 4: 'FATAL', + 3: 'ERROR', + 2: 'WARNING', + 1: 'NOTE', + 0: 'IGNORED', + }[diag.severity], + diag.location.file, + diag.location.line, + diag.location.column, + diag.spelling + ), file=out) + + class CCodeConverter(BaseParser): + """The Code Convereter for Clang AST + + The converter object takes the C source code or file as input and + converts them to SymPy Expressions. + """ + + def __init__(self): + """Initializes the code converter""" + super().__init__() + self._py_nodes = [] + self._data_types = { + "void": { + cin.TypeKind.VOID: none + }, + "bool": { + cin.TypeKind.BOOL: bool_ + }, + "int": { + cin.TypeKind.SCHAR: int8, + cin.TypeKind.SHORT: int16, + cin.TypeKind.INT: intc, + cin.TypeKind.LONG: int64, + cin.TypeKind.UCHAR: uint8, + cin.TypeKind.USHORT: uint16, + cin.TypeKind.UINT: uint32, + cin.TypeKind.ULONG: uint64 + }, + "float": { + cin.TypeKind.FLOAT: float32, + cin.TypeKind.DOUBLE: float64, + cin.TypeKind.LONGDOUBLE: float80 + } + } + + def parse(self, filename, flags): + """Function to parse a file with C source code + + It takes the filename as an attribute and creates a Clang AST + Translation Unit parsing the file. + Then the transformation function is called on the translation unit, + whose results are collected into a list which is returned by the + function. + + Parameters + ========== + + filename : string + Path to the C file to be parsed + + flags: list + Arguments to be passed to Clang while parsing the C code + + Returns + ======= + + py_nodes: list + A list of SymPy AST nodes + + """ + filepath = os.path.abspath(filename) + self.tu = self.index.parse( + filepath, + args=flags, + options=cin.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD + ) + for child in self.tu.cursor.get_children(): + if child.kind == cin.CursorKind.VAR_DECL or child.kind == cin.CursorKind.FUNCTION_DECL: + self._py_nodes.append(self.transform(child)) + return self._py_nodes + + def parse_str(self, source, flags): + """Function to parse a string with C source code + + It takes the source code as an attribute, stores it in a temporary + file and creates a Clang AST Translation Unit parsing the file. + Then the transformation function is called on the translation unit, + whose results are collected into a list which is returned by the + function. + + Parameters + ========== + + source : string + A string containing the C source code to be parsed + + flags: list + Arguments to be passed to Clang while parsing the C code + + Returns + ======= + + py_nodes: list + A list of SymPy AST nodes + + """ + file = tempfile.NamedTemporaryFile(mode = 'w+', suffix = '.cpp') + file.write(source) + file.flush() + file.seek(0) + self.tu = self.index.parse( + file.name, + args=flags, + options=cin.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD + ) + file.close() + for child in self.tu.cursor.get_children(): + if child.kind == cin.CursorKind.VAR_DECL or child.kind == cin.CursorKind.FUNCTION_DECL: + self._py_nodes.append(self.transform(child)) + return self._py_nodes + + def transform(self, node): + """Transformation Function for Clang AST nodes + + It determines the kind of node and calls the respective + transformation function for that node. + + Raises + ====== + + NotImplementedError : if the transformation for the provided node + is not implemented + + """ + handler = getattr(self, 'transform_%s' % node.kind.name.lower(), None) + + if handler is None: + print( + "Ignoring node of type %s (%s)" % ( + node.kind, + ' '.join( + t.spelling for t in node.get_tokens()) + ), + file=sys.stderr + ) + + return handler(node) + + def transform_var_decl(self, node): + """Transformation Function for Variable Declaration + + Used to create nodes for variable declarations and assignments with + values or function call for the respective nodes in the clang AST + + Returns + ======= + + A variable node as Declaration, with the initial value if given + + Raises + ====== + + NotImplementedError : if called for data types not currently + implemented + + Notes + ===== + + The function currently supports following data types: + + Boolean: + bool, _Bool + + Integer: + 8-bit: signed char and unsigned char + 16-bit: short, short int, signed short, + signed short int, unsigned short, unsigned short int + 32-bit: int, signed int, unsigned int + 64-bit: long, long int, signed long, + signed long int, unsigned long, unsigned long int + + Floating point: + Single Precision: float + Double Precision: double + Extended Precision: long double + + """ + if node.type.kind in self._data_types["int"]: + type = self._data_types["int"][node.type.kind] + elif node.type.kind in self._data_types["float"]: + type = self._data_types["float"][node.type.kind] + elif node.type.kind in self._data_types["bool"]: + type = self._data_types["bool"][node.type.kind] + else: + raise NotImplementedError("Only bool, int " + "and float are supported") + try: + children = node.get_children() + child = next(children) + + #ignoring namespace and type details for the variable + while child.kind == cin.CursorKind.NAMESPACE_REF or child.kind == cin.CursorKind.TYPE_REF: + child = next(children) + + val = self.transform(child) + + supported_rhs = [ + cin.CursorKind.INTEGER_LITERAL, + cin.CursorKind.FLOATING_LITERAL, + cin.CursorKind.UNEXPOSED_EXPR, + cin.CursorKind.BINARY_OPERATOR, + cin.CursorKind.PAREN_EXPR, + cin.CursorKind.UNARY_OPERATOR, + cin.CursorKind.CXX_BOOL_LITERAL_EXPR + ] + + if child.kind in supported_rhs: + if isinstance(val, str): + value = Symbol(val) + elif isinstance(val, bool): + if node.type.kind in self._data_types["int"]: + value = Integer(0) if val == False else Integer(1) + elif node.type.kind in self._data_types["float"]: + value = Float(0.0) if val == False else Float(1.0) + elif node.type.kind in self._data_types["bool"]: + value = sympify(val) + elif isinstance(val, (Integer, int, Float, float)): + if node.type.kind in self._data_types["int"]: + value = Integer(val) + elif node.type.kind in self._data_types["float"]: + value = Float(val) + elif node.type.kind in self._data_types["bool"]: + value = sympify(bool(val)) + else: + value = val + + return Variable( + node.spelling + ).as_Declaration( + type = type, + value = value + ) + + elif child.kind == cin.CursorKind.CALL_EXPR: + return Variable( + node.spelling + ).as_Declaration( + value = val + ) + + else: + raise NotImplementedError("Given " + "variable declaration \"{}\" " + "is not possible to parse yet!" + .format(" ".join( + t.spelling for t in node.get_tokens() + ) + )) + + except StopIteration: + return Variable( + node.spelling + ).as_Declaration( + type = type + ) + + def transform_function_decl(self, node): + """Transformation Function For Function Declaration + + Used to create nodes for function declarations and definitions for + the respective nodes in the clang AST + + Returns + ======= + + function : Codegen AST node + - FunctionPrototype node if function body is not present + - FunctionDefinition node if the function body is present + + + """ + + if node.result_type.kind in self._data_types["int"]: + ret_type = self._data_types["int"][node.result_type.kind] + elif node.result_type.kind in self._data_types["float"]: + ret_type = self._data_types["float"][node.result_type.kind] + elif node.result_type.kind in self._data_types["bool"]: + ret_type = self._data_types["bool"][node.result_type.kind] + elif node.result_type.kind in self._data_types["void"]: + ret_type = self._data_types["void"][node.result_type.kind] + else: + raise NotImplementedError("Only void, bool, int " + "and float are supported") + body = [] + param = [] + + # Subsequent nodes will be the parameters for the function. + for child in node.get_children(): + decl = self.transform(child) + if child.kind == cin.CursorKind.PARM_DECL: + param.append(decl) + elif child.kind == cin.CursorKind.COMPOUND_STMT: + for val in decl: + body.append(val) + else: + body.append(decl) + + if body == []: + function = FunctionPrototype( + return_type = ret_type, + name = node.spelling, + parameters = param + ) + else: + function = FunctionDefinition( + return_type = ret_type, + name = node.spelling, + parameters = param, + body = body + ) + return function + + def transform_parm_decl(self, node): + """Transformation function for Parameter Declaration + + Used to create parameter nodes for the required functions for the + respective nodes in the clang AST + + Returns + ======= + + param : Codegen AST Node + Variable node with the value and type of the variable + + Raises + ====== + + ValueError if multiple children encountered in the parameter node + + """ + if node.type.kind in self._data_types["int"]: + type = self._data_types["int"][node.type.kind] + elif node.type.kind in self._data_types["float"]: + type = self._data_types["float"][node.type.kind] + elif node.type.kind in self._data_types["bool"]: + type = self._data_types["bool"][node.type.kind] + else: + raise NotImplementedError("Only bool, int " + "and float are supported") + try: + children = node.get_children() + child = next(children) + + # Any namespace nodes can be ignored + while child.kind in [cin.CursorKind.NAMESPACE_REF, + cin.CursorKind.TYPE_REF, + cin.CursorKind.TEMPLATE_REF]: + child = next(children) + + # If there is a child, it is the default value of the parameter. + lit = self.transform(child) + if node.type.kind in self._data_types["int"]: + val = Integer(lit) + elif node.type.kind in self._data_types["float"]: + val = Float(lit) + elif node.type.kind in self._data_types["bool"]: + val = sympify(bool(lit)) + else: + raise NotImplementedError("Only bool, int " + "and float are supported") + + param = Variable( + node.spelling + ).as_Declaration( + type = type, + value = val + ) + except StopIteration: + param = Variable( + node.spelling + ).as_Declaration( + type = type + ) + + try: + self.transform(next(children)) + raise ValueError("Can't handle multiple children on parameter") + except StopIteration: + pass + + return param + + def transform_integer_literal(self, node): + """Transformation function for integer literal + + Used to get the value and type of the given integer literal. + + Returns + ======= + + val : list + List with two arguments type and Value + type contains the type of the integer + value contains the value stored in the variable + + Notes + ===== + + Only Base Integer type supported for now + + """ + try: + value = next(node.get_tokens()).spelling + except StopIteration: + # No tokens + value = node.literal + return int(value) + + def transform_floating_literal(self, node): + """Transformation function for floating literal + + Used to get the value and type of the given floating literal. + + Returns + ======= + + val : list + List with two arguments type and Value + type contains the type of float + value contains the value stored in the variable + + Notes + ===== + + Only Base Float type supported for now + + """ + try: + value = next(node.get_tokens()).spelling + except (StopIteration, ValueError): + # No tokens + value = node.literal + return float(value) + + def transform_string_literal(self, node): + #TODO: No string type in AST + #type = + #try: + # value = next(node.get_tokens()).spelling + #except (StopIteration, ValueError): + # No tokens + # value = node.literal + #val = [type, value] + #return val + pass + + def transform_character_literal(self, node): + """Transformation function for character literal + + Used to get the value of the given character literal. + + Returns + ======= + + val : int + val contains the ascii value of the character literal + + Notes + ===== + + Only for cases where character is assigned to a integer value, + since character literal is not in SymPy AST + + """ + try: + value = next(node.get_tokens()).spelling + except (StopIteration, ValueError): + # No tokens + value = node.literal + return ord(str(value[1])) + + def transform_cxx_bool_literal_expr(self, node): + """Transformation function for boolean literal + + Used to get the value of the given boolean literal. + + Returns + ======= + + value : bool + value contains the boolean value of the variable + + """ + try: + value = next(node.get_tokens()).spelling + except (StopIteration, ValueError): + value = node.literal + return True if value == 'true' else False + + def transform_unexposed_decl(self,node): + """Transformation function for unexposed declarations""" + pass + + def transform_unexposed_expr(self, node): + """Transformation function for unexposed expression + + Unexposed expressions are used to wrap float, double literals and + expressions + + Returns + ======= + + expr : Codegen AST Node + the result from the wrapped expression + + None : NoneType + No children are found for the node + + Raises + ====== + + ValueError if the expression contains multiple children + + """ + # Ignore unexposed nodes; pass whatever is the first + # (and should be only) child unaltered. + try: + children = node.get_children() + expr = self.transform(next(children)) + except StopIteration: + return None + + try: + next(children) + raise ValueError("Unexposed expression has > 1 children.") + except StopIteration: + pass + + return expr + + def transform_decl_ref_expr(self, node): + """Returns the name of the declaration reference""" + return node.spelling + + def transform_call_expr(self, node): + """Transformation function for a call expression + + Used to create function call nodes for the function calls present + in the C code + + Returns + ======= + + FunctionCall : Codegen AST Node + FunctionCall node with parameters if any parameters are present + + """ + param = [] + children = node.get_children() + child = next(children) + + while child.kind == cin.CursorKind.NAMESPACE_REF: + child = next(children) + while child.kind == cin.CursorKind.TYPE_REF: + child = next(children) + + first_child = self.transform(child) + try: + for child in children: + arg = self.transform(child) + if child.kind == cin.CursorKind.INTEGER_LITERAL: + param.append(Integer(arg)) + elif child.kind == cin.CursorKind.FLOATING_LITERAL: + param.append(Float(arg)) + else: + param.append(arg) + return FunctionCall(first_child, param) + + except StopIteration: + return FunctionCall(first_child) + + def transform_return_stmt(self, node): + """Returns the Return Node for a return statement""" + return Return(next(node.get_children()).spelling) + + def transform_compound_stmt(self, node): + """Transformation function for compound statements + + Returns + ======= + + expr : list + list of Nodes for the expressions present in the statement + + None : NoneType + if the compound statement is empty + + """ + expr = [] + children = node.get_children() + + for child in children: + expr.append(self.transform(child)) + return expr + + def transform_decl_stmt(self, node): + """Transformation function for declaration statements + + These statements are used to wrap different kinds of declararions + like variable or function declaration + The function calls the transformer function for the child of the + given node + + Returns + ======= + + statement : Codegen AST Node + contains the node returned by the children node for the type of + declaration + + Raises + ====== + + ValueError if multiple children present + + """ + try: + children = node.get_children() + statement = self.transform(next(children)) + except StopIteration: + pass + + try: + self.transform(next(children)) + raise ValueError("Don't know how to handle multiple statements") + except StopIteration: + pass + + return statement + + def transform_paren_expr(self, node): + """Transformation function for Parenthesized expressions + + Returns the result from its children nodes + + """ + return self.transform(next(node.get_children())) + + def transform_compound_assignment_operator(self, node): + """Transformation function for handling shorthand operators + + Returns + ======= + + augmented_assignment_expression: Codegen AST node + shorthand assignment expression represented as Codegen AST + + Raises + ====== + + NotImplementedError + If the shorthand operator for bitwise operators + (~=, ^=, &=, |=, <<=, >>=) is encountered + + """ + return self.transform_binary_operator(node) + + def transform_unary_operator(self, node): + """Transformation function for handling unary operators + + Returns + ======= + + unary_expression: Codegen AST node + simplified unary expression represented as Codegen AST + + Raises + ====== + + NotImplementedError + If dereferencing operator(*), address operator(&) or + bitwise NOT operator(~) is encountered + + """ + # supported operators list + operators_list = ['+', '-', '++', '--', '!'] + tokens = list(node.get_tokens()) + + # it can be either pre increment/decrement or any other operator from the list + if tokens[0].spelling in operators_list: + child = self.transform(next(node.get_children())) + # (decl_ref) e.g.; int a = ++b; or simply ++b; + if isinstance(child, str): + if tokens[0].spelling == '+': + return Symbol(child) + if tokens[0].spelling == '-': + return Mul(Symbol(child), -1) + if tokens[0].spelling == '++': + return PreIncrement(Symbol(child)) + if tokens[0].spelling == '--': + return PreDecrement(Symbol(child)) + if tokens[0].spelling == '!': + return Not(Symbol(child)) + # e.g.; int a = -1; or int b = -(1 + 2); + else: + if tokens[0].spelling == '+': + return child + if tokens[0].spelling == '-': + return Mul(child, -1) + if tokens[0].spelling == '!': + return Not(sympify(bool(child))) + + # it can be either post increment/decrement + # since variable name is obtained in token[0].spelling + elif tokens[1].spelling in ['++', '--']: + child = self.transform(next(node.get_children())) + if tokens[1].spelling == '++': + return PostIncrement(Symbol(child)) + if tokens[1].spelling == '--': + return PostDecrement(Symbol(child)) + else: + raise NotImplementedError("Dereferencing operator, " + "Address operator and bitwise NOT operator " + "have not been implemented yet!") + + def transform_binary_operator(self, node): + """Transformation function for handling binary operators + + Returns + ======= + + binary_expression: Codegen AST node + simplified binary expression represented as Codegen AST + + Raises + ====== + + NotImplementedError + If a bitwise operator or + unary operator(which is a child of any binary + operator in Clang AST) is encountered + + """ + # get all the tokens of assignment + # and store it in the tokens list + tokens = list(node.get_tokens()) + + # supported operators list + operators_list = ['+', '-', '*', '/', '%','=', + '>', '>=', '<', '<=', '==', '!=', '&&', '||', '+=', '-=', + '*=', '/=', '%='] + + # this stack will contain variable content + # and type of variable in the rhs + combined_variables_stack = [] + + # this stack will contain operators + # to be processed in the rhs + operators_stack = [] + + # iterate through every token + for token in tokens: + # token is either '(', ')' or + # any of the supported operators from the operator list + if token.kind == cin.TokenKind.PUNCTUATION: + + # push '(' to the operators stack + if token.spelling == '(': + operators_stack.append('(') + + elif token.spelling == ')': + # keep adding the expression to the + # combined variables stack unless + # '(' is found + while (operators_stack + and operators_stack[-1] != '('): + if len(combined_variables_stack) < 2: + raise NotImplementedError( + "Unary operators as a part of " + "binary operators is not " + "supported yet!") + rhs = combined_variables_stack.pop() + lhs = combined_variables_stack.pop() + operator = operators_stack.pop() + combined_variables_stack.append( + self.perform_operation( + lhs, rhs, operator)) + + # pop '(' + operators_stack.pop() + + # token is an operator (supported) + elif token.spelling in operators_list: + while (operators_stack + and self.priority_of(token.spelling) + <= self.priority_of( + operators_stack[-1])): + if len(combined_variables_stack) < 2: + raise NotImplementedError( + "Unary operators as a part of " + "binary operators is not " + "supported yet!") + rhs = combined_variables_stack.pop() + lhs = combined_variables_stack.pop() + operator = operators_stack.pop() + combined_variables_stack.append( + self.perform_operation( + lhs, rhs, operator)) + + # push current operator + operators_stack.append(token.spelling) + + # token is a bitwise operator + elif token.spelling in ['&', '|', '^', '<<', '>>']: + raise NotImplementedError( + "Bitwise operator has not been " + "implemented yet!") + + # token is a shorthand bitwise operator + elif token.spelling in ['&=', '|=', '^=', '<<=', + '>>=']: + raise NotImplementedError( + "Shorthand bitwise operator has not been " + "implemented yet!") + else: + raise NotImplementedError( + "Given token {} is not implemented yet!" + .format(token.spelling)) + + # token is an identifier(variable) + elif token.kind == cin.TokenKind.IDENTIFIER: + combined_variables_stack.append( + [token.spelling, 'identifier']) + + # token is a literal + elif token.kind == cin.TokenKind.LITERAL: + combined_variables_stack.append( + [token.spelling, 'literal']) + + # token is a keyword, either true or false + elif (token.kind == cin.TokenKind.KEYWORD + and token.spelling in ['true', 'false']): + combined_variables_stack.append( + [token.spelling, 'boolean']) + else: + raise NotImplementedError( + "Given token {} is not implemented yet!" + .format(token.spelling)) + + # process remaining operators + while operators_stack: + if len(combined_variables_stack) < 2: + raise NotImplementedError( + "Unary operators as a part of " + "binary operators is not " + "supported yet!") + rhs = combined_variables_stack.pop() + lhs = combined_variables_stack.pop() + operator = operators_stack.pop() + combined_variables_stack.append( + self.perform_operation(lhs, rhs, operator)) + + return combined_variables_stack[-1][0] + + def priority_of(self, op): + """To get the priority of given operator""" + if op in ['=', '+=', '-=', '*=', '/=', '%=']: + return 1 + if op in ['&&', '||']: + return 2 + if op in ['<', '<=', '>', '>=', '==', '!=']: + return 3 + if op in ['+', '-']: + return 4 + if op in ['*', '/', '%']: + return 5 + return 0 + + def perform_operation(self, lhs, rhs, op): + """Performs operation supported by the SymPy core + + Returns + ======= + + combined_variable: list + contains variable content and type of variable + + """ + lhs_value = self.get_expr_for_operand(lhs) + rhs_value = self.get_expr_for_operand(rhs) + if op == '+': + return [Add(lhs_value, rhs_value), 'expr'] + if op == '-': + return [Add(lhs_value, -rhs_value), 'expr'] + if op == '*': + return [Mul(lhs_value, rhs_value), 'expr'] + if op == '/': + return [Mul(lhs_value, Pow(rhs_value, Integer(-1))), 'expr'] + if op == '%': + return [Mod(lhs_value, rhs_value), 'expr'] + if op in ['<', '<=', '>', '>=', '==', '!=']: + return [Rel(lhs_value, rhs_value, op), 'expr'] + if op == '&&': + return [And(as_Boolean(lhs_value), as_Boolean(rhs_value)), 'expr'] + if op == '||': + return [Or(as_Boolean(lhs_value), as_Boolean(rhs_value)), 'expr'] + if op == '=': + return [Assignment(Variable(lhs_value), rhs_value), 'expr'] + if op in ['+=', '-=', '*=', '/=', '%=']: + return [aug_assign(Variable(lhs_value), op[0], rhs_value), 'expr'] + + def get_expr_for_operand(self, combined_variable): + """Gives out SymPy Codegen AST node + + AST node returned is corresponding to + combined variable passed.Combined variable contains + variable content and type of variable + + """ + if combined_variable[1] == 'identifier': + return Symbol(combined_variable[0]) + if combined_variable[1] == 'literal': + if '.' in combined_variable[0]: + return Float(float(combined_variable[0])) + else: + return Integer(int(combined_variable[0])) + if combined_variable[1] == 'expr': + return combined_variable[0] + if combined_variable[1] == 'boolean': + return true if combined_variable[0] == 'true' else false + + def transform_null_stmt(self, node): + """Handles Null Statement and returns None""" + return none + + def transform_while_stmt(self, node): + """Transformation function for handling while statement + + Returns + ======= + + while statement : Codegen AST Node + contains the while statement node having condition and + statement block + + """ + children = node.get_children() + + condition = self.transform(next(children)) + statements = self.transform(next(children)) + + if isinstance(statements, list): + statement_block = CodeBlock(*statements) + else: + statement_block = CodeBlock(statements) + + return While(condition, statement_block) + + + +else: + class CCodeConverter(): # type: ignore + def __init__(self, *args, **kwargs): + raise ImportError("Module not Installed") + + +def parse_c(source): + """Function for converting a C source code + + The function reads the source code present in the given file and parses it + to give out SymPy Expressions + + Returns + ======= + + src : list + List of Python expression strings + + """ + converter = CCodeConverter() + if os.path.exists(source): + src = converter.parse(source, flags = []) + else: + src = converter.parse_str(source, flags = []) + return src diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/fortran/__init__.py b/.venv/lib/python3.13/site-packages/sympy/parsing/fortran/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c65e37cf3de2dddbcee0fa5c7eeac2fdc9f685db --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/fortran/__init__.py @@ -0,0 +1 @@ +"""Used for translating Fortran source code into a SymPy expression. """ diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/fortran/fortran_parser.py b/.venv/lib/python3.13/site-packages/sympy/parsing/fortran/fortran_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..504249f6119a59a90d91c5e989f893cffe20e643 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/fortran/fortran_parser.py @@ -0,0 +1,347 @@ +from sympy.external import import_module + +lfortran = import_module('lfortran') + +if lfortran: + from sympy.codegen.ast import (Variable, IntBaseType, FloatBaseType, String, + Return, FunctionDefinition, Assignment) + from sympy.core import Add, Mul, Integer, Float + from sympy.core.symbol import Symbol + + asr_mod = lfortran.asr + asr = lfortran.asr.asr + src_to_ast = lfortran.ast.src_to_ast + ast_to_asr = lfortran.semantic.ast_to_asr.ast_to_asr + + """ + This module contains all the necessary Classes and Function used to Parse + Fortran code into SymPy expression + + The module and its API are currently under development and experimental. + It is also dependent on LFortran for the ASR that is converted to SymPy syntax + which is also under development. + The module only supports the features currently supported by the LFortran ASR + which will be updated as the development of LFortran and this module progresses + + You might find unexpected bugs and exceptions while using the module, feel free + to report them to the SymPy Issue Tracker + + The API for the module might also change while in development if better and + more effective ways are discovered for the process + + Features Supported + ================== + + - Variable Declarations (integers and reals) + - Function Definitions + - Assignments and Basic Binary Operations + + + Notes + ===== + + The module depends on an external dependency + + LFortran : Required to parse Fortran source code into ASR + + + References + ========== + + .. [1] https://github.com/sympy/sympy/issues + .. [2] https://gitlab.com/lfortran/lfortran + .. [3] https://docs.lfortran.org/ + + """ + + + class ASR2PyVisitor(asr.ASTVisitor): # type: ignore + """ + Visitor Class for LFortran ASR + + It is a Visitor class derived from asr.ASRVisitor which visits all the + nodes of the LFortran ASR and creates corresponding AST node for each + ASR node + + """ + + def __init__(self): + """Initialize the Parser""" + self._py_ast = [] + + def visit_TranslationUnit(self, node): + """ + Function to visit all the elements of the Translation Unit + created by LFortran ASR + """ + for s in node.global_scope.symbols: + sym = node.global_scope.symbols[s] + self.visit(sym) + for item in node.items: + self.visit(item) + + def visit_Assignment(self, node): + """Visitor Function for Assignment + + Visits each Assignment is the LFortran ASR and creates corresponding + assignment for SymPy. + + Notes + ===== + + The function currently only supports variable assignment and binary + operation assignments of varying multitudes. Any type of numberS or + array is not supported. + + Raises + ====== + + NotImplementedError() when called for Numeric assignments or Arrays + + """ + # TODO: Arithmetic Assignment + if isinstance(node.target, asr.Variable): + target = node.target + value = node.value + if isinstance(value, asr.Variable): + new_node = Assignment( + Variable( + target.name + ), + Variable( + value.name + ) + ) + elif (type(value) == asr.BinOp): + exp_ast = call_visitor(value) + for expr in exp_ast: + new_node = Assignment( + Variable(target.name), + expr + ) + else: + raise NotImplementedError("Numeric assignments not supported") + else: + raise NotImplementedError("Arrays not supported") + self._py_ast.append(new_node) + + def visit_BinOp(self, node): + """Visitor Function for Binary Operations + + Visits each binary operation present in the LFortran ASR like addition, + subtraction, multiplication, division and creates the corresponding + operation node in SymPy's AST + + In case of more than one binary operations, the function calls the + call_visitor() function on the child nodes of the binary operations + recursively until all the operations have been processed. + + Notes + ===== + + The function currently only supports binary operations with Variables + or other binary operations. Numerics are not supported as of yet. + + Raises + ====== + + NotImplementedError() when called for Numeric assignments + + """ + # TODO: Integer Binary Operations + op = node.op + lhs = node.left + rhs = node.right + + if (type(lhs) == asr.Variable): + left_value = Symbol(lhs.name) + elif(type(lhs) == asr.BinOp): + l_exp_ast = call_visitor(lhs) + for exp in l_exp_ast: + left_value = exp + else: + raise NotImplementedError("Numbers Currently not supported") + + if (type(rhs) == asr.Variable): + right_value = Symbol(rhs.name) + elif(type(rhs) == asr.BinOp): + r_exp_ast = call_visitor(rhs) + for exp in r_exp_ast: + right_value = exp + else: + raise NotImplementedError("Numbers Currently not supported") + + if isinstance(op, asr.Add): + new_node = Add(left_value, right_value) + elif isinstance(op, asr.Sub): + new_node = Add(left_value, -right_value) + elif isinstance(op, asr.Div): + new_node = Mul(left_value, 1/right_value) + elif isinstance(op, asr.Mul): + new_node = Mul(left_value, right_value) + + self._py_ast.append(new_node) + + def visit_Variable(self, node): + """Visitor Function for Variable Declaration + + Visits each variable declaration present in the ASR and creates a + Symbol declaration for each variable + + Notes + ===== + + The functions currently only support declaration of integer and + real variables. Other data types are still under development. + + Raises + ====== + + NotImplementedError() when called for unsupported data types + + """ + if isinstance(node.type, asr.Integer): + var_type = IntBaseType(String('integer')) + value = Integer(0) + elif isinstance(node.type, asr.Real): + var_type = FloatBaseType(String('real')) + value = Float(0.0) + else: + raise NotImplementedError("Data type not supported") + + if not (node.intent == 'in'): + new_node = Variable( + node.name + ).as_Declaration( + type = var_type, + value = value + ) + self._py_ast.append(new_node) + + def visit_Sequence(self, seq): + """Visitor Function for code sequence + + Visits a code sequence/ block and calls the visitor function on all the + children of the code block to create corresponding code in python + + """ + if seq is not None: + for node in seq: + self._py_ast.append(call_visitor(node)) + + def visit_Num(self, node): + """Visitor Function for Numbers in ASR + + This function is currently under development and will be updated + with improvements in the LFortran ASR + + """ + # TODO:Numbers when the LFortran ASR is updated + # self._py_ast.append(Integer(node.n)) + pass + + def visit_Function(self, node): + """Visitor Function for function Definitions + + Visits each function definition present in the ASR and creates a + function definition node in the Python AST with all the elements of the + given function + + The functions declare all the variables required as SymPy symbols in + the function before the function definition + + This function also the call_visior_function to parse the contents of + the function body + + """ + # TODO: Return statement, variable declaration + fn_args = [Variable(arg_iter.name) for arg_iter in node.args] + fn_body = [] + fn_name = node.name + for i in node.body: + fn_ast = call_visitor(i) + try: + fn_body_expr = fn_ast + except UnboundLocalError: + fn_body_expr = [] + for sym in node.symtab.symbols: + decl = call_visitor(node.symtab.symbols[sym]) + for symbols in decl: + fn_body.append(symbols) + for elem in fn_body_expr: + fn_body.append(elem) + fn_body.append( + Return( + Variable( + node.return_var.name + ) + ) + ) + if isinstance(node.return_var.type, asr.Integer): + ret_type = IntBaseType(String('integer')) + elif isinstance(node.return_var.type, asr.Real): + ret_type = FloatBaseType(String('real')) + else: + raise NotImplementedError("Data type not supported") + new_node = FunctionDefinition( + return_type = ret_type, + name = fn_name, + parameters = fn_args, + body = fn_body + ) + self._py_ast.append(new_node) + + def ret_ast(self): + """Returns the AST nodes""" + return self._py_ast +else: + class ASR2PyVisitor(): # type: ignore + def __init__(self, *args, **kwargs): + raise ImportError('lfortran not available') + +def call_visitor(fort_node): + """Calls the AST Visitor on the Module + + This function is used to call the AST visitor for a program or module + It imports all the required modules and calls the visit() function + on the given node + + Parameters + ========== + + fort_node : LFortran ASR object + Node for the operation for which the NodeVisitor is called + + Returns + ======= + + res_ast : list + list of SymPy AST Nodes + + """ + v = ASR2PyVisitor() + v.visit(fort_node) + res_ast = v.ret_ast() + return res_ast + + +def src_to_sympy(src): + """Wrapper function to convert the given Fortran source code to SymPy Expressions + + Parameters + ========== + + src : string + A string with the Fortran source code + + Returns + ======= + + py_src : string + A string with the Python source code compatible with SymPy + + """ + a_ast = src_to_ast(src, translation_unit=False) + a = ast_to_asr(a_ast) + py_src = call_visitor(a) + return py_src diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/LICENSE.txt b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..6bbfda911b2afada41a568218e31a6502dc68f44 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/LICENSE.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright 2016, latex2sympy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/LaTeX.g4 b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/LaTeX.g4 new file mode 100644 index 0000000000000000000000000000000000000000..fc2c30f9817931e2060b549a39f98a6a4f9cb1f7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/LaTeX.g4 @@ -0,0 +1,312 @@ +/* + ANTLR4 LaTeX Math Grammar + + Ported from latex2sympy by @augustt198 https://github.com/augustt198/latex2sympy See license in + LICENSE.txt + */ + +/* + After changing this file, it is necessary to run `python setup.py antlr` in the root directory of + the repository. This will regenerate the code in `sympy/parsing/latex/_antlr/*.py`. + */ + +grammar LaTeX; + +options { + language = Python3; +} + +WS: [ \t\r\n]+ -> skip; +THINSPACE: ('\\,' | '\\thinspace') -> skip; +MEDSPACE: ('\\:' | '\\medspace') -> skip; +THICKSPACE: ('\\;' | '\\thickspace') -> skip; +QUAD: '\\quad' -> skip; +QQUAD: '\\qquad' -> skip; +NEGTHINSPACE: ('\\!' | '\\negthinspace') -> skip; +NEGMEDSPACE: '\\negmedspace' -> skip; +NEGTHICKSPACE: '\\negthickspace' -> skip; +CMD_LEFT: '\\left' -> skip; +CMD_RIGHT: '\\right' -> skip; + +IGNORE: + ( + '\\vrule' + | '\\vcenter' + | '\\vbox' + | '\\vskip' + | '\\vspace' + | '\\hfil' + | '\\*' + | '\\-' + | '\\.' + | '\\/' + | '\\"' + | '\\(' + | '\\=' + ) -> skip; + +ADD: '+'; +SUB: '-'; +MUL: '*'; +DIV: '/'; + +L_PAREN: '('; +R_PAREN: ')'; +L_BRACE: '{'; +R_BRACE: '}'; +L_BRACE_LITERAL: '\\{'; +R_BRACE_LITERAL: '\\}'; +L_BRACKET: '['; +R_BRACKET: ']'; + +BAR: '|'; + +R_BAR: '\\right|'; +L_BAR: '\\left|'; + +L_ANGLE: '\\langle'; +R_ANGLE: '\\rangle'; +FUNC_LIM: '\\lim'; +LIM_APPROACH_SYM: + '\\to' + | '\\rightarrow' + | '\\Rightarrow' + | '\\longrightarrow' + | '\\Longrightarrow'; +FUNC_INT: + '\\int' + | '\\int\\limits'; +FUNC_SUM: '\\sum'; +FUNC_PROD: '\\prod'; + +FUNC_EXP: '\\exp'; +FUNC_LOG: '\\log'; +FUNC_LG: '\\lg'; +FUNC_LN: '\\ln'; +FUNC_SIN: '\\sin'; +FUNC_COS: '\\cos'; +FUNC_TAN: '\\tan'; +FUNC_CSC: '\\csc'; +FUNC_SEC: '\\sec'; +FUNC_COT: '\\cot'; + +FUNC_ARCSIN: '\\arcsin'; +FUNC_ARCCOS: '\\arccos'; +FUNC_ARCTAN: '\\arctan'; +FUNC_ARCCSC: '\\arccsc'; +FUNC_ARCSEC: '\\arcsec'; +FUNC_ARCCOT: '\\arccot'; + +FUNC_SINH: '\\sinh'; +FUNC_COSH: '\\cosh'; +FUNC_TANH: '\\tanh'; +FUNC_ARSINH: '\\arsinh'; +FUNC_ARCOSH: '\\arcosh'; +FUNC_ARTANH: '\\artanh'; + +L_FLOOR: '\\lfloor'; +R_FLOOR: '\\rfloor'; +L_CEIL: '\\lceil'; +R_CEIL: '\\rceil'; + +FUNC_SQRT: '\\sqrt'; +FUNC_OVERLINE: '\\overline'; + +CMD_TIMES: '\\times'; +CMD_CDOT: '\\cdot'; +CMD_DIV: '\\div'; +CMD_FRAC: + '\\frac' + | '\\dfrac' + | '\\tfrac'; +CMD_BINOM: '\\binom'; +CMD_DBINOM: '\\dbinom'; +CMD_TBINOM: '\\tbinom'; + +CMD_MATHIT: '\\mathit'; + +UNDERSCORE: '_'; +CARET: '^'; +COLON: ':'; + +fragment WS_CHAR: [ \t\r\n]; +DIFFERENTIAL: 'd' WS_CHAR*? ([a-zA-Z] | '\\' [a-zA-Z]+); + +LETTER: [a-zA-Z]; +DIGIT: [0-9]; + +EQUAL: (('&' WS_CHAR*?)? '=') | ('=' (WS_CHAR*? '&')?); +NEQ: '\\neq'; + +LT: '<'; +LTE: ('\\leq' | '\\le' | LTE_Q | LTE_S); +LTE_Q: '\\leqq'; +LTE_S: '\\leqslant'; + +GT: '>'; +GTE: ('\\geq' | '\\ge' | GTE_Q | GTE_S); +GTE_Q: '\\geqq'; +GTE_S: '\\geqslant'; + +BANG: '!'; + +SINGLE_QUOTES: '\''+; + +SYMBOL: '\\' [a-zA-Z]+; + +math: relation; + +relation: + relation (EQUAL | LT | LTE | GT | GTE | NEQ) relation + | expr; + +equality: expr EQUAL expr; + +expr: additive; + +additive: additive (ADD | SUB) additive | mp; + +// mult part +mp: + mp (MUL | CMD_TIMES | CMD_CDOT | DIV | CMD_DIV | COLON) mp + | unary; + +mp_nofunc: + mp_nofunc ( + MUL + | CMD_TIMES + | CMD_CDOT + | DIV + | CMD_DIV + | COLON + ) mp_nofunc + | unary_nofunc; + +unary: (ADD | SUB) unary | postfix+; + +unary_nofunc: + (ADD | SUB) unary_nofunc + | postfix postfix_nofunc*; + +postfix: exp postfix_op*; +postfix_nofunc: exp_nofunc postfix_op*; +postfix_op: BANG | eval_at; + +eval_at: + BAR (eval_at_sup | eval_at_sub | eval_at_sup eval_at_sub); + +eval_at_sub: UNDERSCORE L_BRACE (expr | equality) R_BRACE; + +eval_at_sup: CARET L_BRACE (expr | equality) R_BRACE; + +exp: exp CARET (atom | L_BRACE expr R_BRACE) subexpr? | comp; + +exp_nofunc: + exp_nofunc CARET (atom | L_BRACE expr R_BRACE) subexpr? + | comp_nofunc; + +comp: + group + | abs_group + | func + | atom + | floor + | ceil; + +comp_nofunc: + group + | abs_group + | atom + | floor + | ceil; + +group: + L_PAREN expr R_PAREN + | L_BRACKET expr R_BRACKET + | L_BRACE expr R_BRACE + | L_BRACE_LITERAL expr R_BRACE_LITERAL; + +abs_group: BAR expr BAR; + +number: DIGIT+ (',' DIGIT DIGIT DIGIT)* ('.' DIGIT+)?; + +atom: (LETTER | SYMBOL) (subexpr? SINGLE_QUOTES? | SINGLE_QUOTES? subexpr?) + | number + | DIFFERENTIAL + | mathit + | frac + | binom + | bra + | ket; + +bra: L_ANGLE expr (R_BAR | BAR); +ket: (L_BAR | BAR) expr R_ANGLE; + +mathit: CMD_MATHIT L_BRACE mathit_text R_BRACE; +mathit_text: LETTER*; + +frac: CMD_FRAC (upperd = DIGIT | L_BRACE upper = expr R_BRACE) + (lowerd = DIGIT | L_BRACE lower = expr R_BRACE); + +binom: + (CMD_BINOM | CMD_DBINOM | CMD_TBINOM) L_BRACE n = expr R_BRACE L_BRACE k = expr R_BRACE; + +floor: L_FLOOR val = expr R_FLOOR; +ceil: L_CEIL val = expr R_CEIL; + +func_normal: + FUNC_EXP + | FUNC_LOG + | FUNC_LG + | FUNC_LN + | FUNC_SIN + | FUNC_COS + | FUNC_TAN + | FUNC_CSC + | FUNC_SEC + | FUNC_COT + | FUNC_ARCSIN + | FUNC_ARCCOS + | FUNC_ARCTAN + | FUNC_ARCCSC + | FUNC_ARCSEC + | FUNC_ARCCOT + | FUNC_SINH + | FUNC_COSH + | FUNC_TANH + | FUNC_ARSINH + | FUNC_ARCOSH + | FUNC_ARTANH; + +func: + func_normal (subexpr? supexpr? | supexpr? subexpr?) ( + L_PAREN func_arg R_PAREN + | func_arg_noparens + ) + | (LETTER | SYMBOL) (subexpr? SINGLE_QUOTES? | SINGLE_QUOTES? subexpr?) // e.g. f(x), f_1'(x) + L_PAREN args R_PAREN + | FUNC_INT (subexpr supexpr | supexpr subexpr)? ( + additive? DIFFERENTIAL + | frac + | additive + ) + | FUNC_SQRT (L_BRACKET root = expr R_BRACKET)? L_BRACE base = expr R_BRACE + | FUNC_OVERLINE L_BRACE base = expr R_BRACE + | (FUNC_SUM | FUNC_PROD) (subeq supexpr | supexpr subeq) mp + | FUNC_LIM limit_sub mp; + +args: (expr ',' args) | expr; + +limit_sub: + UNDERSCORE L_BRACE (LETTER | SYMBOL) LIM_APPROACH_SYM expr ( + CARET ((L_BRACE (ADD | SUB) R_BRACE) | ADD | SUB) + )? R_BRACE; + +func_arg: expr | (expr ',' func_arg); +func_arg_noparens: mp_nofunc; + +subexpr: UNDERSCORE (atom | L_BRACE expr R_BRACE); +supexpr: CARET (atom | L_BRACE expr R_BRACE); + +subeq: UNDERSCORE L_BRACE equality R_BRACE; +supeq: UNDERSCORE L_BRACE equality R_BRACE; diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/__init__.py b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9466d37b8b06f1f292c73f975e44d21c96da10d1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/__init__.py @@ -0,0 +1,204 @@ +from sympy.external import import_module +from sympy.utilities.decorator import doctest_depends_on +from re import compile as rcompile + +from sympy.parsing.latex.lark import LarkLaTeXParser, TransformToSymPyExpr, parse_latex_lark # noqa + +from .errors import LaTeXParsingError # noqa + + +IGNORE_L = r"\s*[{]*\s*" +IGNORE_R = r"\s*[}]*\s*" +NO_LEFT = r"(? len(latex_str): + e = len(latex_str) + eellipsis = "" + + if x[3] in END_DELIM_REPR: + err = (f"Extra '{x[2]}' at index {x[0]} or " + "missing corresponding " + f"'{BEGIN_DELIM_REPR[MATRIX_DELIMS_INV[x[3]]]}' " + f"in LaTeX string: {sellipsis}{latex_str[s:e]}" + f"{eellipsis}") + raise LaTeXParsingError(err) + + if x[7] is None: + err = (f"Extra '{x[2]}' at index {x[0]} or " + "missing corresponding " + f"'{END_DELIM_REPR[MATRIX_DELIMS[x[3]]]}' " + f"in LaTeX string: {sellipsis}{latex_str[s:e]}" + f"{eellipsis}") + raise LaTeXParsingError(err) + + correct_end_regex = MATRIX_DELIMS[x[3]] + sellipsis = "..." if x[0] > 0 else "" + eellipsis = "..." if x[5] < len(latex_str) else "" + if x[7] != correct_end_regex: + err = ("Expected " + f"'{END_DELIM_REPR[correct_end_regex]}' " + f"to close the '{x[2]}' at index {x[0]} but " + f"found '{x[6]}' at index {x[4]} of LaTeX " + f"string instead: {sellipsis}{latex_str[x[0]:x[5]]}" + f"{eellipsis}") + raise LaTeXParsingError(err) + +__doctest_requires__ = {('parse_latex',): ['antlr4', 'lark']} + + +@doctest_depends_on(modules=('antlr4', 'lark')) +def parse_latex(s, strict=False, backend="antlr"): + r"""Converts the input LaTeX string ``s`` to a SymPy ``Expr``. + + Parameters + ========== + + s : str + The LaTeX string to parse. In Python source containing LaTeX, + *raw strings* (denoted with ``r"``, like this one) are preferred, + as LaTeX makes liberal use of the ``\`` character, which would + trigger escaping in normal Python strings. + backend : str, optional + Currently, there are two backends supported: ANTLR, and Lark. + The default setting is to use the ANTLR backend, which can be + changed to Lark if preferred. + + Use ``backend="antlr"`` for the ANTLR-based parser, and + ``backend="lark"`` for the Lark-based parser. + + The ``backend`` option is case-sensitive, and must be in + all lowercase. + strict : bool, optional + This option is only available with the ANTLR backend. + + If True, raise an exception if the string cannot be parsed as + valid LaTeX. If False, try to recover gracefully from common + mistakes. + + Examples + ======== + + >>> from sympy.parsing.latex import parse_latex + >>> expr = parse_latex(r"\frac {1 + \sqrt {\a}} {\b}") + >>> expr + (sqrt(a) + 1)/b + >>> expr.evalf(4, subs=dict(a=5, b=2)) + 1.618 + >>> func = parse_latex(r"\int_1^\alpha \dfrac{\mathrm{d}t}{t}", backend="lark") + >>> func.evalf(subs={"alpha": 2}) + 0.693147180559945 + """ + + check_matrix_delimiters(s) + + if backend == "antlr": + _latex = import_module( + 'sympy.parsing.latex._parse_latex_antlr', + import_kwargs={'fromlist': ['X']}) + + if _latex is not None: + return _latex.parse_latex(s, strict) + elif backend == "lark": + return parse_latex_lark(s) + else: + raise NotImplementedError(f"Using the '{backend}' backend in the LaTeX" \ + " parser is not supported.") diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_antlr/__init__.py b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_antlr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d690e1eb8631ee7731fc1875769d3a4704a1743 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_antlr/__init__.py @@ -0,0 +1,9 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated from ../LaTeX.g4, derived from latex2sympy +# latex2sympy is licensed under the MIT license +# https://github.com/augustt198/latex2sympy/blob/master/LICENSE.txt +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_antlr/latexlexer.py b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_antlr/latexlexer.py new file mode 100644 index 0000000000000000000000000000000000000000..46ca959736c967782eef360b9b3268ccd0be0979 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_antlr/latexlexer.py @@ -0,0 +1,512 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated from ../LaTeX.g4, derived from latex2sympy +# latex2sympy is licensed under the MIT license +# https://github.com/augustt198/latex2sympy/blob/master/LICENSE.txt +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +from io import StringIO +import sys +if sys.version_info[1] > 5: + from typing import TextIO +else: + from typing.io import TextIO + + +def serializedATN(): + return [ + 4,0,91,911,6,-1,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5, + 2,6,7,6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2, + 13,7,13,2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7, + 19,2,20,7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2, + 26,7,26,2,27,7,27,2,28,7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7, + 32,2,33,7,33,2,34,7,34,2,35,7,35,2,36,7,36,2,37,7,37,2,38,7,38,2, + 39,7,39,2,40,7,40,2,41,7,41,2,42,7,42,2,43,7,43,2,44,7,44,2,45,7, + 45,2,46,7,46,2,47,7,47,2,48,7,48,2,49,7,49,2,50,7,50,2,51,7,51,2, + 52,7,52,2,53,7,53,2,54,7,54,2,55,7,55,2,56,7,56,2,57,7,57,2,58,7, + 58,2,59,7,59,2,60,7,60,2,61,7,61,2,62,7,62,2,63,7,63,2,64,7,64,2, + 65,7,65,2,66,7,66,2,67,7,67,2,68,7,68,2,69,7,69,2,70,7,70,2,71,7, + 71,2,72,7,72,2,73,7,73,2,74,7,74,2,75,7,75,2,76,7,76,2,77,7,77,2, + 78,7,78,2,79,7,79,2,80,7,80,2,81,7,81,2,82,7,82,2,83,7,83,2,84,7, + 84,2,85,7,85,2,86,7,86,2,87,7,87,2,88,7,88,2,89,7,89,2,90,7,90,2, + 91,7,91,1,0,1,0,1,1,1,1,1,2,4,2,191,8,2,11,2,12,2,192,1,2,1,2,1, + 3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,3,3,209,8,3,1,3,1, + 3,1,4,1,4,1,4,1,4,1,4,1,4,1,4,1,4,1,4,1,4,1,4,3,4,224,8,4,1,4,1, + 4,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,3,5,241,8, + 5,1,5,1,5,1,6,1,6,1,6,1,6,1,6,1,6,1,6,1,6,1,7,1,7,1,7,1,7,1,7,1, + 7,1,7,1,7,1,7,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1, + 8,1,8,1,8,3,8,277,8,8,1,8,1,8,1,9,1,9,1,9,1,9,1,9,1,9,1,9,1,9,1, + 9,1,9,1,9,1,9,1,9,1,9,1,9,1,10,1,10,1,10,1,10,1,10,1,10,1,10,1,10, + 1,10,1,10,1,10,1,10,1,10,1,10,1,10,1,10,1,10,1,11,1,11,1,11,1,11, + 1,11,1,11,1,11,1,11,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12, + 1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13, + 1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13, + 1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13, + 1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,3,13, + 381,8,13,1,13,1,13,1,14,1,14,1,15,1,15,1,16,1,16,1,17,1,17,1,18, + 1,18,1,19,1,19,1,20,1,20,1,21,1,21,1,22,1,22,1,22,1,23,1,23,1,23, + 1,24,1,24,1,25,1,25,1,26,1,26,1,27,1,27,1,27,1,27,1,27,1,27,1,27, + 1,27,1,28,1,28,1,28,1,28,1,28,1,28,1,28,1,29,1,29,1,29,1,29,1,29, + 1,29,1,29,1,29,1,30,1,30,1,30,1,30,1,30,1,30,1,30,1,30,1,31,1,31, + 1,31,1,31,1,31,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,3,32,504,8,32,1,33,1,33,1,33,1,33, + 1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,3,33,521, + 8,33,1,34,1,34,1,34,1,34,1,34,1,35,1,35,1,35,1,35,1,35,1,35,1,36, + 1,36,1,36,1,36,1,36,1,37,1,37,1,37,1,37,1,37,1,38,1,38,1,38,1,38, + 1,39,1,39,1,39,1,39,1,40,1,40,1,40,1,40,1,40,1,41,1,41,1,41,1,41, + 1,41,1,42,1,42,1,42,1,42,1,42,1,43,1,43,1,43,1,43,1,43,1,44,1,44, + 1,44,1,44,1,44,1,45,1,45,1,45,1,45,1,45,1,46,1,46,1,46,1,46,1,46, + 1,46,1,46,1,46,1,47,1,47,1,47,1,47,1,47,1,47,1,47,1,47,1,48,1,48, + 1,48,1,48,1,48,1,48,1,48,1,48,1,49,1,49,1,49,1,49,1,49,1,49,1,49, + 1,49,1,50,1,50,1,50,1,50,1,50,1,50,1,50,1,50,1,51,1,51,1,51,1,51, + 1,51,1,51,1,51,1,51,1,52,1,52,1,52,1,52,1,52,1,52,1,53,1,53,1,53, + 1,53,1,53,1,53,1,54,1,54,1,54,1,54,1,54,1,54,1,55,1,55,1,55,1,55, + 1,55,1,55,1,55,1,55,1,56,1,56,1,56,1,56,1,56,1,56,1,56,1,56,1,57, + 1,57,1,57,1,57,1,57,1,57,1,57,1,57,1,58,1,58,1,58,1,58,1,58,1,58, + 1,58,1,58,1,59,1,59,1,59,1,59,1,59,1,59,1,59,1,59,1,60,1,60,1,60, + 1,60,1,60,1,60,1,60,1,61,1,61,1,61,1,61,1,61,1,61,1,61,1,62,1,62, + 1,62,1,62,1,62,1,62,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63, + 1,63,1,64,1,64,1,64,1,64,1,64,1,64,1,64,1,65,1,65,1,65,1,65,1,65, + 1,65,1,66,1,66,1,66,1,66,1,66,1,67,1,67,1,67,1,67,1,67,1,67,1,67, + 1,67,1,67,1,67,1,67,1,67,1,67,1,67,1,67,1,67,1,67,3,67,753,8,67, + 1,68,1,68,1,68,1,68,1,68,1,68,1,68,1,69,1,69,1,69,1,69,1,69,1,69, + 1,69,1,69,1,70,1,70,1,70,1,70,1,70,1,70,1,70,1,70,1,71,1,71,1,71, + 1,71,1,71,1,71,1,71,1,71,1,72,1,72,1,73,1,73,1,74,1,74,1,75,1,75, + 1,76,1,76,5,76,796,8,76,10,76,12,76,799,9,76,1,76,1,76,1,76,4,76, + 804,8,76,11,76,12,76,805,3,76,808,8,76,1,77,1,77,1,78,1,78,1,79, + 1,79,5,79,816,8,79,10,79,12,79,819,9,79,3,79,821,8,79,1,79,1,79, + 1,79,5,79,826,8,79,10,79,12,79,829,9,79,1,79,3,79,832,8,79,3,79, + 834,8,79,1,80,1,80,1,80,1,80,1,80,1,81,1,81,1,82,1,82,1,82,1,82, + 1,82,1,82,1,82,1,82,1,82,3,82,852,8,82,1,83,1,83,1,83,1,83,1,83, + 1,83,1,84,1,84,1,84,1,84,1,84,1,84,1,84,1,84,1,84,1,84,1,85,1,85, + 1,86,1,86,1,86,1,86,1,86,1,86,1,86,1,86,1,86,3,86,881,8,86,1,87, + 1,87,1,87,1,87,1,87,1,87,1,88,1,88,1,88,1,88,1,88,1,88,1,88,1,88, + 1,88,1,88,1,89,1,89,1,90,4,90,902,8,90,11,90,12,90,903,1,91,1,91, + 4,91,908,8,91,11,91,12,91,909,3,797,817,827,0,92,1,1,3,2,5,3,7,4, + 9,5,11,6,13,7,15,8,17,9,19,10,21,11,23,12,25,13,27,14,29,15,31,16, + 33,17,35,18,37,19,39,20,41,21,43,22,45,23,47,24,49,25,51,26,53,27, + 55,28,57,29,59,30,61,31,63,32,65,33,67,34,69,35,71,36,73,37,75,38, + 77,39,79,40,81,41,83,42,85,43,87,44,89,45,91,46,93,47,95,48,97,49, + 99,50,101,51,103,52,105,53,107,54,109,55,111,56,113,57,115,58,117, + 59,119,60,121,61,123,62,125,63,127,64,129,65,131,66,133,67,135,68, + 137,69,139,70,141,71,143,72,145,73,147,74,149,75,151,0,153,76,155, + 77,157,78,159,79,161,80,163,81,165,82,167,83,169,84,171,85,173,86, + 175,87,177,88,179,89,181,90,183,91,1,0,3,3,0,9,10,13,13,32,32,2, + 0,65,90,97,122,1,0,48,57,949,0,1,1,0,0,0,0,3,1,0,0,0,0,5,1,0,0,0, + 0,7,1,0,0,0,0,9,1,0,0,0,0,11,1,0,0,0,0,13,1,0,0,0,0,15,1,0,0,0,0, + 17,1,0,0,0,0,19,1,0,0,0,0,21,1,0,0,0,0,23,1,0,0,0,0,25,1,0,0,0,0, + 27,1,0,0,0,0,29,1,0,0,0,0,31,1,0,0,0,0,33,1,0,0,0,0,35,1,0,0,0,0, + 37,1,0,0,0,0,39,1,0,0,0,0,41,1,0,0,0,0,43,1,0,0,0,0,45,1,0,0,0,0, + 47,1,0,0,0,0,49,1,0,0,0,0,51,1,0,0,0,0,53,1,0,0,0,0,55,1,0,0,0,0, + 57,1,0,0,0,0,59,1,0,0,0,0,61,1,0,0,0,0,63,1,0,0,0,0,65,1,0,0,0,0, + 67,1,0,0,0,0,69,1,0,0,0,0,71,1,0,0,0,0,73,1,0,0,0,0,75,1,0,0,0,0, + 77,1,0,0,0,0,79,1,0,0,0,0,81,1,0,0,0,0,83,1,0,0,0,0,85,1,0,0,0,0, + 87,1,0,0,0,0,89,1,0,0,0,0,91,1,0,0,0,0,93,1,0,0,0,0,95,1,0,0,0,0, + 97,1,0,0,0,0,99,1,0,0,0,0,101,1,0,0,0,0,103,1,0,0,0,0,105,1,0,0, + 0,0,107,1,0,0,0,0,109,1,0,0,0,0,111,1,0,0,0,0,113,1,0,0,0,0,115, + 1,0,0,0,0,117,1,0,0,0,0,119,1,0,0,0,0,121,1,0,0,0,0,123,1,0,0,0, + 0,125,1,0,0,0,0,127,1,0,0,0,0,129,1,0,0,0,0,131,1,0,0,0,0,133,1, + 0,0,0,0,135,1,0,0,0,0,137,1,0,0,0,0,139,1,0,0,0,0,141,1,0,0,0,0, + 143,1,0,0,0,0,145,1,0,0,0,0,147,1,0,0,0,0,149,1,0,0,0,0,153,1,0, + 0,0,0,155,1,0,0,0,0,157,1,0,0,0,0,159,1,0,0,0,0,161,1,0,0,0,0,163, + 1,0,0,0,0,165,1,0,0,0,0,167,1,0,0,0,0,169,1,0,0,0,0,171,1,0,0,0, + 0,173,1,0,0,0,0,175,1,0,0,0,0,177,1,0,0,0,0,179,1,0,0,0,0,181,1, + 0,0,0,0,183,1,0,0,0,1,185,1,0,0,0,3,187,1,0,0,0,5,190,1,0,0,0,7, + 208,1,0,0,0,9,223,1,0,0,0,11,240,1,0,0,0,13,244,1,0,0,0,15,252,1, + 0,0,0,17,276,1,0,0,0,19,280,1,0,0,0,21,295,1,0,0,0,23,312,1,0,0, + 0,25,320,1,0,0,0,27,380,1,0,0,0,29,384,1,0,0,0,31,386,1,0,0,0,33, + 388,1,0,0,0,35,390,1,0,0,0,37,392,1,0,0,0,39,394,1,0,0,0,41,396, + 1,0,0,0,43,398,1,0,0,0,45,400,1,0,0,0,47,403,1,0,0,0,49,406,1,0, + 0,0,51,408,1,0,0,0,53,410,1,0,0,0,55,412,1,0,0,0,57,420,1,0,0,0, + 59,427,1,0,0,0,61,435,1,0,0,0,63,443,1,0,0,0,65,503,1,0,0,0,67,520, + 1,0,0,0,69,522,1,0,0,0,71,527,1,0,0,0,73,533,1,0,0,0,75,538,1,0, + 0,0,77,543,1,0,0,0,79,547,1,0,0,0,81,551,1,0,0,0,83,556,1,0,0,0, + 85,561,1,0,0,0,87,566,1,0,0,0,89,571,1,0,0,0,91,576,1,0,0,0,93,581, + 1,0,0,0,95,589,1,0,0,0,97,597,1,0,0,0,99,605,1,0,0,0,101,613,1,0, + 0,0,103,621,1,0,0,0,105,629,1,0,0,0,107,635,1,0,0,0,109,641,1,0, + 0,0,111,647,1,0,0,0,113,655,1,0,0,0,115,663,1,0,0,0,117,671,1,0, + 0,0,119,679,1,0,0,0,121,687,1,0,0,0,123,694,1,0,0,0,125,701,1,0, + 0,0,127,707,1,0,0,0,129,717,1,0,0,0,131,724,1,0,0,0,133,730,1,0, + 0,0,135,752,1,0,0,0,137,754,1,0,0,0,139,761,1,0,0,0,141,769,1,0, + 0,0,143,777,1,0,0,0,145,785,1,0,0,0,147,787,1,0,0,0,149,789,1,0, + 0,0,151,791,1,0,0,0,153,793,1,0,0,0,155,809,1,0,0,0,157,811,1,0, + 0,0,159,833,1,0,0,0,161,835,1,0,0,0,163,840,1,0,0,0,165,851,1,0, + 0,0,167,853,1,0,0,0,169,859,1,0,0,0,171,869,1,0,0,0,173,880,1,0, + 0,0,175,882,1,0,0,0,177,888,1,0,0,0,179,898,1,0,0,0,181,901,1,0, + 0,0,183,905,1,0,0,0,185,186,5,44,0,0,186,2,1,0,0,0,187,188,5,46, + 0,0,188,4,1,0,0,0,189,191,7,0,0,0,190,189,1,0,0,0,191,192,1,0,0, + 0,192,190,1,0,0,0,192,193,1,0,0,0,193,194,1,0,0,0,194,195,6,2,0, + 0,195,6,1,0,0,0,196,197,5,92,0,0,197,209,5,44,0,0,198,199,5,92,0, + 0,199,200,5,116,0,0,200,201,5,104,0,0,201,202,5,105,0,0,202,203, + 5,110,0,0,203,204,5,115,0,0,204,205,5,112,0,0,205,206,5,97,0,0,206, + 207,5,99,0,0,207,209,5,101,0,0,208,196,1,0,0,0,208,198,1,0,0,0,209, + 210,1,0,0,0,210,211,6,3,0,0,211,8,1,0,0,0,212,213,5,92,0,0,213,224, + 5,58,0,0,214,215,5,92,0,0,215,216,5,109,0,0,216,217,5,101,0,0,217, + 218,5,100,0,0,218,219,5,115,0,0,219,220,5,112,0,0,220,221,5,97,0, + 0,221,222,5,99,0,0,222,224,5,101,0,0,223,212,1,0,0,0,223,214,1,0, + 0,0,224,225,1,0,0,0,225,226,6,4,0,0,226,10,1,0,0,0,227,228,5,92, + 0,0,228,241,5,59,0,0,229,230,5,92,0,0,230,231,5,116,0,0,231,232, + 5,104,0,0,232,233,5,105,0,0,233,234,5,99,0,0,234,235,5,107,0,0,235, + 236,5,115,0,0,236,237,5,112,0,0,237,238,5,97,0,0,238,239,5,99,0, + 0,239,241,5,101,0,0,240,227,1,0,0,0,240,229,1,0,0,0,241,242,1,0, + 0,0,242,243,6,5,0,0,243,12,1,0,0,0,244,245,5,92,0,0,245,246,5,113, + 0,0,246,247,5,117,0,0,247,248,5,97,0,0,248,249,5,100,0,0,249,250, + 1,0,0,0,250,251,6,6,0,0,251,14,1,0,0,0,252,253,5,92,0,0,253,254, + 5,113,0,0,254,255,5,113,0,0,255,256,5,117,0,0,256,257,5,97,0,0,257, + 258,5,100,0,0,258,259,1,0,0,0,259,260,6,7,0,0,260,16,1,0,0,0,261, + 262,5,92,0,0,262,277,5,33,0,0,263,264,5,92,0,0,264,265,5,110,0,0, + 265,266,5,101,0,0,266,267,5,103,0,0,267,268,5,116,0,0,268,269,5, + 104,0,0,269,270,5,105,0,0,270,271,5,110,0,0,271,272,5,115,0,0,272, + 273,5,112,0,0,273,274,5,97,0,0,274,275,5,99,0,0,275,277,5,101,0, + 0,276,261,1,0,0,0,276,263,1,0,0,0,277,278,1,0,0,0,278,279,6,8,0, + 0,279,18,1,0,0,0,280,281,5,92,0,0,281,282,5,110,0,0,282,283,5,101, + 0,0,283,284,5,103,0,0,284,285,5,109,0,0,285,286,5,101,0,0,286,287, + 5,100,0,0,287,288,5,115,0,0,288,289,5,112,0,0,289,290,5,97,0,0,290, + 291,5,99,0,0,291,292,5,101,0,0,292,293,1,0,0,0,293,294,6,9,0,0,294, + 20,1,0,0,0,295,296,5,92,0,0,296,297,5,110,0,0,297,298,5,101,0,0, + 298,299,5,103,0,0,299,300,5,116,0,0,300,301,5,104,0,0,301,302,5, + 105,0,0,302,303,5,99,0,0,303,304,5,107,0,0,304,305,5,115,0,0,305, + 306,5,112,0,0,306,307,5,97,0,0,307,308,5,99,0,0,308,309,5,101,0, + 0,309,310,1,0,0,0,310,311,6,10,0,0,311,22,1,0,0,0,312,313,5,92,0, + 0,313,314,5,108,0,0,314,315,5,101,0,0,315,316,5,102,0,0,316,317, + 5,116,0,0,317,318,1,0,0,0,318,319,6,11,0,0,319,24,1,0,0,0,320,321, + 5,92,0,0,321,322,5,114,0,0,322,323,5,105,0,0,323,324,5,103,0,0,324, + 325,5,104,0,0,325,326,5,116,0,0,326,327,1,0,0,0,327,328,6,12,0,0, + 328,26,1,0,0,0,329,330,5,92,0,0,330,331,5,118,0,0,331,332,5,114, + 0,0,332,333,5,117,0,0,333,334,5,108,0,0,334,381,5,101,0,0,335,336, + 5,92,0,0,336,337,5,118,0,0,337,338,5,99,0,0,338,339,5,101,0,0,339, + 340,5,110,0,0,340,341,5,116,0,0,341,342,5,101,0,0,342,381,5,114, + 0,0,343,344,5,92,0,0,344,345,5,118,0,0,345,346,5,98,0,0,346,347, + 5,111,0,0,347,381,5,120,0,0,348,349,5,92,0,0,349,350,5,118,0,0,350, + 351,5,115,0,0,351,352,5,107,0,0,352,353,5,105,0,0,353,381,5,112, + 0,0,354,355,5,92,0,0,355,356,5,118,0,0,356,357,5,115,0,0,357,358, + 5,112,0,0,358,359,5,97,0,0,359,360,5,99,0,0,360,381,5,101,0,0,361, + 362,5,92,0,0,362,363,5,104,0,0,363,364,5,102,0,0,364,365,5,105,0, + 0,365,381,5,108,0,0,366,367,5,92,0,0,367,381,5,42,0,0,368,369,5, + 92,0,0,369,381,5,45,0,0,370,371,5,92,0,0,371,381,5,46,0,0,372,373, + 5,92,0,0,373,381,5,47,0,0,374,375,5,92,0,0,375,381,5,34,0,0,376, + 377,5,92,0,0,377,381,5,40,0,0,378,379,5,92,0,0,379,381,5,61,0,0, + 380,329,1,0,0,0,380,335,1,0,0,0,380,343,1,0,0,0,380,348,1,0,0,0, + 380,354,1,0,0,0,380,361,1,0,0,0,380,366,1,0,0,0,380,368,1,0,0,0, + 380,370,1,0,0,0,380,372,1,0,0,0,380,374,1,0,0,0,380,376,1,0,0,0, + 380,378,1,0,0,0,381,382,1,0,0,0,382,383,6,13,0,0,383,28,1,0,0,0, + 384,385,5,43,0,0,385,30,1,0,0,0,386,387,5,45,0,0,387,32,1,0,0,0, + 388,389,5,42,0,0,389,34,1,0,0,0,390,391,5,47,0,0,391,36,1,0,0,0, + 392,393,5,40,0,0,393,38,1,0,0,0,394,395,5,41,0,0,395,40,1,0,0,0, + 396,397,5,123,0,0,397,42,1,0,0,0,398,399,5,125,0,0,399,44,1,0,0, + 0,400,401,5,92,0,0,401,402,5,123,0,0,402,46,1,0,0,0,403,404,5,92, + 0,0,404,405,5,125,0,0,405,48,1,0,0,0,406,407,5,91,0,0,407,50,1,0, + 0,0,408,409,5,93,0,0,409,52,1,0,0,0,410,411,5,124,0,0,411,54,1,0, + 0,0,412,413,5,92,0,0,413,414,5,114,0,0,414,415,5,105,0,0,415,416, + 5,103,0,0,416,417,5,104,0,0,417,418,5,116,0,0,418,419,5,124,0,0, + 419,56,1,0,0,0,420,421,5,92,0,0,421,422,5,108,0,0,422,423,5,101, + 0,0,423,424,5,102,0,0,424,425,5,116,0,0,425,426,5,124,0,0,426,58, + 1,0,0,0,427,428,5,92,0,0,428,429,5,108,0,0,429,430,5,97,0,0,430, + 431,5,110,0,0,431,432,5,103,0,0,432,433,5,108,0,0,433,434,5,101, + 0,0,434,60,1,0,0,0,435,436,5,92,0,0,436,437,5,114,0,0,437,438,5, + 97,0,0,438,439,5,110,0,0,439,440,5,103,0,0,440,441,5,108,0,0,441, + 442,5,101,0,0,442,62,1,0,0,0,443,444,5,92,0,0,444,445,5,108,0,0, + 445,446,5,105,0,0,446,447,5,109,0,0,447,64,1,0,0,0,448,449,5,92, + 0,0,449,450,5,116,0,0,450,504,5,111,0,0,451,452,5,92,0,0,452,453, + 5,114,0,0,453,454,5,105,0,0,454,455,5,103,0,0,455,456,5,104,0,0, + 456,457,5,116,0,0,457,458,5,97,0,0,458,459,5,114,0,0,459,460,5,114, + 0,0,460,461,5,111,0,0,461,504,5,119,0,0,462,463,5,92,0,0,463,464, + 5,82,0,0,464,465,5,105,0,0,465,466,5,103,0,0,466,467,5,104,0,0,467, + 468,5,116,0,0,468,469,5,97,0,0,469,470,5,114,0,0,470,471,5,114,0, + 0,471,472,5,111,0,0,472,504,5,119,0,0,473,474,5,92,0,0,474,475,5, + 108,0,0,475,476,5,111,0,0,476,477,5,110,0,0,477,478,5,103,0,0,478, + 479,5,114,0,0,479,480,5,105,0,0,480,481,5,103,0,0,481,482,5,104, + 0,0,482,483,5,116,0,0,483,484,5,97,0,0,484,485,5,114,0,0,485,486, + 5,114,0,0,486,487,5,111,0,0,487,504,5,119,0,0,488,489,5,92,0,0,489, + 490,5,76,0,0,490,491,5,111,0,0,491,492,5,110,0,0,492,493,5,103,0, + 0,493,494,5,114,0,0,494,495,5,105,0,0,495,496,5,103,0,0,496,497, + 5,104,0,0,497,498,5,116,0,0,498,499,5,97,0,0,499,500,5,114,0,0,500, + 501,5,114,0,0,501,502,5,111,0,0,502,504,5,119,0,0,503,448,1,0,0, + 0,503,451,1,0,0,0,503,462,1,0,0,0,503,473,1,0,0,0,503,488,1,0,0, + 0,504,66,1,0,0,0,505,506,5,92,0,0,506,507,5,105,0,0,507,508,5,110, + 0,0,508,521,5,116,0,0,509,510,5,92,0,0,510,511,5,105,0,0,511,512, + 5,110,0,0,512,513,5,116,0,0,513,514,5,92,0,0,514,515,5,108,0,0,515, + 516,5,105,0,0,516,517,5,109,0,0,517,518,5,105,0,0,518,519,5,116, + 0,0,519,521,5,115,0,0,520,505,1,0,0,0,520,509,1,0,0,0,521,68,1,0, + 0,0,522,523,5,92,0,0,523,524,5,115,0,0,524,525,5,117,0,0,525,526, + 5,109,0,0,526,70,1,0,0,0,527,528,5,92,0,0,528,529,5,112,0,0,529, + 530,5,114,0,0,530,531,5,111,0,0,531,532,5,100,0,0,532,72,1,0,0,0, + 533,534,5,92,0,0,534,535,5,101,0,0,535,536,5,120,0,0,536,537,5,112, + 0,0,537,74,1,0,0,0,538,539,5,92,0,0,539,540,5,108,0,0,540,541,5, + 111,0,0,541,542,5,103,0,0,542,76,1,0,0,0,543,544,5,92,0,0,544,545, + 5,108,0,0,545,546,5,103,0,0,546,78,1,0,0,0,547,548,5,92,0,0,548, + 549,5,108,0,0,549,550,5,110,0,0,550,80,1,0,0,0,551,552,5,92,0,0, + 552,553,5,115,0,0,553,554,5,105,0,0,554,555,5,110,0,0,555,82,1,0, + 0,0,556,557,5,92,0,0,557,558,5,99,0,0,558,559,5,111,0,0,559,560, + 5,115,0,0,560,84,1,0,0,0,561,562,5,92,0,0,562,563,5,116,0,0,563, + 564,5,97,0,0,564,565,5,110,0,0,565,86,1,0,0,0,566,567,5,92,0,0,567, + 568,5,99,0,0,568,569,5,115,0,0,569,570,5,99,0,0,570,88,1,0,0,0,571, + 572,5,92,0,0,572,573,5,115,0,0,573,574,5,101,0,0,574,575,5,99,0, + 0,575,90,1,0,0,0,576,577,5,92,0,0,577,578,5,99,0,0,578,579,5,111, + 0,0,579,580,5,116,0,0,580,92,1,0,0,0,581,582,5,92,0,0,582,583,5, + 97,0,0,583,584,5,114,0,0,584,585,5,99,0,0,585,586,5,115,0,0,586, + 587,5,105,0,0,587,588,5,110,0,0,588,94,1,0,0,0,589,590,5,92,0,0, + 590,591,5,97,0,0,591,592,5,114,0,0,592,593,5,99,0,0,593,594,5,99, + 0,0,594,595,5,111,0,0,595,596,5,115,0,0,596,96,1,0,0,0,597,598,5, + 92,0,0,598,599,5,97,0,0,599,600,5,114,0,0,600,601,5,99,0,0,601,602, + 5,116,0,0,602,603,5,97,0,0,603,604,5,110,0,0,604,98,1,0,0,0,605, + 606,5,92,0,0,606,607,5,97,0,0,607,608,5,114,0,0,608,609,5,99,0,0, + 609,610,5,99,0,0,610,611,5,115,0,0,611,612,5,99,0,0,612,100,1,0, + 0,0,613,614,5,92,0,0,614,615,5,97,0,0,615,616,5,114,0,0,616,617, + 5,99,0,0,617,618,5,115,0,0,618,619,5,101,0,0,619,620,5,99,0,0,620, + 102,1,0,0,0,621,622,5,92,0,0,622,623,5,97,0,0,623,624,5,114,0,0, + 624,625,5,99,0,0,625,626,5,99,0,0,626,627,5,111,0,0,627,628,5,116, + 0,0,628,104,1,0,0,0,629,630,5,92,0,0,630,631,5,115,0,0,631,632,5, + 105,0,0,632,633,5,110,0,0,633,634,5,104,0,0,634,106,1,0,0,0,635, + 636,5,92,0,0,636,637,5,99,0,0,637,638,5,111,0,0,638,639,5,115,0, + 0,639,640,5,104,0,0,640,108,1,0,0,0,641,642,5,92,0,0,642,643,5,116, + 0,0,643,644,5,97,0,0,644,645,5,110,0,0,645,646,5,104,0,0,646,110, + 1,0,0,0,647,648,5,92,0,0,648,649,5,97,0,0,649,650,5,114,0,0,650, + 651,5,115,0,0,651,652,5,105,0,0,652,653,5,110,0,0,653,654,5,104, + 0,0,654,112,1,0,0,0,655,656,5,92,0,0,656,657,5,97,0,0,657,658,5, + 114,0,0,658,659,5,99,0,0,659,660,5,111,0,0,660,661,5,115,0,0,661, + 662,5,104,0,0,662,114,1,0,0,0,663,664,5,92,0,0,664,665,5,97,0,0, + 665,666,5,114,0,0,666,667,5,116,0,0,667,668,5,97,0,0,668,669,5,110, + 0,0,669,670,5,104,0,0,670,116,1,0,0,0,671,672,5,92,0,0,672,673,5, + 108,0,0,673,674,5,102,0,0,674,675,5,108,0,0,675,676,5,111,0,0,676, + 677,5,111,0,0,677,678,5,114,0,0,678,118,1,0,0,0,679,680,5,92,0,0, + 680,681,5,114,0,0,681,682,5,102,0,0,682,683,5,108,0,0,683,684,5, + 111,0,0,684,685,5,111,0,0,685,686,5,114,0,0,686,120,1,0,0,0,687, + 688,5,92,0,0,688,689,5,108,0,0,689,690,5,99,0,0,690,691,5,101,0, + 0,691,692,5,105,0,0,692,693,5,108,0,0,693,122,1,0,0,0,694,695,5, + 92,0,0,695,696,5,114,0,0,696,697,5,99,0,0,697,698,5,101,0,0,698, + 699,5,105,0,0,699,700,5,108,0,0,700,124,1,0,0,0,701,702,5,92,0,0, + 702,703,5,115,0,0,703,704,5,113,0,0,704,705,5,114,0,0,705,706,5, + 116,0,0,706,126,1,0,0,0,707,708,5,92,0,0,708,709,5,111,0,0,709,710, + 5,118,0,0,710,711,5,101,0,0,711,712,5,114,0,0,712,713,5,108,0,0, + 713,714,5,105,0,0,714,715,5,110,0,0,715,716,5,101,0,0,716,128,1, + 0,0,0,717,718,5,92,0,0,718,719,5,116,0,0,719,720,5,105,0,0,720,721, + 5,109,0,0,721,722,5,101,0,0,722,723,5,115,0,0,723,130,1,0,0,0,724, + 725,5,92,0,0,725,726,5,99,0,0,726,727,5,100,0,0,727,728,5,111,0, + 0,728,729,5,116,0,0,729,132,1,0,0,0,730,731,5,92,0,0,731,732,5,100, + 0,0,732,733,5,105,0,0,733,734,5,118,0,0,734,134,1,0,0,0,735,736, + 5,92,0,0,736,737,5,102,0,0,737,738,5,114,0,0,738,739,5,97,0,0,739, + 753,5,99,0,0,740,741,5,92,0,0,741,742,5,100,0,0,742,743,5,102,0, + 0,743,744,5,114,0,0,744,745,5,97,0,0,745,753,5,99,0,0,746,747,5, + 92,0,0,747,748,5,116,0,0,748,749,5,102,0,0,749,750,5,114,0,0,750, + 751,5,97,0,0,751,753,5,99,0,0,752,735,1,0,0,0,752,740,1,0,0,0,752, + 746,1,0,0,0,753,136,1,0,0,0,754,755,5,92,0,0,755,756,5,98,0,0,756, + 757,5,105,0,0,757,758,5,110,0,0,758,759,5,111,0,0,759,760,5,109, + 0,0,760,138,1,0,0,0,761,762,5,92,0,0,762,763,5,100,0,0,763,764,5, + 98,0,0,764,765,5,105,0,0,765,766,5,110,0,0,766,767,5,111,0,0,767, + 768,5,109,0,0,768,140,1,0,0,0,769,770,5,92,0,0,770,771,5,116,0,0, + 771,772,5,98,0,0,772,773,5,105,0,0,773,774,5,110,0,0,774,775,5,111, + 0,0,775,776,5,109,0,0,776,142,1,0,0,0,777,778,5,92,0,0,778,779,5, + 109,0,0,779,780,5,97,0,0,780,781,5,116,0,0,781,782,5,104,0,0,782, + 783,5,105,0,0,783,784,5,116,0,0,784,144,1,0,0,0,785,786,5,95,0,0, + 786,146,1,0,0,0,787,788,5,94,0,0,788,148,1,0,0,0,789,790,5,58,0, + 0,790,150,1,0,0,0,791,792,7,0,0,0,792,152,1,0,0,0,793,797,5,100, + 0,0,794,796,3,151,75,0,795,794,1,0,0,0,796,799,1,0,0,0,797,798,1, + 0,0,0,797,795,1,0,0,0,798,807,1,0,0,0,799,797,1,0,0,0,800,808,7, + 1,0,0,801,803,5,92,0,0,802,804,7,1,0,0,803,802,1,0,0,0,804,805,1, + 0,0,0,805,803,1,0,0,0,805,806,1,0,0,0,806,808,1,0,0,0,807,800,1, + 0,0,0,807,801,1,0,0,0,808,154,1,0,0,0,809,810,7,1,0,0,810,156,1, + 0,0,0,811,812,7,2,0,0,812,158,1,0,0,0,813,817,5,38,0,0,814,816,3, + 151,75,0,815,814,1,0,0,0,816,819,1,0,0,0,817,818,1,0,0,0,817,815, + 1,0,0,0,818,821,1,0,0,0,819,817,1,0,0,0,820,813,1,0,0,0,820,821, + 1,0,0,0,821,822,1,0,0,0,822,834,5,61,0,0,823,831,5,61,0,0,824,826, + 3,151,75,0,825,824,1,0,0,0,826,829,1,0,0,0,827,828,1,0,0,0,827,825, + 1,0,0,0,828,830,1,0,0,0,829,827,1,0,0,0,830,832,5,38,0,0,831,827, + 1,0,0,0,831,832,1,0,0,0,832,834,1,0,0,0,833,820,1,0,0,0,833,823, + 1,0,0,0,834,160,1,0,0,0,835,836,5,92,0,0,836,837,5,110,0,0,837,838, + 5,101,0,0,838,839,5,113,0,0,839,162,1,0,0,0,840,841,5,60,0,0,841, + 164,1,0,0,0,842,843,5,92,0,0,843,844,5,108,0,0,844,845,5,101,0,0, + 845,852,5,113,0,0,846,847,5,92,0,0,847,848,5,108,0,0,848,852,5,101, + 0,0,849,852,3,167,83,0,850,852,3,169,84,0,851,842,1,0,0,0,851,846, + 1,0,0,0,851,849,1,0,0,0,851,850,1,0,0,0,852,166,1,0,0,0,853,854, + 5,92,0,0,854,855,5,108,0,0,855,856,5,101,0,0,856,857,5,113,0,0,857, + 858,5,113,0,0,858,168,1,0,0,0,859,860,5,92,0,0,860,861,5,108,0,0, + 861,862,5,101,0,0,862,863,5,113,0,0,863,864,5,115,0,0,864,865,5, + 108,0,0,865,866,5,97,0,0,866,867,5,110,0,0,867,868,5,116,0,0,868, + 170,1,0,0,0,869,870,5,62,0,0,870,172,1,0,0,0,871,872,5,92,0,0,872, + 873,5,103,0,0,873,874,5,101,0,0,874,881,5,113,0,0,875,876,5,92,0, + 0,876,877,5,103,0,0,877,881,5,101,0,0,878,881,3,175,87,0,879,881, + 3,177,88,0,880,871,1,0,0,0,880,875,1,0,0,0,880,878,1,0,0,0,880,879, + 1,0,0,0,881,174,1,0,0,0,882,883,5,92,0,0,883,884,5,103,0,0,884,885, + 5,101,0,0,885,886,5,113,0,0,886,887,5,113,0,0,887,176,1,0,0,0,888, + 889,5,92,0,0,889,890,5,103,0,0,890,891,5,101,0,0,891,892,5,113,0, + 0,892,893,5,115,0,0,893,894,5,108,0,0,894,895,5,97,0,0,895,896,5, + 110,0,0,896,897,5,116,0,0,897,178,1,0,0,0,898,899,5,33,0,0,899,180, + 1,0,0,0,900,902,5,39,0,0,901,900,1,0,0,0,902,903,1,0,0,0,903,901, + 1,0,0,0,903,904,1,0,0,0,904,182,1,0,0,0,905,907,5,92,0,0,906,908, + 7,1,0,0,907,906,1,0,0,0,908,909,1,0,0,0,909,907,1,0,0,0,909,910, + 1,0,0,0,910,184,1,0,0,0,22,0,192,208,223,240,276,380,503,520,752, + 797,805,807,817,820,827,831,833,851,880,903,909,1,6,0,0 + ] + +class LaTeXLexer(Lexer): + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + T__0 = 1 + T__1 = 2 + WS = 3 + THINSPACE = 4 + MEDSPACE = 5 + THICKSPACE = 6 + QUAD = 7 + QQUAD = 8 + NEGTHINSPACE = 9 + NEGMEDSPACE = 10 + NEGTHICKSPACE = 11 + CMD_LEFT = 12 + CMD_RIGHT = 13 + IGNORE = 14 + ADD = 15 + SUB = 16 + MUL = 17 + DIV = 18 + L_PAREN = 19 + R_PAREN = 20 + L_BRACE = 21 + R_BRACE = 22 + L_BRACE_LITERAL = 23 + R_BRACE_LITERAL = 24 + L_BRACKET = 25 + R_BRACKET = 26 + BAR = 27 + R_BAR = 28 + L_BAR = 29 + L_ANGLE = 30 + R_ANGLE = 31 + FUNC_LIM = 32 + LIM_APPROACH_SYM = 33 + FUNC_INT = 34 + FUNC_SUM = 35 + FUNC_PROD = 36 + FUNC_EXP = 37 + FUNC_LOG = 38 + FUNC_LG = 39 + FUNC_LN = 40 + FUNC_SIN = 41 + FUNC_COS = 42 + FUNC_TAN = 43 + FUNC_CSC = 44 + FUNC_SEC = 45 + FUNC_COT = 46 + FUNC_ARCSIN = 47 + FUNC_ARCCOS = 48 + FUNC_ARCTAN = 49 + FUNC_ARCCSC = 50 + FUNC_ARCSEC = 51 + FUNC_ARCCOT = 52 + FUNC_SINH = 53 + FUNC_COSH = 54 + FUNC_TANH = 55 + FUNC_ARSINH = 56 + FUNC_ARCOSH = 57 + FUNC_ARTANH = 58 + L_FLOOR = 59 + R_FLOOR = 60 + L_CEIL = 61 + R_CEIL = 62 + FUNC_SQRT = 63 + FUNC_OVERLINE = 64 + CMD_TIMES = 65 + CMD_CDOT = 66 + CMD_DIV = 67 + CMD_FRAC = 68 + CMD_BINOM = 69 + CMD_DBINOM = 70 + CMD_TBINOM = 71 + CMD_MATHIT = 72 + UNDERSCORE = 73 + CARET = 74 + COLON = 75 + DIFFERENTIAL = 76 + LETTER = 77 + DIGIT = 78 + EQUAL = 79 + NEQ = 80 + LT = 81 + LTE = 82 + LTE_Q = 83 + LTE_S = 84 + GT = 85 + GTE = 86 + GTE_Q = 87 + GTE_S = 88 + BANG = 89 + SINGLE_QUOTES = 90 + SYMBOL = 91 + + channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] + + modeNames = [ "DEFAULT_MODE" ] + + literalNames = [ "", + "','", "'.'", "'\\quad'", "'\\qquad'", "'\\negmedspace'", "'\\negthickspace'", + "'\\left'", "'\\right'", "'+'", "'-'", "'*'", "'/'", "'('", + "')'", "'{'", "'}'", "'\\{'", "'\\}'", "'['", "']'", "'|'", + "'\\right|'", "'\\left|'", "'\\langle'", "'\\rangle'", "'\\lim'", + "'\\sum'", "'\\prod'", "'\\exp'", "'\\log'", "'\\lg'", "'\\ln'", + "'\\sin'", "'\\cos'", "'\\tan'", "'\\csc'", "'\\sec'", "'\\cot'", + "'\\arcsin'", "'\\arccos'", "'\\arctan'", "'\\arccsc'", "'\\arcsec'", + "'\\arccot'", "'\\sinh'", "'\\cosh'", "'\\tanh'", "'\\arsinh'", + "'\\arcosh'", "'\\artanh'", "'\\lfloor'", "'\\rfloor'", "'\\lceil'", + "'\\rceil'", "'\\sqrt'", "'\\overline'", "'\\times'", "'\\cdot'", + "'\\div'", "'\\binom'", "'\\dbinom'", "'\\tbinom'", "'\\mathit'", + "'_'", "'^'", "':'", "'\\neq'", "'<'", "'\\leqq'", "'\\leqslant'", + "'>'", "'\\geqq'", "'\\geqslant'", "'!'" ] + + symbolicNames = [ "", + "WS", "THINSPACE", "MEDSPACE", "THICKSPACE", "QUAD", "QQUAD", + "NEGTHINSPACE", "NEGMEDSPACE", "NEGTHICKSPACE", "CMD_LEFT", + "CMD_RIGHT", "IGNORE", "ADD", "SUB", "MUL", "DIV", "L_PAREN", + "R_PAREN", "L_BRACE", "R_BRACE", "L_BRACE_LITERAL", "R_BRACE_LITERAL", + "L_BRACKET", "R_BRACKET", "BAR", "R_BAR", "L_BAR", "L_ANGLE", + "R_ANGLE", "FUNC_LIM", "LIM_APPROACH_SYM", "FUNC_INT", "FUNC_SUM", + "FUNC_PROD", "FUNC_EXP", "FUNC_LOG", "FUNC_LG", "FUNC_LN", "FUNC_SIN", + "FUNC_COS", "FUNC_TAN", "FUNC_CSC", "FUNC_SEC", "FUNC_COT", + "FUNC_ARCSIN", "FUNC_ARCCOS", "FUNC_ARCTAN", "FUNC_ARCCSC", + "FUNC_ARCSEC", "FUNC_ARCCOT", "FUNC_SINH", "FUNC_COSH", "FUNC_TANH", + "FUNC_ARSINH", "FUNC_ARCOSH", "FUNC_ARTANH", "L_FLOOR", "R_FLOOR", + "L_CEIL", "R_CEIL", "FUNC_SQRT", "FUNC_OVERLINE", "CMD_TIMES", + "CMD_CDOT", "CMD_DIV", "CMD_FRAC", "CMD_BINOM", "CMD_DBINOM", + "CMD_TBINOM", "CMD_MATHIT", "UNDERSCORE", "CARET", "COLON", + "DIFFERENTIAL", "LETTER", "DIGIT", "EQUAL", "NEQ", "LT", "LTE", + "LTE_Q", "LTE_S", "GT", "GTE", "GTE_Q", "GTE_S", "BANG", "SINGLE_QUOTES", + "SYMBOL" ] + + ruleNames = [ "T__0", "T__1", "WS", "THINSPACE", "MEDSPACE", "THICKSPACE", + "QUAD", "QQUAD", "NEGTHINSPACE", "NEGMEDSPACE", "NEGTHICKSPACE", + "CMD_LEFT", "CMD_RIGHT", "IGNORE", "ADD", "SUB", "MUL", + "DIV", "L_PAREN", "R_PAREN", "L_BRACE", "R_BRACE", "L_BRACE_LITERAL", + "R_BRACE_LITERAL", "L_BRACKET", "R_BRACKET", "BAR", "R_BAR", + "L_BAR", "L_ANGLE", "R_ANGLE", "FUNC_LIM", "LIM_APPROACH_SYM", + "FUNC_INT", "FUNC_SUM", "FUNC_PROD", "FUNC_EXP", "FUNC_LOG", + "FUNC_LG", "FUNC_LN", "FUNC_SIN", "FUNC_COS", "FUNC_TAN", + "FUNC_CSC", "FUNC_SEC", "FUNC_COT", "FUNC_ARCSIN", "FUNC_ARCCOS", + "FUNC_ARCTAN", "FUNC_ARCCSC", "FUNC_ARCSEC", "FUNC_ARCCOT", + "FUNC_SINH", "FUNC_COSH", "FUNC_TANH", "FUNC_ARSINH", + "FUNC_ARCOSH", "FUNC_ARTANH", "L_FLOOR", "R_FLOOR", "L_CEIL", + "R_CEIL", "FUNC_SQRT", "FUNC_OVERLINE", "CMD_TIMES", "CMD_CDOT", + "CMD_DIV", "CMD_FRAC", "CMD_BINOM", "CMD_DBINOM", "CMD_TBINOM", + "CMD_MATHIT", "UNDERSCORE", "CARET", "COLON", "WS_CHAR", + "DIFFERENTIAL", "LETTER", "DIGIT", "EQUAL", "NEQ", "LT", + "LTE", "LTE_Q", "LTE_S", "GT", "GTE", "GTE_Q", "GTE_S", + "BANG", "SINGLE_QUOTES", "SYMBOL" ] + + grammarFileName = "LaTeX.g4" + + def __init__(self, input=None, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.11.1") + self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) + self._actions = None + self._predicates = None + + diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_antlr/latexparser.py b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_antlr/latexparser.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f58119055ded8f77380bbef52c77ddd6a01cfe --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_antlr/latexparser.py @@ -0,0 +1,3652 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated from ../LaTeX.g4, derived from latex2sympy +# latex2sympy is licensed under the MIT license +# https://github.com/augustt198/latex2sympy/blob/master/LICENSE.txt +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +from io import StringIO +import sys +if sys.version_info[1] > 5: + from typing import TextIO +else: + from typing.io import TextIO + +def serializedATN(): + return [ + 4,1,91,522,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6,7, + 6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2,13,7,13, + 2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7,19,2,20, + 7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2,26,7,26, + 2,27,7,27,2,28,7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7,32,2,33, + 7,33,2,34,7,34,2,35,7,35,2,36,7,36,2,37,7,37,2,38,7,38,2,39,7,39, + 2,40,7,40,1,0,1,0,1,1,1,1,1,1,1,1,1,1,1,1,5,1,91,8,1,10,1,12,1,94, + 9,1,1,2,1,2,1,2,1,2,1,3,1,3,1,4,1,4,1,4,1,4,1,4,1,4,5,4,108,8,4, + 10,4,12,4,111,9,4,1,5,1,5,1,5,1,5,1,5,1,5,5,5,119,8,5,10,5,12,5, + 122,9,5,1,6,1,6,1,6,1,6,1,6,1,6,5,6,130,8,6,10,6,12,6,133,9,6,1, + 7,1,7,1,7,4,7,138,8,7,11,7,12,7,139,3,7,142,8,7,1,8,1,8,1,8,1,8, + 5,8,148,8,8,10,8,12,8,151,9,8,3,8,153,8,8,1,9,1,9,5,9,157,8,9,10, + 9,12,9,160,9,9,1,10,1,10,5,10,164,8,10,10,10,12,10,167,9,10,1,11, + 1,11,3,11,171,8,11,1,12,1,12,1,12,1,12,1,12,1,12,3,12,179,8,12,1, + 13,1,13,1,13,1,13,3,13,185,8,13,1,13,1,13,1,14,1,14,1,14,1,14,3, + 14,193,8,14,1,14,1,14,1,15,1,15,1,15,1,15,1,15,1,15,1,15,1,15,1, + 15,1,15,3,15,207,8,15,1,15,3,15,210,8,15,5,15,212,8,15,10,15,12, + 15,215,9,15,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,3, + 16,227,8,16,1,16,3,16,230,8,16,5,16,232,8,16,10,16,12,16,235,9,16, + 1,17,1,17,1,17,1,17,1,17,1,17,3,17,243,8,17,1,18,1,18,1,18,1,18, + 1,18,3,18,250,8,18,1,19,1,19,1,19,1,19,1,19,1,19,1,19,1,19,1,19, + 1,19,1,19,1,19,1,19,1,19,1,19,1,19,3,19,268,8,19,1,20,1,20,1,20, + 1,20,1,21,4,21,275,8,21,11,21,12,21,276,1,21,1,21,1,21,1,21,5,21, + 283,8,21,10,21,12,21,286,9,21,1,21,1,21,4,21,290,8,21,11,21,12,21, + 291,3,21,294,8,21,1,22,1,22,3,22,298,8,22,1,22,3,22,301,8,22,1,22, + 3,22,304,8,22,1,22,3,22,307,8,22,3,22,309,8,22,1,22,1,22,1,22,1, + 22,1,22,1,22,1,22,3,22,318,8,22,1,23,1,23,1,23,1,23,1,24,1,24,1, + 24,1,24,1,25,1,25,1,25,1,25,1,25,1,26,5,26,334,8,26,10,26,12,26, + 337,9,26,1,27,1,27,1,27,1,27,1,27,1,27,3,27,345,8,27,1,27,1,27,1, + 27,1,27,1,27,3,27,352,8,27,1,28,1,28,1,28,1,28,1,28,1,28,1,28,1, + 28,1,29,1,29,1,29,1,29,1,30,1,30,1,30,1,30,1,31,1,31,1,32,1,32,3, + 32,374,8,32,1,32,3,32,377,8,32,1,32,3,32,380,8,32,1,32,3,32,383, + 8,32,3,32,385,8,32,1,32,1,32,1,32,1,32,1,32,3,32,392,8,32,1,32,1, + 32,3,32,396,8,32,1,32,3,32,399,8,32,1,32,3,32,402,8,32,1,32,3,32, + 405,8,32,3,32,407,8,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1, + 32,1,32,1,32,3,32,420,8,32,1,32,3,32,423,8,32,1,32,1,32,1,32,3,32, + 428,8,32,1,32,1,32,1,32,1,32,1,32,3,32,435,8,32,1,32,1,32,1,32,1, + 32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,3, + 32,453,8,32,1,32,1,32,1,32,1,32,1,32,1,32,3,32,461,8,32,1,33,1,33, + 1,33,1,33,1,33,3,33,468,8,33,1,34,1,34,1,34,1,34,1,34,1,34,1,34, + 1,34,1,34,1,34,1,34,3,34,481,8,34,3,34,483,8,34,1,34,1,34,1,35,1, + 35,1,35,1,35,1,35,3,35,492,8,35,1,36,1,36,1,37,1,37,1,37,1,37,1, + 37,1,37,3,37,502,8,37,1,38,1,38,1,38,1,38,1,38,1,38,3,38,510,8,38, + 1,39,1,39,1,39,1,39,1,39,1,40,1,40,1,40,1,40,1,40,1,40,0,6,2,8,10, + 12,30,32,41,0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36, + 38,40,42,44,46,48,50,52,54,56,58,60,62,64,66,68,70,72,74,76,78,80, + 0,9,2,0,79,82,85,86,1,0,15,16,3,0,17,18,65,67,75,75,2,0,77,77,91, + 91,1,0,27,28,2,0,27,27,29,29,1,0,69,71,1,0,37,58,1,0,35,36,563,0, + 82,1,0,0,0,2,84,1,0,0,0,4,95,1,0,0,0,6,99,1,0,0,0,8,101,1,0,0,0, + 10,112,1,0,0,0,12,123,1,0,0,0,14,141,1,0,0,0,16,152,1,0,0,0,18,154, + 1,0,0,0,20,161,1,0,0,0,22,170,1,0,0,0,24,172,1,0,0,0,26,180,1,0, + 0,0,28,188,1,0,0,0,30,196,1,0,0,0,32,216,1,0,0,0,34,242,1,0,0,0, + 36,249,1,0,0,0,38,267,1,0,0,0,40,269,1,0,0,0,42,274,1,0,0,0,44,317, + 1,0,0,0,46,319,1,0,0,0,48,323,1,0,0,0,50,327,1,0,0,0,52,335,1,0, + 0,0,54,338,1,0,0,0,56,353,1,0,0,0,58,361,1,0,0,0,60,365,1,0,0,0, + 62,369,1,0,0,0,64,460,1,0,0,0,66,467,1,0,0,0,68,469,1,0,0,0,70,491, + 1,0,0,0,72,493,1,0,0,0,74,495,1,0,0,0,76,503,1,0,0,0,78,511,1,0, + 0,0,80,516,1,0,0,0,82,83,3,2,1,0,83,1,1,0,0,0,84,85,6,1,-1,0,85, + 86,3,6,3,0,86,92,1,0,0,0,87,88,10,2,0,0,88,89,7,0,0,0,89,91,3,2, + 1,3,90,87,1,0,0,0,91,94,1,0,0,0,92,90,1,0,0,0,92,93,1,0,0,0,93,3, + 1,0,0,0,94,92,1,0,0,0,95,96,3,6,3,0,96,97,5,79,0,0,97,98,3,6,3,0, + 98,5,1,0,0,0,99,100,3,8,4,0,100,7,1,0,0,0,101,102,6,4,-1,0,102,103, + 3,10,5,0,103,109,1,0,0,0,104,105,10,2,0,0,105,106,7,1,0,0,106,108, + 3,8,4,3,107,104,1,0,0,0,108,111,1,0,0,0,109,107,1,0,0,0,109,110, + 1,0,0,0,110,9,1,0,0,0,111,109,1,0,0,0,112,113,6,5,-1,0,113,114,3, + 14,7,0,114,120,1,0,0,0,115,116,10,2,0,0,116,117,7,2,0,0,117,119, + 3,10,5,3,118,115,1,0,0,0,119,122,1,0,0,0,120,118,1,0,0,0,120,121, + 1,0,0,0,121,11,1,0,0,0,122,120,1,0,0,0,123,124,6,6,-1,0,124,125, + 3,16,8,0,125,131,1,0,0,0,126,127,10,2,0,0,127,128,7,2,0,0,128,130, + 3,12,6,3,129,126,1,0,0,0,130,133,1,0,0,0,131,129,1,0,0,0,131,132, + 1,0,0,0,132,13,1,0,0,0,133,131,1,0,0,0,134,135,7,1,0,0,135,142,3, + 14,7,0,136,138,3,18,9,0,137,136,1,0,0,0,138,139,1,0,0,0,139,137, + 1,0,0,0,139,140,1,0,0,0,140,142,1,0,0,0,141,134,1,0,0,0,141,137, + 1,0,0,0,142,15,1,0,0,0,143,144,7,1,0,0,144,153,3,16,8,0,145,149, + 3,18,9,0,146,148,3,20,10,0,147,146,1,0,0,0,148,151,1,0,0,0,149,147, + 1,0,0,0,149,150,1,0,0,0,150,153,1,0,0,0,151,149,1,0,0,0,152,143, + 1,0,0,0,152,145,1,0,0,0,153,17,1,0,0,0,154,158,3,30,15,0,155,157, + 3,22,11,0,156,155,1,0,0,0,157,160,1,0,0,0,158,156,1,0,0,0,158,159, + 1,0,0,0,159,19,1,0,0,0,160,158,1,0,0,0,161,165,3,32,16,0,162,164, + 3,22,11,0,163,162,1,0,0,0,164,167,1,0,0,0,165,163,1,0,0,0,165,166, + 1,0,0,0,166,21,1,0,0,0,167,165,1,0,0,0,168,171,5,89,0,0,169,171, + 3,24,12,0,170,168,1,0,0,0,170,169,1,0,0,0,171,23,1,0,0,0,172,178, + 5,27,0,0,173,179,3,28,14,0,174,179,3,26,13,0,175,176,3,28,14,0,176, + 177,3,26,13,0,177,179,1,0,0,0,178,173,1,0,0,0,178,174,1,0,0,0,178, + 175,1,0,0,0,179,25,1,0,0,0,180,181,5,73,0,0,181,184,5,21,0,0,182, + 185,3,6,3,0,183,185,3,4,2,0,184,182,1,0,0,0,184,183,1,0,0,0,185, + 186,1,0,0,0,186,187,5,22,0,0,187,27,1,0,0,0,188,189,5,74,0,0,189, + 192,5,21,0,0,190,193,3,6,3,0,191,193,3,4,2,0,192,190,1,0,0,0,192, + 191,1,0,0,0,193,194,1,0,0,0,194,195,5,22,0,0,195,29,1,0,0,0,196, + 197,6,15,-1,0,197,198,3,34,17,0,198,213,1,0,0,0,199,200,10,2,0,0, + 200,206,5,74,0,0,201,207,3,44,22,0,202,203,5,21,0,0,203,204,3,6, + 3,0,204,205,5,22,0,0,205,207,1,0,0,0,206,201,1,0,0,0,206,202,1,0, + 0,0,207,209,1,0,0,0,208,210,3,74,37,0,209,208,1,0,0,0,209,210,1, + 0,0,0,210,212,1,0,0,0,211,199,1,0,0,0,212,215,1,0,0,0,213,211,1, + 0,0,0,213,214,1,0,0,0,214,31,1,0,0,0,215,213,1,0,0,0,216,217,6,16, + -1,0,217,218,3,36,18,0,218,233,1,0,0,0,219,220,10,2,0,0,220,226, + 5,74,0,0,221,227,3,44,22,0,222,223,5,21,0,0,223,224,3,6,3,0,224, + 225,5,22,0,0,225,227,1,0,0,0,226,221,1,0,0,0,226,222,1,0,0,0,227, + 229,1,0,0,0,228,230,3,74,37,0,229,228,1,0,0,0,229,230,1,0,0,0,230, + 232,1,0,0,0,231,219,1,0,0,0,232,235,1,0,0,0,233,231,1,0,0,0,233, + 234,1,0,0,0,234,33,1,0,0,0,235,233,1,0,0,0,236,243,3,38,19,0,237, + 243,3,40,20,0,238,243,3,64,32,0,239,243,3,44,22,0,240,243,3,58,29, + 0,241,243,3,60,30,0,242,236,1,0,0,0,242,237,1,0,0,0,242,238,1,0, + 0,0,242,239,1,0,0,0,242,240,1,0,0,0,242,241,1,0,0,0,243,35,1,0,0, + 0,244,250,3,38,19,0,245,250,3,40,20,0,246,250,3,44,22,0,247,250, + 3,58,29,0,248,250,3,60,30,0,249,244,1,0,0,0,249,245,1,0,0,0,249, + 246,1,0,0,0,249,247,1,0,0,0,249,248,1,0,0,0,250,37,1,0,0,0,251,252, + 5,19,0,0,252,253,3,6,3,0,253,254,5,20,0,0,254,268,1,0,0,0,255,256, + 5,25,0,0,256,257,3,6,3,0,257,258,5,26,0,0,258,268,1,0,0,0,259,260, + 5,21,0,0,260,261,3,6,3,0,261,262,5,22,0,0,262,268,1,0,0,0,263,264, + 5,23,0,0,264,265,3,6,3,0,265,266,5,24,0,0,266,268,1,0,0,0,267,251, + 1,0,0,0,267,255,1,0,0,0,267,259,1,0,0,0,267,263,1,0,0,0,268,39,1, + 0,0,0,269,270,5,27,0,0,270,271,3,6,3,0,271,272,5,27,0,0,272,41,1, + 0,0,0,273,275,5,78,0,0,274,273,1,0,0,0,275,276,1,0,0,0,276,274,1, + 0,0,0,276,277,1,0,0,0,277,284,1,0,0,0,278,279,5,1,0,0,279,280,5, + 78,0,0,280,281,5,78,0,0,281,283,5,78,0,0,282,278,1,0,0,0,283,286, + 1,0,0,0,284,282,1,0,0,0,284,285,1,0,0,0,285,293,1,0,0,0,286,284, + 1,0,0,0,287,289,5,2,0,0,288,290,5,78,0,0,289,288,1,0,0,0,290,291, + 1,0,0,0,291,289,1,0,0,0,291,292,1,0,0,0,292,294,1,0,0,0,293,287, + 1,0,0,0,293,294,1,0,0,0,294,43,1,0,0,0,295,308,7,3,0,0,296,298,3, + 74,37,0,297,296,1,0,0,0,297,298,1,0,0,0,298,300,1,0,0,0,299,301, + 5,90,0,0,300,299,1,0,0,0,300,301,1,0,0,0,301,309,1,0,0,0,302,304, + 5,90,0,0,303,302,1,0,0,0,303,304,1,0,0,0,304,306,1,0,0,0,305,307, + 3,74,37,0,306,305,1,0,0,0,306,307,1,0,0,0,307,309,1,0,0,0,308,297, + 1,0,0,0,308,303,1,0,0,0,309,318,1,0,0,0,310,318,3,42,21,0,311,318, + 5,76,0,0,312,318,3,50,25,0,313,318,3,54,27,0,314,318,3,56,28,0,315, + 318,3,46,23,0,316,318,3,48,24,0,317,295,1,0,0,0,317,310,1,0,0,0, + 317,311,1,0,0,0,317,312,1,0,0,0,317,313,1,0,0,0,317,314,1,0,0,0, + 317,315,1,0,0,0,317,316,1,0,0,0,318,45,1,0,0,0,319,320,5,30,0,0, + 320,321,3,6,3,0,321,322,7,4,0,0,322,47,1,0,0,0,323,324,7,5,0,0,324, + 325,3,6,3,0,325,326,5,31,0,0,326,49,1,0,0,0,327,328,5,72,0,0,328, + 329,5,21,0,0,329,330,3,52,26,0,330,331,5,22,0,0,331,51,1,0,0,0,332, + 334,5,77,0,0,333,332,1,0,0,0,334,337,1,0,0,0,335,333,1,0,0,0,335, + 336,1,0,0,0,336,53,1,0,0,0,337,335,1,0,0,0,338,344,5,68,0,0,339, + 345,5,78,0,0,340,341,5,21,0,0,341,342,3,6,3,0,342,343,5,22,0,0,343, + 345,1,0,0,0,344,339,1,0,0,0,344,340,1,0,0,0,345,351,1,0,0,0,346, + 352,5,78,0,0,347,348,5,21,0,0,348,349,3,6,3,0,349,350,5,22,0,0,350, + 352,1,0,0,0,351,346,1,0,0,0,351,347,1,0,0,0,352,55,1,0,0,0,353,354, + 7,6,0,0,354,355,5,21,0,0,355,356,3,6,3,0,356,357,5,22,0,0,357,358, + 5,21,0,0,358,359,3,6,3,0,359,360,5,22,0,0,360,57,1,0,0,0,361,362, + 5,59,0,0,362,363,3,6,3,0,363,364,5,60,0,0,364,59,1,0,0,0,365,366, + 5,61,0,0,366,367,3,6,3,0,367,368,5,62,0,0,368,61,1,0,0,0,369,370, + 7,7,0,0,370,63,1,0,0,0,371,384,3,62,31,0,372,374,3,74,37,0,373,372, + 1,0,0,0,373,374,1,0,0,0,374,376,1,0,0,0,375,377,3,76,38,0,376,375, + 1,0,0,0,376,377,1,0,0,0,377,385,1,0,0,0,378,380,3,76,38,0,379,378, + 1,0,0,0,379,380,1,0,0,0,380,382,1,0,0,0,381,383,3,74,37,0,382,381, + 1,0,0,0,382,383,1,0,0,0,383,385,1,0,0,0,384,373,1,0,0,0,384,379, + 1,0,0,0,385,391,1,0,0,0,386,387,5,19,0,0,387,388,3,70,35,0,388,389, + 5,20,0,0,389,392,1,0,0,0,390,392,3,72,36,0,391,386,1,0,0,0,391,390, + 1,0,0,0,392,461,1,0,0,0,393,406,7,3,0,0,394,396,3,74,37,0,395,394, + 1,0,0,0,395,396,1,0,0,0,396,398,1,0,0,0,397,399,5,90,0,0,398,397, + 1,0,0,0,398,399,1,0,0,0,399,407,1,0,0,0,400,402,5,90,0,0,401,400, + 1,0,0,0,401,402,1,0,0,0,402,404,1,0,0,0,403,405,3,74,37,0,404,403, + 1,0,0,0,404,405,1,0,0,0,405,407,1,0,0,0,406,395,1,0,0,0,406,401, + 1,0,0,0,407,408,1,0,0,0,408,409,5,19,0,0,409,410,3,66,33,0,410,411, + 5,20,0,0,411,461,1,0,0,0,412,419,5,34,0,0,413,414,3,74,37,0,414, + 415,3,76,38,0,415,420,1,0,0,0,416,417,3,76,38,0,417,418,3,74,37, + 0,418,420,1,0,0,0,419,413,1,0,0,0,419,416,1,0,0,0,419,420,1,0,0, + 0,420,427,1,0,0,0,421,423,3,8,4,0,422,421,1,0,0,0,422,423,1,0,0, + 0,423,424,1,0,0,0,424,428,5,76,0,0,425,428,3,54,27,0,426,428,3,8, + 4,0,427,422,1,0,0,0,427,425,1,0,0,0,427,426,1,0,0,0,428,461,1,0, + 0,0,429,434,5,63,0,0,430,431,5,25,0,0,431,432,3,6,3,0,432,433,5, + 26,0,0,433,435,1,0,0,0,434,430,1,0,0,0,434,435,1,0,0,0,435,436,1, + 0,0,0,436,437,5,21,0,0,437,438,3,6,3,0,438,439,5,22,0,0,439,461, + 1,0,0,0,440,441,5,64,0,0,441,442,5,21,0,0,442,443,3,6,3,0,443,444, + 5,22,0,0,444,461,1,0,0,0,445,452,7,8,0,0,446,447,3,78,39,0,447,448, + 3,76,38,0,448,453,1,0,0,0,449,450,3,76,38,0,450,451,3,78,39,0,451, + 453,1,0,0,0,452,446,1,0,0,0,452,449,1,0,0,0,453,454,1,0,0,0,454, + 455,3,10,5,0,455,461,1,0,0,0,456,457,5,32,0,0,457,458,3,68,34,0, + 458,459,3,10,5,0,459,461,1,0,0,0,460,371,1,0,0,0,460,393,1,0,0,0, + 460,412,1,0,0,0,460,429,1,0,0,0,460,440,1,0,0,0,460,445,1,0,0,0, + 460,456,1,0,0,0,461,65,1,0,0,0,462,463,3,6,3,0,463,464,5,1,0,0,464, + 465,3,66,33,0,465,468,1,0,0,0,466,468,3,6,3,0,467,462,1,0,0,0,467, + 466,1,0,0,0,468,67,1,0,0,0,469,470,5,73,0,0,470,471,5,21,0,0,471, + 472,7,3,0,0,472,473,5,33,0,0,473,482,3,6,3,0,474,480,5,74,0,0,475, + 476,5,21,0,0,476,477,7,1,0,0,477,481,5,22,0,0,478,481,5,15,0,0,479, + 481,5,16,0,0,480,475,1,0,0,0,480,478,1,0,0,0,480,479,1,0,0,0,481, + 483,1,0,0,0,482,474,1,0,0,0,482,483,1,0,0,0,483,484,1,0,0,0,484, + 485,5,22,0,0,485,69,1,0,0,0,486,492,3,6,3,0,487,488,3,6,3,0,488, + 489,5,1,0,0,489,490,3,70,35,0,490,492,1,0,0,0,491,486,1,0,0,0,491, + 487,1,0,0,0,492,71,1,0,0,0,493,494,3,12,6,0,494,73,1,0,0,0,495,501, + 5,73,0,0,496,502,3,44,22,0,497,498,5,21,0,0,498,499,3,6,3,0,499, + 500,5,22,0,0,500,502,1,0,0,0,501,496,1,0,0,0,501,497,1,0,0,0,502, + 75,1,0,0,0,503,509,5,74,0,0,504,510,3,44,22,0,505,506,5,21,0,0,506, + 507,3,6,3,0,507,508,5,22,0,0,508,510,1,0,0,0,509,504,1,0,0,0,509, + 505,1,0,0,0,510,77,1,0,0,0,511,512,5,73,0,0,512,513,5,21,0,0,513, + 514,3,4,2,0,514,515,5,22,0,0,515,79,1,0,0,0,516,517,5,73,0,0,517, + 518,5,21,0,0,518,519,3,4,2,0,519,520,5,22,0,0,520,81,1,0,0,0,59, + 92,109,120,131,139,141,149,152,158,165,170,178,184,192,206,209,213, + 226,229,233,242,249,267,276,284,291,293,297,300,303,306,308,317, + 335,344,351,373,376,379,382,384,391,395,398,401,404,406,419,422, + 427,434,452,460,467,480,482,491,501,509 + ] + +class LaTeXParser ( Parser ): + + grammarFileName = "LaTeX.g4" + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + sharedContextCache = PredictionContextCache() + + literalNames = [ "", "','", "'.'", "", "", + "", "", "'\\quad'", "'\\qquad'", + "", "'\\negmedspace'", "'\\negthickspace'", + "'\\left'", "'\\right'", "", "'+'", "'-'", + "'*'", "'/'", "'('", "')'", "'{'", "'}'", "'\\{'", + "'\\}'", "'['", "']'", "'|'", "'\\right|'", "'\\left|'", + "'\\langle'", "'\\rangle'", "'\\lim'", "", + "", "'\\sum'", "'\\prod'", "'\\exp'", "'\\log'", + "'\\lg'", "'\\ln'", "'\\sin'", "'\\cos'", "'\\tan'", + "'\\csc'", "'\\sec'", "'\\cot'", "'\\arcsin'", "'\\arccos'", + "'\\arctan'", "'\\arccsc'", "'\\arcsec'", "'\\arccot'", + "'\\sinh'", "'\\cosh'", "'\\tanh'", "'\\arsinh'", "'\\arcosh'", + "'\\artanh'", "'\\lfloor'", "'\\rfloor'", "'\\lceil'", + "'\\rceil'", "'\\sqrt'", "'\\overline'", "'\\times'", + "'\\cdot'", "'\\div'", "", "'\\binom'", "'\\dbinom'", + "'\\tbinom'", "'\\mathit'", "'_'", "'^'", "':'", "", + "", "", "", "'\\neq'", "'<'", + "", "'\\leqq'", "'\\leqslant'", "'>'", "", + "'\\geqq'", "'\\geqslant'", "'!'" ] + + symbolicNames = [ "", "", "", "WS", "THINSPACE", + "MEDSPACE", "THICKSPACE", "QUAD", "QQUAD", "NEGTHINSPACE", + "NEGMEDSPACE", "NEGTHICKSPACE", "CMD_LEFT", "CMD_RIGHT", + "IGNORE", "ADD", "SUB", "MUL", "DIV", "L_PAREN", "R_PAREN", + "L_BRACE", "R_BRACE", "L_BRACE_LITERAL", "R_BRACE_LITERAL", + "L_BRACKET", "R_BRACKET", "BAR", "R_BAR", "L_BAR", + "L_ANGLE", "R_ANGLE", "FUNC_LIM", "LIM_APPROACH_SYM", + "FUNC_INT", "FUNC_SUM", "FUNC_PROD", "FUNC_EXP", "FUNC_LOG", + "FUNC_LG", "FUNC_LN", "FUNC_SIN", "FUNC_COS", "FUNC_TAN", + "FUNC_CSC", "FUNC_SEC", "FUNC_COT", "FUNC_ARCSIN", + "FUNC_ARCCOS", "FUNC_ARCTAN", "FUNC_ARCCSC", "FUNC_ARCSEC", + "FUNC_ARCCOT", "FUNC_SINH", "FUNC_COSH", "FUNC_TANH", + "FUNC_ARSINH", "FUNC_ARCOSH", "FUNC_ARTANH", "L_FLOOR", + "R_FLOOR", "L_CEIL", "R_CEIL", "FUNC_SQRT", "FUNC_OVERLINE", + "CMD_TIMES", "CMD_CDOT", "CMD_DIV", "CMD_FRAC", "CMD_BINOM", + "CMD_DBINOM", "CMD_TBINOM", "CMD_MATHIT", "UNDERSCORE", + "CARET", "COLON", "DIFFERENTIAL", "LETTER", "DIGIT", + "EQUAL", "NEQ", "LT", "LTE", "LTE_Q", "LTE_S", "GT", + "GTE", "GTE_Q", "GTE_S", "BANG", "SINGLE_QUOTES", + "SYMBOL" ] + + RULE_math = 0 + RULE_relation = 1 + RULE_equality = 2 + RULE_expr = 3 + RULE_additive = 4 + RULE_mp = 5 + RULE_mp_nofunc = 6 + RULE_unary = 7 + RULE_unary_nofunc = 8 + RULE_postfix = 9 + RULE_postfix_nofunc = 10 + RULE_postfix_op = 11 + RULE_eval_at = 12 + RULE_eval_at_sub = 13 + RULE_eval_at_sup = 14 + RULE_exp = 15 + RULE_exp_nofunc = 16 + RULE_comp = 17 + RULE_comp_nofunc = 18 + RULE_group = 19 + RULE_abs_group = 20 + RULE_number = 21 + RULE_atom = 22 + RULE_bra = 23 + RULE_ket = 24 + RULE_mathit = 25 + RULE_mathit_text = 26 + RULE_frac = 27 + RULE_binom = 28 + RULE_floor = 29 + RULE_ceil = 30 + RULE_func_normal = 31 + RULE_func = 32 + RULE_args = 33 + RULE_limit_sub = 34 + RULE_func_arg = 35 + RULE_func_arg_noparens = 36 + RULE_subexpr = 37 + RULE_supexpr = 38 + RULE_subeq = 39 + RULE_supeq = 40 + + ruleNames = [ "math", "relation", "equality", "expr", "additive", "mp", + "mp_nofunc", "unary", "unary_nofunc", "postfix", "postfix_nofunc", + "postfix_op", "eval_at", "eval_at_sub", "eval_at_sup", + "exp", "exp_nofunc", "comp", "comp_nofunc", "group", + "abs_group", "number", "atom", "bra", "ket", "mathit", + "mathit_text", "frac", "binom", "floor", "ceil", "func_normal", + "func", "args", "limit_sub", "func_arg", "func_arg_noparens", + "subexpr", "supexpr", "subeq", "supeq" ] + + EOF = Token.EOF + T__0=1 + T__1=2 + WS=3 + THINSPACE=4 + MEDSPACE=5 + THICKSPACE=6 + QUAD=7 + QQUAD=8 + NEGTHINSPACE=9 + NEGMEDSPACE=10 + NEGTHICKSPACE=11 + CMD_LEFT=12 + CMD_RIGHT=13 + IGNORE=14 + ADD=15 + SUB=16 + MUL=17 + DIV=18 + L_PAREN=19 + R_PAREN=20 + L_BRACE=21 + R_BRACE=22 + L_BRACE_LITERAL=23 + R_BRACE_LITERAL=24 + L_BRACKET=25 + R_BRACKET=26 + BAR=27 + R_BAR=28 + L_BAR=29 + L_ANGLE=30 + R_ANGLE=31 + FUNC_LIM=32 + LIM_APPROACH_SYM=33 + FUNC_INT=34 + FUNC_SUM=35 + FUNC_PROD=36 + FUNC_EXP=37 + FUNC_LOG=38 + FUNC_LG=39 + FUNC_LN=40 + FUNC_SIN=41 + FUNC_COS=42 + FUNC_TAN=43 + FUNC_CSC=44 + FUNC_SEC=45 + FUNC_COT=46 + FUNC_ARCSIN=47 + FUNC_ARCCOS=48 + FUNC_ARCTAN=49 + FUNC_ARCCSC=50 + FUNC_ARCSEC=51 + FUNC_ARCCOT=52 + FUNC_SINH=53 + FUNC_COSH=54 + FUNC_TANH=55 + FUNC_ARSINH=56 + FUNC_ARCOSH=57 + FUNC_ARTANH=58 + L_FLOOR=59 + R_FLOOR=60 + L_CEIL=61 + R_CEIL=62 + FUNC_SQRT=63 + FUNC_OVERLINE=64 + CMD_TIMES=65 + CMD_CDOT=66 + CMD_DIV=67 + CMD_FRAC=68 + CMD_BINOM=69 + CMD_DBINOM=70 + CMD_TBINOM=71 + CMD_MATHIT=72 + UNDERSCORE=73 + CARET=74 + COLON=75 + DIFFERENTIAL=76 + LETTER=77 + DIGIT=78 + EQUAL=79 + NEQ=80 + LT=81 + LTE=82 + LTE_Q=83 + LTE_S=84 + GT=85 + GTE=86 + GTE_Q=87 + GTE_S=88 + BANG=89 + SINGLE_QUOTES=90 + SYMBOL=91 + + def __init__(self, input:TokenStream, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.11.1") + self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) + self._predicates = None + + + + + class MathContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def relation(self): + return self.getTypedRuleContext(LaTeXParser.RelationContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_math + + + + + def math(self): + + localctx = LaTeXParser.MathContext(self, self._ctx, self.state) + self.enterRule(localctx, 0, self.RULE_math) + try: + self.enterOuterAlt(localctx, 1) + self.state = 82 + self.relation(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class RelationContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def relation(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.RelationContext) + else: + return self.getTypedRuleContext(LaTeXParser.RelationContext,i) + + + def EQUAL(self): + return self.getToken(LaTeXParser.EQUAL, 0) + + def LT(self): + return self.getToken(LaTeXParser.LT, 0) + + def LTE(self): + return self.getToken(LaTeXParser.LTE, 0) + + def GT(self): + return self.getToken(LaTeXParser.GT, 0) + + def GTE(self): + return self.getToken(LaTeXParser.GTE, 0) + + def NEQ(self): + return self.getToken(LaTeXParser.NEQ, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_relation + + + + def relation(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.RelationContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 2 + self.enterRecursionRule(localctx, 2, self.RULE_relation, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 85 + self.expr() + self._ctx.stop = self._input.LT(-1) + self.state = 92 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,0,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.RelationContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_relation) + self.state = 87 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 88 + _la = self._input.LA(1) + if not((((_la - 79)) & ~0x3f) == 0 and ((1 << (_la - 79)) & 207) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 89 + self.relation(3) + self.state = 94 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,0,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class EqualityContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.ExprContext) + else: + return self.getTypedRuleContext(LaTeXParser.ExprContext,i) + + + def EQUAL(self): + return self.getToken(LaTeXParser.EQUAL, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_equality + + + + + def equality(self): + + localctx = LaTeXParser.EqualityContext(self, self._ctx, self.state) + self.enterRule(localctx, 4, self.RULE_equality) + try: + self.enterOuterAlt(localctx, 1) + self.state = 95 + self.expr() + self.state = 96 + self.match(LaTeXParser.EQUAL) + self.state = 97 + self.expr() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ExprContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def additive(self): + return self.getTypedRuleContext(LaTeXParser.AdditiveContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_expr + + + + + def expr(self): + + localctx = LaTeXParser.ExprContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_expr) + try: + self.enterOuterAlt(localctx, 1) + self.state = 99 + self.additive(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AdditiveContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def mp(self): + return self.getTypedRuleContext(LaTeXParser.MpContext,0) + + + def additive(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.AdditiveContext) + else: + return self.getTypedRuleContext(LaTeXParser.AdditiveContext,i) + + + def ADD(self): + return self.getToken(LaTeXParser.ADD, 0) + + def SUB(self): + return self.getToken(LaTeXParser.SUB, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_additive + + + + def additive(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.AdditiveContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 8 + self.enterRecursionRule(localctx, 8, self.RULE_additive, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 102 + self.mp(0) + self._ctx.stop = self._input.LT(-1) + self.state = 109 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,1,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.AdditiveContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_additive) + self.state = 104 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 105 + _la = self._input.LA(1) + if not(_la==15 or _la==16): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 106 + self.additive(3) + self.state = 111 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,1,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class MpContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def unary(self): + return self.getTypedRuleContext(LaTeXParser.UnaryContext,0) + + + def mp(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.MpContext) + else: + return self.getTypedRuleContext(LaTeXParser.MpContext,i) + + + def MUL(self): + return self.getToken(LaTeXParser.MUL, 0) + + def CMD_TIMES(self): + return self.getToken(LaTeXParser.CMD_TIMES, 0) + + def CMD_CDOT(self): + return self.getToken(LaTeXParser.CMD_CDOT, 0) + + def DIV(self): + return self.getToken(LaTeXParser.DIV, 0) + + def CMD_DIV(self): + return self.getToken(LaTeXParser.CMD_DIV, 0) + + def COLON(self): + return self.getToken(LaTeXParser.COLON, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_mp + + + + def mp(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.MpContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 10 + self.enterRecursionRule(localctx, 10, self.RULE_mp, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 113 + self.unary() + self._ctx.stop = self._input.LT(-1) + self.state = 120 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,2,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.MpContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_mp) + self.state = 115 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 116 + _la = self._input.LA(1) + if not((((_la - 17)) & ~0x3f) == 0 and ((1 << (_la - 17)) & 290200700988686339) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 117 + self.mp(3) + self.state = 122 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,2,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class Mp_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def unary_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Unary_nofuncContext,0) + + + def mp_nofunc(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.Mp_nofuncContext) + else: + return self.getTypedRuleContext(LaTeXParser.Mp_nofuncContext,i) + + + def MUL(self): + return self.getToken(LaTeXParser.MUL, 0) + + def CMD_TIMES(self): + return self.getToken(LaTeXParser.CMD_TIMES, 0) + + def CMD_CDOT(self): + return self.getToken(LaTeXParser.CMD_CDOT, 0) + + def DIV(self): + return self.getToken(LaTeXParser.DIV, 0) + + def CMD_DIV(self): + return self.getToken(LaTeXParser.CMD_DIV, 0) + + def COLON(self): + return self.getToken(LaTeXParser.COLON, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_mp_nofunc + + + + def mp_nofunc(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.Mp_nofuncContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 12 + self.enterRecursionRule(localctx, 12, self.RULE_mp_nofunc, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 124 + self.unary_nofunc() + self._ctx.stop = self._input.LT(-1) + self.state = 131 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,3,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.Mp_nofuncContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_mp_nofunc) + self.state = 126 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 127 + _la = self._input.LA(1) + if not((((_la - 17)) & ~0x3f) == 0 and ((1 << (_la - 17)) & 290200700988686339) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 128 + self.mp_nofunc(3) + self.state = 133 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,3,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class UnaryContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def unary(self): + return self.getTypedRuleContext(LaTeXParser.UnaryContext,0) + + + def ADD(self): + return self.getToken(LaTeXParser.ADD, 0) + + def SUB(self): + return self.getToken(LaTeXParser.SUB, 0) + + def postfix(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.PostfixContext) + else: + return self.getTypedRuleContext(LaTeXParser.PostfixContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_unary + + + + + def unary(self): + + localctx = LaTeXParser.UnaryContext(self, self._ctx, self.state) + self.enterRule(localctx, 14, self.RULE_unary) + self._la = 0 # Token type + try: + self.state = 141 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [15, 16]: + self.enterOuterAlt(localctx, 1) + self.state = 134 + _la = self._input.LA(1) + if not(_la==15 or _la==16): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 135 + self.unary() + pass + elif token in [19, 21, 23, 25, 27, 29, 30, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 63, 64, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.enterOuterAlt(localctx, 2) + self.state = 137 + self._errHandler.sync(self) + _alt = 1 + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt == 1: + self.state = 136 + self.postfix() + + else: + raise NoViableAltException(self) + self.state = 139 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,4,self._ctx) + + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Unary_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def unary_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Unary_nofuncContext,0) + + + def ADD(self): + return self.getToken(LaTeXParser.ADD, 0) + + def SUB(self): + return self.getToken(LaTeXParser.SUB, 0) + + def postfix(self): + return self.getTypedRuleContext(LaTeXParser.PostfixContext,0) + + + def postfix_nofunc(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.Postfix_nofuncContext) + else: + return self.getTypedRuleContext(LaTeXParser.Postfix_nofuncContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_unary_nofunc + + + + + def unary_nofunc(self): + + localctx = LaTeXParser.Unary_nofuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 16, self.RULE_unary_nofunc) + self._la = 0 # Token type + try: + self.state = 152 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [15, 16]: + self.enterOuterAlt(localctx, 1) + self.state = 143 + _la = self._input.LA(1) + if not(_la==15 or _la==16): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 144 + self.unary_nofunc() + pass + elif token in [19, 21, 23, 25, 27, 29, 30, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 63, 64, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.enterOuterAlt(localctx, 2) + self.state = 145 + self.postfix() + self.state = 149 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,6,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 146 + self.postfix_nofunc() + self.state = 151 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,6,self._ctx) + + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class PostfixContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def exp(self): + return self.getTypedRuleContext(LaTeXParser.ExpContext,0) + + + def postfix_op(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.Postfix_opContext) + else: + return self.getTypedRuleContext(LaTeXParser.Postfix_opContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_postfix + + + + + def postfix(self): + + localctx = LaTeXParser.PostfixContext(self, self._ctx, self.state) + self.enterRule(localctx, 18, self.RULE_postfix) + try: + self.enterOuterAlt(localctx, 1) + self.state = 154 + self.exp(0) + self.state = 158 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,8,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 155 + self.postfix_op() + self.state = 160 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,8,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Postfix_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def exp_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Exp_nofuncContext,0) + + + def postfix_op(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.Postfix_opContext) + else: + return self.getTypedRuleContext(LaTeXParser.Postfix_opContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_postfix_nofunc + + + + + def postfix_nofunc(self): + + localctx = LaTeXParser.Postfix_nofuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 20, self.RULE_postfix_nofunc) + try: + self.enterOuterAlt(localctx, 1) + self.state = 161 + self.exp_nofunc(0) + self.state = 165 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,9,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 162 + self.postfix_op() + self.state = 167 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,9,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Postfix_opContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def BANG(self): + return self.getToken(LaTeXParser.BANG, 0) + + def eval_at(self): + return self.getTypedRuleContext(LaTeXParser.Eval_atContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_postfix_op + + + + + def postfix_op(self): + + localctx = LaTeXParser.Postfix_opContext(self, self._ctx, self.state) + self.enterRule(localctx, 22, self.RULE_postfix_op) + try: + self.state = 170 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [89]: + self.enterOuterAlt(localctx, 1) + self.state = 168 + self.match(LaTeXParser.BANG) + pass + elif token in [27]: + self.enterOuterAlt(localctx, 2) + self.state = 169 + self.eval_at() + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Eval_atContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def BAR(self): + return self.getToken(LaTeXParser.BAR, 0) + + def eval_at_sup(self): + return self.getTypedRuleContext(LaTeXParser.Eval_at_supContext,0) + + + def eval_at_sub(self): + return self.getTypedRuleContext(LaTeXParser.Eval_at_subContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_eval_at + + + + + def eval_at(self): + + localctx = LaTeXParser.Eval_atContext(self, self._ctx, self.state) + self.enterRule(localctx, 24, self.RULE_eval_at) + try: + self.enterOuterAlt(localctx, 1) + self.state = 172 + self.match(LaTeXParser.BAR) + self.state = 178 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,11,self._ctx) + if la_ == 1: + self.state = 173 + self.eval_at_sup() + pass + + elif la_ == 2: + self.state = 174 + self.eval_at_sub() + pass + + elif la_ == 3: + self.state = 175 + self.eval_at_sup() + self.state = 176 + self.eval_at_sub() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Eval_at_subContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def equality(self): + return self.getTypedRuleContext(LaTeXParser.EqualityContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_eval_at_sub + + + + + def eval_at_sub(self): + + localctx = LaTeXParser.Eval_at_subContext(self, self._ctx, self.state) + self.enterRule(localctx, 26, self.RULE_eval_at_sub) + try: + self.enterOuterAlt(localctx, 1) + self.state = 180 + self.match(LaTeXParser.UNDERSCORE) + self.state = 181 + self.match(LaTeXParser.L_BRACE) + self.state = 184 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,12,self._ctx) + if la_ == 1: + self.state = 182 + self.expr() + pass + + elif la_ == 2: + self.state = 183 + self.equality() + pass + + + self.state = 186 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Eval_at_supContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def equality(self): + return self.getTypedRuleContext(LaTeXParser.EqualityContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_eval_at_sup + + + + + def eval_at_sup(self): + + localctx = LaTeXParser.Eval_at_supContext(self, self._ctx, self.state) + self.enterRule(localctx, 28, self.RULE_eval_at_sup) + try: + self.enterOuterAlt(localctx, 1) + self.state = 188 + self.match(LaTeXParser.CARET) + self.state = 189 + self.match(LaTeXParser.L_BRACE) + self.state = 192 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,13,self._ctx) + if la_ == 1: + self.state = 190 + self.expr() + pass + + elif la_ == 2: + self.state = 191 + self.equality() + pass + + + self.state = 194 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ExpContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def comp(self): + return self.getTypedRuleContext(LaTeXParser.CompContext,0) + + + def exp(self): + return self.getTypedRuleContext(LaTeXParser.ExpContext,0) + + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def subexpr(self): + return self.getTypedRuleContext(LaTeXParser.SubexprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_exp + + + + def exp(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.ExpContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 30 + self.enterRecursionRule(localctx, 30, self.RULE_exp, _p) + try: + self.enterOuterAlt(localctx, 1) + self.state = 197 + self.comp() + self._ctx.stop = self._input.LT(-1) + self.state = 213 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,16,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.ExpContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_exp) + self.state = 199 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 200 + self.match(LaTeXParser.CARET) + self.state = 206 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [27, 29, 30, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.state = 201 + self.atom() + pass + elif token in [21]: + self.state = 202 + self.match(LaTeXParser.L_BRACE) + self.state = 203 + self.expr() + self.state = 204 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + self.state = 209 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,15,self._ctx) + if la_ == 1: + self.state = 208 + self.subexpr() + + + self.state = 215 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,16,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class Exp_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def comp_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Comp_nofuncContext,0) + + + def exp_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Exp_nofuncContext,0) + + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def subexpr(self): + return self.getTypedRuleContext(LaTeXParser.SubexprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_exp_nofunc + + + + def exp_nofunc(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.Exp_nofuncContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 32 + self.enterRecursionRule(localctx, 32, self.RULE_exp_nofunc, _p) + try: + self.enterOuterAlt(localctx, 1) + self.state = 217 + self.comp_nofunc() + self._ctx.stop = self._input.LT(-1) + self.state = 233 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,19,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.Exp_nofuncContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_exp_nofunc) + self.state = 219 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 220 + self.match(LaTeXParser.CARET) + self.state = 226 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [27, 29, 30, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.state = 221 + self.atom() + pass + elif token in [21]: + self.state = 222 + self.match(LaTeXParser.L_BRACE) + self.state = 223 + self.expr() + self.state = 224 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + self.state = 229 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,18,self._ctx) + if la_ == 1: + self.state = 228 + self.subexpr() + + + self.state = 235 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,19,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class CompContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def group(self): + return self.getTypedRuleContext(LaTeXParser.GroupContext,0) + + + def abs_group(self): + return self.getTypedRuleContext(LaTeXParser.Abs_groupContext,0) + + + def func(self): + return self.getTypedRuleContext(LaTeXParser.FuncContext,0) + + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def floor(self): + return self.getTypedRuleContext(LaTeXParser.FloorContext,0) + + + def ceil(self): + return self.getTypedRuleContext(LaTeXParser.CeilContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_comp + + + + + def comp(self): + + localctx = LaTeXParser.CompContext(self, self._ctx, self.state) + self.enterRule(localctx, 34, self.RULE_comp) + try: + self.state = 242 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,20,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 236 + self.group() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 237 + self.abs_group() + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 238 + self.func() + pass + + elif la_ == 4: + self.enterOuterAlt(localctx, 4) + self.state = 239 + self.atom() + pass + + elif la_ == 5: + self.enterOuterAlt(localctx, 5) + self.state = 240 + self.floor() + pass + + elif la_ == 6: + self.enterOuterAlt(localctx, 6) + self.state = 241 + self.ceil() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Comp_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def group(self): + return self.getTypedRuleContext(LaTeXParser.GroupContext,0) + + + def abs_group(self): + return self.getTypedRuleContext(LaTeXParser.Abs_groupContext,0) + + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def floor(self): + return self.getTypedRuleContext(LaTeXParser.FloorContext,0) + + + def ceil(self): + return self.getTypedRuleContext(LaTeXParser.CeilContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_comp_nofunc + + + + + def comp_nofunc(self): + + localctx = LaTeXParser.Comp_nofuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 36, self.RULE_comp_nofunc) + try: + self.state = 249 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,21,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 244 + self.group() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 245 + self.abs_group() + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 246 + self.atom() + pass + + elif la_ == 4: + self.enterOuterAlt(localctx, 4) + self.state = 247 + self.floor() + pass + + elif la_ == 5: + self.enterOuterAlt(localctx, 5) + self.state = 248 + self.ceil() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class GroupContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def L_PAREN(self): + return self.getToken(LaTeXParser.L_PAREN, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_PAREN(self): + return self.getToken(LaTeXParser.R_PAREN, 0) + + def L_BRACKET(self): + return self.getToken(LaTeXParser.L_BRACKET, 0) + + def R_BRACKET(self): + return self.getToken(LaTeXParser.R_BRACKET, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def L_BRACE_LITERAL(self): + return self.getToken(LaTeXParser.L_BRACE_LITERAL, 0) + + def R_BRACE_LITERAL(self): + return self.getToken(LaTeXParser.R_BRACE_LITERAL, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_group + + + + + def group(self): + + localctx = LaTeXParser.GroupContext(self, self._ctx, self.state) + self.enterRule(localctx, 38, self.RULE_group) + try: + self.state = 267 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [19]: + self.enterOuterAlt(localctx, 1) + self.state = 251 + self.match(LaTeXParser.L_PAREN) + self.state = 252 + self.expr() + self.state = 253 + self.match(LaTeXParser.R_PAREN) + pass + elif token in [25]: + self.enterOuterAlt(localctx, 2) + self.state = 255 + self.match(LaTeXParser.L_BRACKET) + self.state = 256 + self.expr() + self.state = 257 + self.match(LaTeXParser.R_BRACKET) + pass + elif token in [21]: + self.enterOuterAlt(localctx, 3) + self.state = 259 + self.match(LaTeXParser.L_BRACE) + self.state = 260 + self.expr() + self.state = 261 + self.match(LaTeXParser.R_BRACE) + pass + elif token in [23]: + self.enterOuterAlt(localctx, 4) + self.state = 263 + self.match(LaTeXParser.L_BRACE_LITERAL) + self.state = 264 + self.expr() + self.state = 265 + self.match(LaTeXParser.R_BRACE_LITERAL) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Abs_groupContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def BAR(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.BAR) + else: + return self.getToken(LaTeXParser.BAR, i) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_abs_group + + + + + def abs_group(self): + + localctx = LaTeXParser.Abs_groupContext(self, self._ctx, self.state) + self.enterRule(localctx, 40, self.RULE_abs_group) + try: + self.enterOuterAlt(localctx, 1) + self.state = 269 + self.match(LaTeXParser.BAR) + self.state = 270 + self.expr() + self.state = 271 + self.match(LaTeXParser.BAR) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class NumberContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def DIGIT(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.DIGIT) + else: + return self.getToken(LaTeXParser.DIGIT, i) + + def getRuleIndex(self): + return LaTeXParser.RULE_number + + + + + def number(self): + + localctx = LaTeXParser.NumberContext(self, self._ctx, self.state) + self.enterRule(localctx, 42, self.RULE_number) + try: + self.enterOuterAlt(localctx, 1) + self.state = 274 + self._errHandler.sync(self) + _alt = 1 + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt == 1: + self.state = 273 + self.match(LaTeXParser.DIGIT) + + else: + raise NoViableAltException(self) + self.state = 276 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,23,self._ctx) + + self.state = 284 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,24,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 278 + self.match(LaTeXParser.T__0) + self.state = 279 + self.match(LaTeXParser.DIGIT) + self.state = 280 + self.match(LaTeXParser.DIGIT) + self.state = 281 + self.match(LaTeXParser.DIGIT) + self.state = 286 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,24,self._ctx) + + self.state = 293 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,26,self._ctx) + if la_ == 1: + self.state = 287 + self.match(LaTeXParser.T__1) + self.state = 289 + self._errHandler.sync(self) + _alt = 1 + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt == 1: + self.state = 288 + self.match(LaTeXParser.DIGIT) + + else: + raise NoViableAltException(self) + self.state = 291 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,25,self._ctx) + + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AtomContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def LETTER(self): + return self.getToken(LaTeXParser.LETTER, 0) + + def SYMBOL(self): + return self.getToken(LaTeXParser.SYMBOL, 0) + + def subexpr(self): + return self.getTypedRuleContext(LaTeXParser.SubexprContext,0) + + + def SINGLE_QUOTES(self): + return self.getToken(LaTeXParser.SINGLE_QUOTES, 0) + + def number(self): + return self.getTypedRuleContext(LaTeXParser.NumberContext,0) + + + def DIFFERENTIAL(self): + return self.getToken(LaTeXParser.DIFFERENTIAL, 0) + + def mathit(self): + return self.getTypedRuleContext(LaTeXParser.MathitContext,0) + + + def frac(self): + return self.getTypedRuleContext(LaTeXParser.FracContext,0) + + + def binom(self): + return self.getTypedRuleContext(LaTeXParser.BinomContext,0) + + + def bra(self): + return self.getTypedRuleContext(LaTeXParser.BraContext,0) + + + def ket(self): + return self.getTypedRuleContext(LaTeXParser.KetContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_atom + + + + + def atom(self): + + localctx = LaTeXParser.AtomContext(self, self._ctx, self.state) + self.enterRule(localctx, 44, self.RULE_atom) + self._la = 0 # Token type + try: + self.state = 317 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [77, 91]: + self.enterOuterAlt(localctx, 1) + self.state = 295 + _la = self._input.LA(1) + if not(_la==77 or _la==91): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 308 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,31,self._ctx) + if la_ == 1: + self.state = 297 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,27,self._ctx) + if la_ == 1: + self.state = 296 + self.subexpr() + + + self.state = 300 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,28,self._ctx) + if la_ == 1: + self.state = 299 + self.match(LaTeXParser.SINGLE_QUOTES) + + + pass + + elif la_ == 2: + self.state = 303 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,29,self._ctx) + if la_ == 1: + self.state = 302 + self.match(LaTeXParser.SINGLE_QUOTES) + + + self.state = 306 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,30,self._ctx) + if la_ == 1: + self.state = 305 + self.subexpr() + + + pass + + + pass + elif token in [78]: + self.enterOuterAlt(localctx, 2) + self.state = 310 + self.number() + pass + elif token in [76]: + self.enterOuterAlt(localctx, 3) + self.state = 311 + self.match(LaTeXParser.DIFFERENTIAL) + pass + elif token in [72]: + self.enterOuterAlt(localctx, 4) + self.state = 312 + self.mathit() + pass + elif token in [68]: + self.enterOuterAlt(localctx, 5) + self.state = 313 + self.frac() + pass + elif token in [69, 70, 71]: + self.enterOuterAlt(localctx, 6) + self.state = 314 + self.binom() + pass + elif token in [30]: + self.enterOuterAlt(localctx, 7) + self.state = 315 + self.bra() + pass + elif token in [27, 29]: + self.enterOuterAlt(localctx, 8) + self.state = 316 + self.ket() + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class BraContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def L_ANGLE(self): + return self.getToken(LaTeXParser.L_ANGLE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BAR(self): + return self.getToken(LaTeXParser.R_BAR, 0) + + def BAR(self): + return self.getToken(LaTeXParser.BAR, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_bra + + + + + def bra(self): + + localctx = LaTeXParser.BraContext(self, self._ctx, self.state) + self.enterRule(localctx, 46, self.RULE_bra) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 319 + self.match(LaTeXParser.L_ANGLE) + self.state = 320 + self.expr() + self.state = 321 + _la = self._input.LA(1) + if not(_la==27 or _la==28): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class KetContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_ANGLE(self): + return self.getToken(LaTeXParser.R_ANGLE, 0) + + def L_BAR(self): + return self.getToken(LaTeXParser.L_BAR, 0) + + def BAR(self): + return self.getToken(LaTeXParser.BAR, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_ket + + + + + def ket(self): + + localctx = LaTeXParser.KetContext(self, self._ctx, self.state) + self.enterRule(localctx, 48, self.RULE_ket) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 323 + _la = self._input.LA(1) + if not(_la==27 or _la==29): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 324 + self.expr() + self.state = 325 + self.match(LaTeXParser.R_ANGLE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MathitContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CMD_MATHIT(self): + return self.getToken(LaTeXParser.CMD_MATHIT, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def mathit_text(self): + return self.getTypedRuleContext(LaTeXParser.Mathit_textContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_mathit + + + + + def mathit(self): + + localctx = LaTeXParser.MathitContext(self, self._ctx, self.state) + self.enterRule(localctx, 50, self.RULE_mathit) + try: + self.enterOuterAlt(localctx, 1) + self.state = 327 + self.match(LaTeXParser.CMD_MATHIT) + self.state = 328 + self.match(LaTeXParser.L_BRACE) + self.state = 329 + self.mathit_text() + self.state = 330 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Mathit_textContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def LETTER(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.LETTER) + else: + return self.getToken(LaTeXParser.LETTER, i) + + def getRuleIndex(self): + return LaTeXParser.RULE_mathit_text + + + + + def mathit_text(self): + + localctx = LaTeXParser.Mathit_textContext(self, self._ctx, self.state) + self.enterRule(localctx, 52, self.RULE_mathit_text) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 335 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==77: + self.state = 332 + self.match(LaTeXParser.LETTER) + self.state = 337 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class FracContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.upperd = None # Token + self.upper = None # ExprContext + self.lowerd = None # Token + self.lower = None # ExprContext + + def CMD_FRAC(self): + return self.getToken(LaTeXParser.CMD_FRAC, 0) + + def L_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.L_BRACE) + else: + return self.getToken(LaTeXParser.L_BRACE, i) + + def R_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.R_BRACE) + else: + return self.getToken(LaTeXParser.R_BRACE, i) + + def DIGIT(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.DIGIT) + else: + return self.getToken(LaTeXParser.DIGIT, i) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.ExprContext) + else: + return self.getTypedRuleContext(LaTeXParser.ExprContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_frac + + + + + def frac(self): + + localctx = LaTeXParser.FracContext(self, self._ctx, self.state) + self.enterRule(localctx, 54, self.RULE_frac) + try: + self.enterOuterAlt(localctx, 1) + self.state = 338 + self.match(LaTeXParser.CMD_FRAC) + self.state = 344 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [78]: + self.state = 339 + localctx.upperd = self.match(LaTeXParser.DIGIT) + pass + elif token in [21]: + self.state = 340 + self.match(LaTeXParser.L_BRACE) + self.state = 341 + localctx.upper = self.expr() + self.state = 342 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + self.state = 351 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [78]: + self.state = 346 + localctx.lowerd = self.match(LaTeXParser.DIGIT) + pass + elif token in [21]: + self.state = 347 + self.match(LaTeXParser.L_BRACE) + self.state = 348 + localctx.lower = self.expr() + self.state = 349 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class BinomContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.n = None # ExprContext + self.k = None # ExprContext + + def L_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.L_BRACE) + else: + return self.getToken(LaTeXParser.L_BRACE, i) + + def R_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.R_BRACE) + else: + return self.getToken(LaTeXParser.R_BRACE, i) + + def CMD_BINOM(self): + return self.getToken(LaTeXParser.CMD_BINOM, 0) + + def CMD_DBINOM(self): + return self.getToken(LaTeXParser.CMD_DBINOM, 0) + + def CMD_TBINOM(self): + return self.getToken(LaTeXParser.CMD_TBINOM, 0) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.ExprContext) + else: + return self.getTypedRuleContext(LaTeXParser.ExprContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_binom + + + + + def binom(self): + + localctx = LaTeXParser.BinomContext(self, self._ctx, self.state) + self.enterRule(localctx, 56, self.RULE_binom) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 353 + _la = self._input.LA(1) + if not((((_la - 69)) & ~0x3f) == 0 and ((1 << (_la - 69)) & 7) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 354 + self.match(LaTeXParser.L_BRACE) + self.state = 355 + localctx.n = self.expr() + self.state = 356 + self.match(LaTeXParser.R_BRACE) + self.state = 357 + self.match(LaTeXParser.L_BRACE) + self.state = 358 + localctx.k = self.expr() + self.state = 359 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class FloorContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.val = None # ExprContext + + def L_FLOOR(self): + return self.getToken(LaTeXParser.L_FLOOR, 0) + + def R_FLOOR(self): + return self.getToken(LaTeXParser.R_FLOOR, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_floor + + + + + def floor(self): + + localctx = LaTeXParser.FloorContext(self, self._ctx, self.state) + self.enterRule(localctx, 58, self.RULE_floor) + try: + self.enterOuterAlt(localctx, 1) + self.state = 361 + self.match(LaTeXParser.L_FLOOR) + self.state = 362 + localctx.val = self.expr() + self.state = 363 + self.match(LaTeXParser.R_FLOOR) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class CeilContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.val = None # ExprContext + + def L_CEIL(self): + return self.getToken(LaTeXParser.L_CEIL, 0) + + def R_CEIL(self): + return self.getToken(LaTeXParser.R_CEIL, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_ceil + + + + + def ceil(self): + + localctx = LaTeXParser.CeilContext(self, self._ctx, self.state) + self.enterRule(localctx, 60, self.RULE_ceil) + try: + self.enterOuterAlt(localctx, 1) + self.state = 365 + self.match(LaTeXParser.L_CEIL) + self.state = 366 + localctx.val = self.expr() + self.state = 367 + self.match(LaTeXParser.R_CEIL) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Func_normalContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def FUNC_EXP(self): + return self.getToken(LaTeXParser.FUNC_EXP, 0) + + def FUNC_LOG(self): + return self.getToken(LaTeXParser.FUNC_LOG, 0) + + def FUNC_LG(self): + return self.getToken(LaTeXParser.FUNC_LG, 0) + + def FUNC_LN(self): + return self.getToken(LaTeXParser.FUNC_LN, 0) + + def FUNC_SIN(self): + return self.getToken(LaTeXParser.FUNC_SIN, 0) + + def FUNC_COS(self): + return self.getToken(LaTeXParser.FUNC_COS, 0) + + def FUNC_TAN(self): + return self.getToken(LaTeXParser.FUNC_TAN, 0) + + def FUNC_CSC(self): + return self.getToken(LaTeXParser.FUNC_CSC, 0) + + def FUNC_SEC(self): + return self.getToken(LaTeXParser.FUNC_SEC, 0) + + def FUNC_COT(self): + return self.getToken(LaTeXParser.FUNC_COT, 0) + + def FUNC_ARCSIN(self): + return self.getToken(LaTeXParser.FUNC_ARCSIN, 0) + + def FUNC_ARCCOS(self): + return self.getToken(LaTeXParser.FUNC_ARCCOS, 0) + + def FUNC_ARCTAN(self): + return self.getToken(LaTeXParser.FUNC_ARCTAN, 0) + + def FUNC_ARCCSC(self): + return self.getToken(LaTeXParser.FUNC_ARCCSC, 0) + + def FUNC_ARCSEC(self): + return self.getToken(LaTeXParser.FUNC_ARCSEC, 0) + + def FUNC_ARCCOT(self): + return self.getToken(LaTeXParser.FUNC_ARCCOT, 0) + + def FUNC_SINH(self): + return self.getToken(LaTeXParser.FUNC_SINH, 0) + + def FUNC_COSH(self): + return self.getToken(LaTeXParser.FUNC_COSH, 0) + + def FUNC_TANH(self): + return self.getToken(LaTeXParser.FUNC_TANH, 0) + + def FUNC_ARSINH(self): + return self.getToken(LaTeXParser.FUNC_ARSINH, 0) + + def FUNC_ARCOSH(self): + return self.getToken(LaTeXParser.FUNC_ARCOSH, 0) + + def FUNC_ARTANH(self): + return self.getToken(LaTeXParser.FUNC_ARTANH, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_func_normal + + + + + def func_normal(self): + + localctx = LaTeXParser.Func_normalContext(self, self._ctx, self.state) + self.enterRule(localctx, 62, self.RULE_func_normal) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 369 + _la = self._input.LA(1) + if not(((_la) & ~0x3f) == 0 and ((1 << _la) & 576460614864470016) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class FuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.root = None # ExprContext + self.base = None # ExprContext + + def func_normal(self): + return self.getTypedRuleContext(LaTeXParser.Func_normalContext,0) + + + def L_PAREN(self): + return self.getToken(LaTeXParser.L_PAREN, 0) + + def func_arg(self): + return self.getTypedRuleContext(LaTeXParser.Func_argContext,0) + + + def R_PAREN(self): + return self.getToken(LaTeXParser.R_PAREN, 0) + + def func_arg_noparens(self): + return self.getTypedRuleContext(LaTeXParser.Func_arg_noparensContext,0) + + + def subexpr(self): + return self.getTypedRuleContext(LaTeXParser.SubexprContext,0) + + + def supexpr(self): + return self.getTypedRuleContext(LaTeXParser.SupexprContext,0) + + + def args(self): + return self.getTypedRuleContext(LaTeXParser.ArgsContext,0) + + + def LETTER(self): + return self.getToken(LaTeXParser.LETTER, 0) + + def SYMBOL(self): + return self.getToken(LaTeXParser.SYMBOL, 0) + + def SINGLE_QUOTES(self): + return self.getToken(LaTeXParser.SINGLE_QUOTES, 0) + + def FUNC_INT(self): + return self.getToken(LaTeXParser.FUNC_INT, 0) + + def DIFFERENTIAL(self): + return self.getToken(LaTeXParser.DIFFERENTIAL, 0) + + def frac(self): + return self.getTypedRuleContext(LaTeXParser.FracContext,0) + + + def additive(self): + return self.getTypedRuleContext(LaTeXParser.AdditiveContext,0) + + + def FUNC_SQRT(self): + return self.getToken(LaTeXParser.FUNC_SQRT, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.ExprContext) + else: + return self.getTypedRuleContext(LaTeXParser.ExprContext,i) + + + def L_BRACKET(self): + return self.getToken(LaTeXParser.L_BRACKET, 0) + + def R_BRACKET(self): + return self.getToken(LaTeXParser.R_BRACKET, 0) + + def FUNC_OVERLINE(self): + return self.getToken(LaTeXParser.FUNC_OVERLINE, 0) + + def mp(self): + return self.getTypedRuleContext(LaTeXParser.MpContext,0) + + + def FUNC_SUM(self): + return self.getToken(LaTeXParser.FUNC_SUM, 0) + + def FUNC_PROD(self): + return self.getToken(LaTeXParser.FUNC_PROD, 0) + + def subeq(self): + return self.getTypedRuleContext(LaTeXParser.SubeqContext,0) + + + def FUNC_LIM(self): + return self.getToken(LaTeXParser.FUNC_LIM, 0) + + def limit_sub(self): + return self.getTypedRuleContext(LaTeXParser.Limit_subContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_func + + + + + def func(self): + + localctx = LaTeXParser.FuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 64, self.RULE_func) + self._la = 0 # Token type + try: + self.state = 460 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58]: + self.enterOuterAlt(localctx, 1) + self.state = 371 + self.func_normal() + self.state = 384 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,40,self._ctx) + if la_ == 1: + self.state = 373 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==73: + self.state = 372 + self.subexpr() + + + self.state = 376 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==74: + self.state = 375 + self.supexpr() + + + pass + + elif la_ == 2: + self.state = 379 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==74: + self.state = 378 + self.supexpr() + + + self.state = 382 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==73: + self.state = 381 + self.subexpr() + + + pass + + + self.state = 391 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,41,self._ctx) + if la_ == 1: + self.state = 386 + self.match(LaTeXParser.L_PAREN) + self.state = 387 + self.func_arg() + self.state = 388 + self.match(LaTeXParser.R_PAREN) + pass + + elif la_ == 2: + self.state = 390 + self.func_arg_noparens() + pass + + + pass + elif token in [77, 91]: + self.enterOuterAlt(localctx, 2) + self.state = 393 + _la = self._input.LA(1) + if not(_la==77 or _la==91): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 406 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,46,self._ctx) + if la_ == 1: + self.state = 395 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==73: + self.state = 394 + self.subexpr() + + + self.state = 398 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==90: + self.state = 397 + self.match(LaTeXParser.SINGLE_QUOTES) + + + pass + + elif la_ == 2: + self.state = 401 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==90: + self.state = 400 + self.match(LaTeXParser.SINGLE_QUOTES) + + + self.state = 404 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==73: + self.state = 403 + self.subexpr() + + + pass + + + self.state = 408 + self.match(LaTeXParser.L_PAREN) + self.state = 409 + self.args() + self.state = 410 + self.match(LaTeXParser.R_PAREN) + pass + elif token in [34]: + self.enterOuterAlt(localctx, 3) + self.state = 412 + self.match(LaTeXParser.FUNC_INT) + self.state = 419 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [73]: + self.state = 413 + self.subexpr() + self.state = 414 + self.supexpr() + pass + elif token in [74]: + self.state = 416 + self.supexpr() + self.state = 417 + self.subexpr() + pass + elif token in [15, 16, 19, 21, 23, 25, 27, 29, 30, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 63, 64, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + pass + else: + pass + self.state = 427 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,49,self._ctx) + if la_ == 1: + self.state = 422 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,48,self._ctx) + if la_ == 1: + self.state = 421 + self.additive(0) + + + self.state = 424 + self.match(LaTeXParser.DIFFERENTIAL) + pass + + elif la_ == 2: + self.state = 425 + self.frac() + pass + + elif la_ == 3: + self.state = 426 + self.additive(0) + pass + + + pass + elif token in [63]: + self.enterOuterAlt(localctx, 4) + self.state = 429 + self.match(LaTeXParser.FUNC_SQRT) + self.state = 434 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==25: + self.state = 430 + self.match(LaTeXParser.L_BRACKET) + self.state = 431 + localctx.root = self.expr() + self.state = 432 + self.match(LaTeXParser.R_BRACKET) + + + self.state = 436 + self.match(LaTeXParser.L_BRACE) + self.state = 437 + localctx.base = self.expr() + self.state = 438 + self.match(LaTeXParser.R_BRACE) + pass + elif token in [64]: + self.enterOuterAlt(localctx, 5) + self.state = 440 + self.match(LaTeXParser.FUNC_OVERLINE) + self.state = 441 + self.match(LaTeXParser.L_BRACE) + self.state = 442 + localctx.base = self.expr() + self.state = 443 + self.match(LaTeXParser.R_BRACE) + pass + elif token in [35, 36]: + self.enterOuterAlt(localctx, 6) + self.state = 445 + _la = self._input.LA(1) + if not(_la==35 or _la==36): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 452 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [73]: + self.state = 446 + self.subeq() + self.state = 447 + self.supexpr() + pass + elif token in [74]: + self.state = 449 + self.supexpr() + self.state = 450 + self.subeq() + pass + else: + raise NoViableAltException(self) + + self.state = 454 + self.mp(0) + pass + elif token in [32]: + self.enterOuterAlt(localctx, 7) + self.state = 456 + self.match(LaTeXParser.FUNC_LIM) + self.state = 457 + self.limit_sub() + self.state = 458 + self.mp(0) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ArgsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def args(self): + return self.getTypedRuleContext(LaTeXParser.ArgsContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_args + + + + + def args(self): + + localctx = LaTeXParser.ArgsContext(self, self._ctx, self.state) + self.enterRule(localctx, 66, self.RULE_args) + try: + self.state = 467 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,53,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 462 + self.expr() + self.state = 463 + self.match(LaTeXParser.T__0) + self.state = 464 + self.args() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 466 + self.expr() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Limit_subContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def L_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.L_BRACE) + else: + return self.getToken(LaTeXParser.L_BRACE, i) + + def LIM_APPROACH_SYM(self): + return self.getToken(LaTeXParser.LIM_APPROACH_SYM, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.R_BRACE) + else: + return self.getToken(LaTeXParser.R_BRACE, i) + + def LETTER(self): + return self.getToken(LaTeXParser.LETTER, 0) + + def SYMBOL(self): + return self.getToken(LaTeXParser.SYMBOL, 0) + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def ADD(self): + return self.getToken(LaTeXParser.ADD, 0) + + def SUB(self): + return self.getToken(LaTeXParser.SUB, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_limit_sub + + + + + def limit_sub(self): + + localctx = LaTeXParser.Limit_subContext(self, self._ctx, self.state) + self.enterRule(localctx, 68, self.RULE_limit_sub) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 469 + self.match(LaTeXParser.UNDERSCORE) + self.state = 470 + self.match(LaTeXParser.L_BRACE) + self.state = 471 + _la = self._input.LA(1) + if not(_la==77 or _la==91): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 472 + self.match(LaTeXParser.LIM_APPROACH_SYM) + self.state = 473 + self.expr() + self.state = 482 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==74: + self.state = 474 + self.match(LaTeXParser.CARET) + self.state = 480 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [21]: + self.state = 475 + self.match(LaTeXParser.L_BRACE) + self.state = 476 + _la = self._input.LA(1) + if not(_la==15 or _la==16): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 477 + self.match(LaTeXParser.R_BRACE) + pass + elif token in [15]: + self.state = 478 + self.match(LaTeXParser.ADD) + pass + elif token in [16]: + self.state = 479 + self.match(LaTeXParser.SUB) + pass + else: + raise NoViableAltException(self) + + + + self.state = 484 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Func_argContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def func_arg(self): + return self.getTypedRuleContext(LaTeXParser.Func_argContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_func_arg + + + + + def func_arg(self): + + localctx = LaTeXParser.Func_argContext(self, self._ctx, self.state) + self.enterRule(localctx, 70, self.RULE_func_arg) + try: + self.state = 491 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,56,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 486 + self.expr() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 487 + self.expr() + self.state = 488 + self.match(LaTeXParser.T__0) + self.state = 489 + self.func_arg() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Func_arg_noparensContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def mp_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Mp_nofuncContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_func_arg_noparens + + + + + def func_arg_noparens(self): + + localctx = LaTeXParser.Func_arg_noparensContext(self, self._ctx, self.state) + self.enterRule(localctx, 72, self.RULE_func_arg_noparens) + try: + self.enterOuterAlt(localctx, 1) + self.state = 493 + self.mp_nofunc(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SubexprContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_subexpr + + + + + def subexpr(self): + + localctx = LaTeXParser.SubexprContext(self, self._ctx, self.state) + self.enterRule(localctx, 74, self.RULE_subexpr) + try: + self.enterOuterAlt(localctx, 1) + self.state = 495 + self.match(LaTeXParser.UNDERSCORE) + self.state = 501 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [27, 29, 30, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.state = 496 + self.atom() + pass + elif token in [21]: + self.state = 497 + self.match(LaTeXParser.L_BRACE) + self.state = 498 + self.expr() + self.state = 499 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SupexprContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_supexpr + + + + + def supexpr(self): + + localctx = LaTeXParser.SupexprContext(self, self._ctx, self.state) + self.enterRule(localctx, 76, self.RULE_supexpr) + try: + self.enterOuterAlt(localctx, 1) + self.state = 503 + self.match(LaTeXParser.CARET) + self.state = 509 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [27, 29, 30, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.state = 504 + self.atom() + pass + elif token in [21]: + self.state = 505 + self.match(LaTeXParser.L_BRACE) + self.state = 506 + self.expr() + self.state = 507 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SubeqContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def equality(self): + return self.getTypedRuleContext(LaTeXParser.EqualityContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_subeq + + + + + def subeq(self): + + localctx = LaTeXParser.SubeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 78, self.RULE_subeq) + try: + self.enterOuterAlt(localctx, 1) + self.state = 511 + self.match(LaTeXParser.UNDERSCORE) + self.state = 512 + self.match(LaTeXParser.L_BRACE) + self.state = 513 + self.equality() + self.state = 514 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SupeqContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def equality(self): + return self.getTypedRuleContext(LaTeXParser.EqualityContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_supeq + + + + + def supeq(self): + + localctx = LaTeXParser.SupeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 80, self.RULE_supeq) + try: + self.enterOuterAlt(localctx, 1) + self.state = 516 + self.match(LaTeXParser.UNDERSCORE) + self.state = 517 + self.match(LaTeXParser.L_BRACE) + self.state = 518 + self.equality() + self.state = 519 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + + def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): + if self._predicates == None: + self._predicates = dict() + self._predicates[1] = self.relation_sempred + self._predicates[4] = self.additive_sempred + self._predicates[5] = self.mp_sempred + self._predicates[6] = self.mp_nofunc_sempred + self._predicates[15] = self.exp_sempred + self._predicates[16] = self.exp_nofunc_sempred + pred = self._predicates.get(ruleIndex, None) + if pred is None: + raise Exception("No predicate with index:" + str(ruleIndex)) + else: + return pred(localctx, predIndex) + + def relation_sempred(self, localctx:RelationContext, predIndex:int): + if predIndex == 0: + return self.precpred(self._ctx, 2) + + + def additive_sempred(self, localctx:AdditiveContext, predIndex:int): + if predIndex == 1: + return self.precpred(self._ctx, 2) + + + def mp_sempred(self, localctx:MpContext, predIndex:int): + if predIndex == 2: + return self.precpred(self._ctx, 2) + + + def mp_nofunc_sempred(self, localctx:Mp_nofuncContext, predIndex:int): + if predIndex == 3: + return self.precpred(self._ctx, 2) + + + def exp_sempred(self, localctx:ExpContext, predIndex:int): + if predIndex == 4: + return self.precpred(self._ctx, 2) + + + def exp_nofunc_sempred(self, localctx:Exp_nofuncContext, predIndex:int): + if predIndex == 5: + return self.precpred(self._ctx, 2) + + + + + diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_build_latex_antlr.py b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_build_latex_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..ee50da5b7861154823812c7773360b53dfd29ff6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_build_latex_antlr.py @@ -0,0 +1,91 @@ +import os +import subprocess +import glob + +from sympy.utilities.misc import debug + +here = os.path.dirname(__file__) +grammar_file = os.path.abspath(os.path.join(here, "LaTeX.g4")) +dir_latex_antlr = os.path.join(here, "_antlr") + +header = '''\ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated from ../LaTeX.g4, derived from latex2sympy +# latex2sympy is licensed under the MIT license +# https://github.com/augustt198/latex2sympy/blob/master/LICENSE.txt +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +''' + + +def check_antlr_version(): + debug("Checking antlr4 version...") + + try: + debug(subprocess.check_output(["antlr4"]) + .decode('utf-8').split("\n")[0]) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + debug("The 'antlr4' command line tool is not installed, " + "or not on your PATH.\n" + "> Please refer to the README.md file for more information.") + return False + + +def build_parser(output_dir=dir_latex_antlr): + check_antlr_version() + + debug("Updating ANTLR-generated code in {}".format(output_dir)) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + with open(os.path.join(output_dir, "__init__.py"), "w+") as fp: + fp.write(header) + + args = [ + "antlr4", + grammar_file, + "-o", output_dir, + # for now, not generating these as latex2sympy did not use them + "-no-visitor", + "-no-listener", + ] + + debug("Running code generation...\n\t$ {}".format(" ".join(args))) + subprocess.check_output(args, cwd=output_dir) + + debug("Applying headers, removing unnecessary files and renaming...") + # Handle case insensitive file systems. If the files are already + # generated, they will be written to latex* but LaTeX*.* won't match them. + for path in (glob.glob(os.path.join(output_dir, "LaTeX*.*")) or + glob.glob(os.path.join(output_dir, "latex*.*"))): + + # Remove files ending in .interp or .tokens as they are not needed. + if not path.endswith(".py"): + os.unlink(path) + continue + + new_path = os.path.join(output_dir, os.path.basename(path).lower()) + with open(path, 'r') as f: + lines = [line.rstrip() + '\n' for line in f] + + os.unlink(path) + + with open(new_path, "w") as out_file: + offset = 0 + while lines[offset].startswith('#'): + offset += 1 + out_file.write(header) + out_file.writelines(lines[offset:]) + + debug("\t{}".format(new_path)) + + return True + + +if __name__ == "__main__": + build_parser() diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_parse_latex_antlr.py b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_parse_latex_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..26604375b3a9622f8c1dacdb1d678d09c2c3ad41 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/_parse_latex_antlr.py @@ -0,0 +1,607 @@ +# Ported from latex2sympy by @augustt198 +# https://github.com/augustt198/latex2sympy +# See license in LICENSE.txt +from importlib.metadata import version +import sympy +from sympy.external import import_module +from sympy.printing.str import StrPrinter +from sympy.physics.quantum.state import Bra, Ket + +from .errors import LaTeXParsingError + + +LaTeXParser = LaTeXLexer = MathErrorListener = None + +try: + LaTeXParser = import_module('sympy.parsing.latex._antlr.latexparser', + import_kwargs={'fromlist': ['LaTeXParser']}).LaTeXParser + LaTeXLexer = import_module('sympy.parsing.latex._antlr.latexlexer', + import_kwargs={'fromlist': ['LaTeXLexer']}).LaTeXLexer +except Exception: + pass + +ErrorListener = import_module('antlr4.error.ErrorListener', + warn_not_installed=True, + import_kwargs={'fromlist': ['ErrorListener']} + ) + + + +if ErrorListener: + class MathErrorListener(ErrorListener.ErrorListener): # type:ignore # noqa:F811 + def __init__(self, src): + super(ErrorListener.ErrorListener, self).__init__() + self.src = src + + def syntaxError(self, recog, symbol, line, col, msg, e): + fmt = "%s\n%s\n%s" + marker = "~" * col + "^" + + if msg.startswith("missing"): + err = fmt % (msg, self.src, marker) + elif msg.startswith("no viable"): + err = fmt % ("I expected something else here", self.src, marker) + elif msg.startswith("mismatched"): + names = LaTeXParser.literalNames + expected = [ + names[i] for i in e.getExpectedTokens() if i < len(names) + ] + if len(expected) < 10: + expected = " ".join(expected) + err = (fmt % ("I expected one of these: " + expected, self.src, + marker)) + else: + err = (fmt % ("I expected something else here", self.src, + marker)) + else: + err = fmt % ("I don't understand this", self.src, marker) + raise LaTeXParsingError(err) + + +def parse_latex(sympy, strict=False): + antlr4 = import_module('antlr4') + + if None in [antlr4, MathErrorListener] or \ + not version('antlr4-python3-runtime').startswith('4.11'): + raise ImportError("LaTeX parsing requires the antlr4 Python package," + " provided by pip (antlr4-python3-runtime) or" + " conda (antlr-python-runtime), version 4.11") + + sympy = sympy.strip() + matherror = MathErrorListener(sympy) + + stream = antlr4.InputStream(sympy) + lex = LaTeXLexer(stream) + lex.removeErrorListeners() + lex.addErrorListener(matherror) + + tokens = antlr4.CommonTokenStream(lex) + parser = LaTeXParser(tokens) + + # remove default console error listener + parser.removeErrorListeners() + parser.addErrorListener(matherror) + + relation = parser.math().relation() + if strict and (relation.start.start != 0 or relation.stop.stop != len(sympy) - 1): + raise LaTeXParsingError("Invalid LaTeX") + expr = convert_relation(relation) + + return expr + + +def convert_relation(rel): + if rel.expr(): + return convert_expr(rel.expr()) + + lh = convert_relation(rel.relation(0)) + rh = convert_relation(rel.relation(1)) + if rel.LT(): + return sympy.StrictLessThan(lh, rh) + elif rel.LTE(): + return sympy.LessThan(lh, rh) + elif rel.GT(): + return sympy.StrictGreaterThan(lh, rh) + elif rel.GTE(): + return sympy.GreaterThan(lh, rh) + elif rel.EQUAL(): + return sympy.Eq(lh, rh) + elif rel.NEQ(): + return sympy.Ne(lh, rh) + + +def convert_expr(expr): + return convert_add(expr.additive()) + + +def convert_add(add): + if add.ADD(): + lh = convert_add(add.additive(0)) + rh = convert_add(add.additive(1)) + return sympy.Add(lh, rh, evaluate=False) + elif add.SUB(): + lh = convert_add(add.additive(0)) + rh = convert_add(add.additive(1)) + if hasattr(rh, "is_Atom") and rh.is_Atom: + return sympy.Add(lh, -1 * rh, evaluate=False) + return sympy.Add(lh, sympy.Mul(-1, rh, evaluate=False), evaluate=False) + else: + return convert_mp(add.mp()) + + +def convert_mp(mp): + if hasattr(mp, 'mp'): + mp_left = mp.mp(0) + mp_right = mp.mp(1) + else: + mp_left = mp.mp_nofunc(0) + mp_right = mp.mp_nofunc(1) + + if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT(): + lh = convert_mp(mp_left) + rh = convert_mp(mp_right) + return sympy.Mul(lh, rh, evaluate=False) + elif mp.DIV() or mp.CMD_DIV() or mp.COLON(): + lh = convert_mp(mp_left) + rh = convert_mp(mp_right) + return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False) + else: + if hasattr(mp, 'unary'): + return convert_unary(mp.unary()) + else: + return convert_unary(mp.unary_nofunc()) + + +def convert_unary(unary): + if hasattr(unary, 'unary'): + nested_unary = unary.unary() + else: + nested_unary = unary.unary_nofunc() + if hasattr(unary, 'postfix_nofunc'): + first = unary.postfix() + tail = unary.postfix_nofunc() + postfix = [first] + tail + else: + postfix = unary.postfix() + + if unary.ADD(): + return convert_unary(nested_unary) + elif unary.SUB(): + numabs = convert_unary(nested_unary) + # Use Integer(-n) instead of Mul(-1, n) + return -numabs + elif postfix: + return convert_postfix_list(postfix) + + +def convert_postfix_list(arr, i=0): + if i >= len(arr): + raise LaTeXParsingError("Index out of bounds") + + res = convert_postfix(arr[i]) + if isinstance(res, sympy.Expr): + if i == len(arr) - 1: + return res # nothing to multiply by + else: + if i > 0: + left = convert_postfix(arr[i - 1]) + right = convert_postfix(arr[i + 1]) + if isinstance(left, sympy.Expr) and isinstance( + right, sympy.Expr): + left_syms = convert_postfix(arr[i - 1]).atoms(sympy.Symbol) + right_syms = convert_postfix(arr[i + 1]).atoms( + sympy.Symbol) + # if the left and right sides contain no variables and the + # symbol in between is 'x', treat as multiplication. + if not (left_syms or right_syms) and str(res) == 'x': + return convert_postfix_list(arr, i + 1) + # multiply by next + return sympy.Mul( + res, convert_postfix_list(arr, i + 1), evaluate=False) + else: # must be derivative + wrt = res[0] + if i == len(arr) - 1: + raise LaTeXParsingError("Expected expression for derivative") + else: + expr = convert_postfix_list(arr, i + 1) + return sympy.Derivative(expr, wrt) + + +def do_subs(expr, at): + if at.expr(): + at_expr = convert_expr(at.expr()) + syms = at_expr.atoms(sympy.Symbol) + if len(syms) == 0: + return expr + elif len(syms) > 0: + sym = next(iter(syms)) + return expr.subs(sym, at_expr) + elif at.equality(): + lh = convert_expr(at.equality().expr(0)) + rh = convert_expr(at.equality().expr(1)) + return expr.subs(lh, rh) + + +def convert_postfix(postfix): + if hasattr(postfix, 'exp'): + exp_nested = postfix.exp() + else: + exp_nested = postfix.exp_nofunc() + + exp = convert_exp(exp_nested) + for op in postfix.postfix_op(): + if op.BANG(): + if isinstance(exp, list): + raise LaTeXParsingError("Cannot apply postfix to derivative") + exp = sympy.factorial(exp, evaluate=False) + elif op.eval_at(): + ev = op.eval_at() + at_b = None + at_a = None + if ev.eval_at_sup(): + at_b = do_subs(exp, ev.eval_at_sup()) + if ev.eval_at_sub(): + at_a = do_subs(exp, ev.eval_at_sub()) + if at_b is not None and at_a is not None: + exp = sympy.Add(at_b, -1 * at_a, evaluate=False) + elif at_b is not None: + exp = at_b + elif at_a is not None: + exp = at_a + + return exp + + +def convert_exp(exp): + if hasattr(exp, 'exp'): + exp_nested = exp.exp() + else: + exp_nested = exp.exp_nofunc() + + if exp_nested: + base = convert_exp(exp_nested) + if isinstance(base, list): + raise LaTeXParsingError("Cannot raise derivative to power") + if exp.atom(): + exponent = convert_atom(exp.atom()) + elif exp.expr(): + exponent = convert_expr(exp.expr()) + return sympy.Pow(base, exponent, evaluate=False) + else: + if hasattr(exp, 'comp'): + return convert_comp(exp.comp()) + else: + return convert_comp(exp.comp_nofunc()) + + +def convert_comp(comp): + if comp.group(): + return convert_expr(comp.group().expr()) + elif comp.abs_group(): + return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False) + elif comp.atom(): + return convert_atom(comp.atom()) + elif comp.floor(): + return convert_floor(comp.floor()) + elif comp.ceil(): + return convert_ceil(comp.ceil()) + elif comp.func(): + return convert_func(comp.func()) + + +def convert_atom(atom): + if atom.LETTER(): + sname = atom.LETTER().getText() + if atom.subexpr(): + if atom.subexpr().expr(): # subscript is expr + subscript = convert_expr(atom.subexpr().expr()) + else: # subscript is atom + subscript = convert_atom(atom.subexpr().atom()) + sname += '_{' + StrPrinter().doprint(subscript) + '}' + if atom.SINGLE_QUOTES(): + sname += atom.SINGLE_QUOTES().getText() # put after subscript for easy identify + return sympy.Symbol(sname) + elif atom.SYMBOL(): + s = atom.SYMBOL().getText()[1:] + if s == "infty": + return sympy.oo + else: + if atom.subexpr(): + subscript = None + if atom.subexpr().expr(): # subscript is expr + subscript = convert_expr(atom.subexpr().expr()) + else: # subscript is atom + subscript = convert_atom(atom.subexpr().atom()) + subscriptName = StrPrinter().doprint(subscript) + s += '_{' + subscriptName + '}' + return sympy.Symbol(s) + elif atom.number(): + s = atom.number().getText().replace(",", "") + return sympy.Number(s) + elif atom.DIFFERENTIAL(): + var = get_differential_var(atom.DIFFERENTIAL()) + return sympy.Symbol('d' + var.name) + elif atom.mathit(): + text = rule2text(atom.mathit().mathit_text()) + return sympy.Symbol(text) + elif atom.frac(): + return convert_frac(atom.frac()) + elif atom.binom(): + return convert_binom(atom.binom()) + elif atom.bra(): + val = convert_expr(atom.bra().expr()) + return Bra(val) + elif atom.ket(): + val = convert_expr(atom.ket().expr()) + return Ket(val) + + +def rule2text(ctx): + stream = ctx.start.getInputStream() + # starting index of starting token + startIdx = ctx.start.start + # stopping index of stopping token + stopIdx = ctx.stop.stop + + return stream.getText(startIdx, stopIdx) + + +def convert_frac(frac): + diff_op = False + partial_op = False + if frac.lower and frac.upper: + lower_itv = frac.lower.getSourceInterval() + lower_itv_len = lower_itv[1] - lower_itv[0] + 1 + if (frac.lower.start == frac.lower.stop + and frac.lower.start.type == LaTeXLexer.DIFFERENTIAL): + wrt = get_differential_var_str(frac.lower.start.text) + diff_op = True + elif (lower_itv_len == 2 and frac.lower.start.type == LaTeXLexer.SYMBOL + and frac.lower.start.text == '\\partial' + and (frac.lower.stop.type == LaTeXLexer.LETTER + or frac.lower.stop.type == LaTeXLexer.SYMBOL)): + partial_op = True + wrt = frac.lower.stop.text + if frac.lower.stop.type == LaTeXLexer.SYMBOL: + wrt = wrt[1:] + + if diff_op or partial_op: + wrt = sympy.Symbol(wrt) + if (diff_op and frac.upper.start == frac.upper.stop + and frac.upper.start.type == LaTeXLexer.LETTER + and frac.upper.start.text == 'd'): + return [wrt] + elif (partial_op and frac.upper.start == frac.upper.stop + and frac.upper.start.type == LaTeXLexer.SYMBOL + and frac.upper.start.text == '\\partial'): + return [wrt] + upper_text = rule2text(frac.upper) + + expr_top = None + if diff_op and upper_text.startswith('d'): + expr_top = parse_latex(upper_text[1:]) + elif partial_op and frac.upper.start.text == '\\partial': + expr_top = parse_latex(upper_text[len('\\partial'):]) + if expr_top: + return sympy.Derivative(expr_top, wrt) + if frac.upper: + expr_top = convert_expr(frac.upper) + else: + expr_top = sympy.Number(frac.upperd.text) + if frac.lower: + expr_bot = convert_expr(frac.lower) + else: + expr_bot = sympy.Number(frac.lowerd.text) + inverse_denom = sympy.Pow(expr_bot, -1, evaluate=False) + if expr_top == 1: + return inverse_denom + else: + return sympy.Mul(expr_top, inverse_denom, evaluate=False) + +def convert_binom(binom): + expr_n = convert_expr(binom.n) + expr_k = convert_expr(binom.k) + return sympy.binomial(expr_n, expr_k, evaluate=False) + +def convert_floor(floor): + val = convert_expr(floor.val) + return sympy.floor(val, evaluate=False) + +def convert_ceil(ceil): + val = convert_expr(ceil.val) + return sympy.ceiling(val, evaluate=False) + +def convert_func(func): + if func.func_normal(): + if func.L_PAREN(): # function called with parenthesis + arg = convert_func_arg(func.func_arg()) + else: + arg = convert_func_arg(func.func_arg_noparens()) + + name = func.func_normal().start.text[1:] + + # change arc -> a + if name in [ + "arcsin", "arccos", "arctan", "arccsc", "arcsec", "arccot" + ]: + name = "a" + name[3:] + expr = getattr(sympy.functions, name)(arg, evaluate=False) + if name in ["arsinh", "arcosh", "artanh"]: + name = "a" + name[2:] + expr = getattr(sympy.functions, name)(arg, evaluate=False) + + if name == "exp": + expr = sympy.exp(arg, evaluate=False) + + if name in ("log", "lg", "ln"): + if func.subexpr(): + if func.subexpr().expr(): + base = convert_expr(func.subexpr().expr()) + else: + base = convert_atom(func.subexpr().atom()) + elif name == "lg": # ISO 80000-2:2019 + base = 10 + elif name in ("ln", "log"): # SymPy's latex printer prints ln as log by default + base = sympy.E + expr = sympy.log(arg, base, evaluate=False) + + func_pow = None + should_pow = True + if func.supexpr(): + if func.supexpr().expr(): + func_pow = convert_expr(func.supexpr().expr()) + else: + func_pow = convert_atom(func.supexpr().atom()) + + if name in [ + "sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh", + "tanh" + ]: + if func_pow == -1: + name = "a" + name + should_pow = False + expr = getattr(sympy.functions, name)(arg, evaluate=False) + + if func_pow and should_pow: + expr = sympy.Pow(expr, func_pow, evaluate=False) + + return expr + elif func.LETTER() or func.SYMBOL(): + if func.LETTER(): + fname = func.LETTER().getText() + elif func.SYMBOL(): + fname = func.SYMBOL().getText()[1:] + fname = str(fname) # can't be unicode + if func.subexpr(): + if func.subexpr().expr(): # subscript is expr + subscript = convert_expr(func.subexpr().expr()) + else: # subscript is atom + subscript = convert_atom(func.subexpr().atom()) + subscriptName = StrPrinter().doprint(subscript) + fname += '_{' + subscriptName + '}' + if func.SINGLE_QUOTES(): + fname += func.SINGLE_QUOTES().getText() + input_args = func.args() + output_args = [] + while input_args.args(): # handle multiple arguments to function + output_args.append(convert_expr(input_args.expr())) + input_args = input_args.args() + output_args.append(convert_expr(input_args.expr())) + return sympy.Function(fname)(*output_args) + elif func.FUNC_INT(): + return handle_integral(func) + elif func.FUNC_SQRT(): + expr = convert_expr(func.base) + if func.root: + r = convert_expr(func.root) + return sympy.root(expr, r, evaluate=False) + else: + return sympy.sqrt(expr, evaluate=False) + elif func.FUNC_OVERLINE(): + expr = convert_expr(func.base) + return sympy.conjugate(expr, evaluate=False) + elif func.FUNC_SUM(): + return handle_sum_or_prod(func, "summation") + elif func.FUNC_PROD(): + return handle_sum_or_prod(func, "product") + elif func.FUNC_LIM(): + return handle_limit(func) + + +def convert_func_arg(arg): + if hasattr(arg, 'expr'): + return convert_expr(arg.expr()) + else: + return convert_mp(arg.mp_nofunc()) + + +def handle_integral(func): + if func.additive(): + integrand = convert_add(func.additive()) + elif func.frac(): + integrand = convert_frac(func.frac()) + else: + integrand = 1 + + int_var = None + if func.DIFFERENTIAL(): + int_var = get_differential_var(func.DIFFERENTIAL()) + else: + for sym in integrand.atoms(sympy.Symbol): + s = str(sym) + if len(s) > 1 and s[0] == 'd': + if s[1] == '\\': + int_var = sympy.Symbol(s[2:]) + else: + int_var = sympy.Symbol(s[1:]) + int_sym = sym + if int_var: + integrand = integrand.subs(int_sym, 1) + else: + # Assume dx by default + int_var = sympy.Symbol('x') + + if func.subexpr(): + if func.subexpr().atom(): + lower = convert_atom(func.subexpr().atom()) + else: + lower = convert_expr(func.subexpr().expr()) + if func.supexpr().atom(): + upper = convert_atom(func.supexpr().atom()) + else: + upper = convert_expr(func.supexpr().expr()) + return sympy.Integral(integrand, (int_var, lower, upper)) + else: + return sympy.Integral(integrand, int_var) + + +def handle_sum_or_prod(func, name): + val = convert_mp(func.mp()) + iter_var = convert_expr(func.subeq().equality().expr(0)) + start = convert_expr(func.subeq().equality().expr(1)) + if func.supexpr().expr(): # ^{expr} + end = convert_expr(func.supexpr().expr()) + else: # ^atom + end = convert_atom(func.supexpr().atom()) + + if name == "summation": + return sympy.Sum(val, (iter_var, start, end)) + elif name == "product": + return sympy.Product(val, (iter_var, start, end)) + + +def handle_limit(func): + sub = func.limit_sub() + if sub.LETTER(): + var = sympy.Symbol(sub.LETTER().getText()) + elif sub.SYMBOL(): + var = sympy.Symbol(sub.SYMBOL().getText()[1:]) + else: + var = sympy.Symbol('x') + if sub.SUB(): + direction = "-" + elif sub.ADD(): + direction = "+" + else: + direction = "+-" + approaching = convert_expr(sub.expr()) + content = convert_mp(func.mp()) + + return sympy.Limit(content, var, approaching, direction) + + +def get_differential_var(d): + text = get_differential_var_str(d.getText()) + return sympy.Symbol(text) + + +def get_differential_var_str(text): + for i in range(1, len(text)): + c = text[i] + if not (c == " " or c == "\r" or c == "\n" or c == "\t"): + idx = i + break + text = text[idx:] + if text[0] == "\\": + text = text[1:] + return text diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/errors.py b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..d8c3ef9f06279df42d4b2054acc4cfe39b6682a5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/errors.py @@ -0,0 +1,2 @@ +class LaTeXParsingError(Exception): + pass diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/__init__.py b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92e58d3172e100cc376d0b416b3835d164bd5647 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/__init__.py @@ -0,0 +1,2 @@ +from .latex_parser import parse_latex_lark, LarkLaTeXParser # noqa +from .transformer import TransformToSymPyExpr # noqa diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/grammar/greek_symbols.lark b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/grammar/greek_symbols.lark new file mode 100644 index 0000000000000000000000000000000000000000..7439fab9dcac284dc3c9b5fbfa4fc6db8b29dfd2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/grammar/greek_symbols.lark @@ -0,0 +1,28 @@ +// Greek symbols +// TODO: Shouold we include the uppercase variants for the symbols where the uppercase variant doesn't have a separate meaning? +ALPHA: "\\alpha" +BETA: "\\beta" +GAMMA: "\\gamma" +DELTA: "\\delta" // TODO: Should this be included? Delta usually denotes other things. +EPSILON: "\\epsilon" | "\\varepsilon" +ZETA: "\\zeta" +ETA: "\\eta" +THETA: "\\theta" | "\\vartheta" +// TODO: Should I add iota to the list? +KAPPA: "\\kappa" +LAMBDA: "\\lambda" // TODO: What about the uppercase variant? +MU: "\\mu" +NU: "\\nu" +XI: "\\xi" +// TODO: Should there be a separate note for transforming \pi into sympy.pi? +RHO: "\\rho" | "\\varrho" +// TODO: What should we do about sigma? +TAU: "\\tau" +UPSILON: "\\upsilon" +PHI: "\\phi" | "\\varphi" +CHI: "\\chi" +PSI: "\\psi" +OMEGA: "\\omega" + +GREEK_SYMBOL: ALPHA | BETA | GAMMA | DELTA | EPSILON | ZETA | ETA | THETA | KAPPA + | LAMBDA | MU | NU | XI | RHO | TAU | UPSILON | PHI | CHI | PSI | OMEGA diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/grammar/latex.lark b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/grammar/latex.lark new file mode 100644 index 0000000000000000000000000000000000000000..43e8d0e9105fa4da9bcdd2c0fa6111f6d523c9a9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/grammar/latex.lark @@ -0,0 +1,403 @@ +%ignore /[ \t\n\r]+/ + +%ignore "\\," | "\\thinspace" | "\\:" | "\\medspace" | "\\;" | "\\thickspace" +%ignore "\\quad" | "\\qquad" +%ignore "\\!" | "\\negthinspace" | "\\negmedspace" | "\\negthickspace" +%ignore "\\vrule" | "\\vcenter" | "\\vbox" | "\\vskip" | "\\vspace" | "\\hfill" +%ignore "\\*" | "\\-" | "\\." | "\\/" | "\\(" | "\\=" + +%ignore "\\left" | "\\right" +%ignore "\\limits" | "\\nolimits" +%ignore "\\displaystyle" + +///////////////////// tokens /////////////////////// + +// basic binary operators +ADD: "+" +SUB: "-" +MUL: "*" +DIV: "/" + +// tokens with distinct left and right symbols +L_BRACE: "{" +R_BRACE: "}" +L_BRACE_LITERAL: "\\{" +R_BRACE_LITERAL: "\\}" +L_BRACKET: "[" +R_BRACKET: "]" +L_CEIL: "\\lceil" +R_CEIL: "\\rceil" +L_FLOOR: "\\lfloor" +R_FLOOR: "\\rfloor" +L_PAREN: "(" +R_PAREN: ")" + +// limit, integral, sum, and product symbols +FUNC_LIM: "\\lim" +LIM_APPROACH_SYM: "\\to" | "\\rightarrow" | "\\Rightarrow" | "\\longrightarrow" | "\\Longrightarrow" +FUNC_INT: "\\int" | "\\intop" +FUNC_SUM: "\\sum" +FUNC_PROD: "\\prod" + +// common functions +FUNC_EXP: "\\exp" +FUNC_LOG: "\\log" +FUNC_LN: "\\ln" +FUNC_LG: "\\lg" +FUNC_MIN: "\\min" +FUNC_MAX: "\\max" + +// trigonometric functions +FUNC_SIN: "\\sin" +FUNC_COS: "\\cos" +FUNC_TAN: "\\tan" +FUNC_CSC: "\\csc" +FUNC_SEC: "\\sec" +FUNC_COT: "\\cot" + +// inverse trigonometric functions +FUNC_ARCSIN: "\\arcsin" +FUNC_ARCCOS: "\\arccos" +FUNC_ARCTAN: "\\arctan" +FUNC_ARCCSC: "\\arccsc" +FUNC_ARCSEC: "\\arcsec" +FUNC_ARCCOT: "\\arccot" + +// hyperbolic trigonometric functions +FUNC_SINH: "\\sinh" +FUNC_COSH: "\\cosh" +FUNC_TANH: "\\tanh" +FUNC_ARSINH: "\\arsinh" +FUNC_ARCOSH: "\\arcosh" +FUNC_ARTANH: "\\artanh" + +FUNC_SQRT: "\\sqrt" + +// miscellaneous symbols +CMD_TIMES: "\\times" +CMD_CDOT: "\\cdot" +CMD_DIV: "\\div" +CMD_FRAC: "\\frac" | "\\dfrac" | "\\tfrac" | "\\nicefrac" +CMD_BINOM: "\\binom" | "\\dbinom" | "\\tbinom" +CMD_OVERLINE: "\\overline" +CMD_LANGLE: "\\langle" +CMD_RANGLE: "\\rangle" + +CMD_MATHIT: "\\mathit" + +CMD_INFTY: "\\infty" + +BANG: "!" +BAR: "|" +CARET: "^" +COLON: ":" +UNDERSCORE: "_" + +// relational symbols +EQUAL: "=" +NOT_EQUAL: "\\neq" | "\\ne" +LT: "<" +LTE: "\\leq" | "\\le" | "\\leqslant" +GT: ">" +GTE: "\\geq" | "\\ge" | "\\geqslant" + +DIV_SYMBOL: CMD_DIV | DIV +MUL_SYMBOL: MUL | CMD_TIMES | CMD_CDOT + +%import .greek_symbols.GREEK_SYMBOL + +UPRIGHT_DIFFERENTIAL_SYMBOL: "\\text{d}" | "\\mathrm{d}" +DIFFERENTIAL_SYMBOL: "d" | UPRIGHT_DIFFERENTIAL_SYMBOL + +// disallow "d" as a variable name because we want to parse "d" as a differential symbol. +SYMBOL: /[a-zA-Z]'*/ +GREEK_SYMBOL_WITH_PRIMES: GREEK_SYMBOL "'"* +LATIN_SYMBOL_WITH_LATIN_SUBSCRIPT: /([a-zA-Z]'*)_(([A-Za-z0-9]|[a-zA-Z]+)|\{([A-Za-z0-9]|[a-zA-Z]+'*)\})/ +LATIN_SYMBOL_WITH_GREEK_SUBSCRIPT: /([a-zA-Z]'*)_/ GREEK_SYMBOL | /([a-zA-Z]'*)_/ L_BRACE GREEK_SYMBOL_WITH_PRIMES R_BRACE +// best to define the variant with braces like that instead of shoving it all into one case like in +// /([a-zA-Z])_/ L_BRACE? GREEK_SYMBOL R_BRACE? because then we can easily error out on input like +// r"h_{\theta" +GREEK_SYMBOL_WITH_LATIN_SUBSCRIPT: GREEK_SYMBOL_WITH_PRIMES /_(([A-Za-z0-9]|[a-zA-Z]+)|\{([A-Za-z0-9]|[a-zA-Z]+'*)\})/ +GREEK_SYMBOL_WITH_GREEK_SUBSCRIPT: GREEK_SYMBOL_WITH_PRIMES /_/ (GREEK_SYMBOL | L_BRACE GREEK_SYMBOL_WITH_PRIMES R_BRACE) +MULTI_LETTER_SYMBOL: /[a-zA-Z]+(\s+[a-zA-Z]+)*'*/ + +%import common.DIGIT -> DIGIT + +CMD_PRIME: "\\prime" +CMD_ASTERISK: "\\ast" + +PRIMES: "'"+ +STARS: "*"+ +PRIMES_VIA_CMD: CMD_PRIME+ +STARS_VIA_CMD: CMD_ASTERISK+ + +CMD_IMAGINARY_UNIT: "\\imaginaryunit" + +CMD_BEGIN: "\\begin" +CMD_END: "\\end" + +// matrices +IGNORE_L: /[ \t\n\r]*/ L_BRACE* /[ \t\n\r]*/ +IGNORE_R: /[ \t\n\r]*/ R_BRACE* /[ \t\n\r]*/ +ARRAY_MATRIX_BEGIN: L_BRACE "array" R_BRACE L_BRACE /[^}]*/ R_BRACE +ARRAY_MATRIX_END: L_BRACE "array" R_BRACE +AMSMATH_MATRIX: L_BRACE "matrix" R_BRACE +AMSMATH_PMATRIX: L_BRACE "pmatrix" R_BRACE +AMSMATH_BMATRIX: L_BRACE "bmatrix" R_BRACE +// Without the (L|R)_PARENs and (L|R)_BRACKETs, a matrix defined using +// \begin{array}...\end{array} or \begin{matrix}...\end{matrix} must +// not qualify as a complete matrix expression; this is done so that +// if we have \begin{array}...\end{array} or \begin{matrix}...\end{matrix} +// between BAR pairs, then they should be interpreted as determinants as +// opposed to sympy.Abs (absolute value) applied to a matrix. +CMD_BEGIN_AMSPMATRIX_AMSBMATRIX: CMD_BEGIN (AMSMATH_PMATRIX | AMSMATH_BMATRIX) +CMD_BEGIN_ARRAY_AMSMATRIX: (L_PAREN | L_BRACKET) IGNORE_L CMD_BEGIN (ARRAY_MATRIX_BEGIN | AMSMATH_MATRIX) +CMD_MATRIX_BEGIN: CMD_BEGIN_AMSPMATRIX_AMSBMATRIX | CMD_BEGIN_ARRAY_AMSMATRIX +CMD_END_AMSPMATRIX_AMSBMATRIX: CMD_END (AMSMATH_PMATRIX | AMSMATH_BMATRIX) +CMD_END_ARRAY_AMSMATRIX: CMD_END (ARRAY_MATRIX_END | AMSMATH_MATRIX) IGNORE_R "\\right"? (R_PAREN | R_BRACKET) +CMD_MATRIX_END: CMD_END_AMSPMATRIX_AMSBMATRIX | CMD_END_ARRAY_AMSMATRIX +MATRIX_COL_DELIM: "&" +MATRIX_ROW_DELIM: "\\\\" +FUNC_MATRIX_TRACE: "\\trace" +FUNC_MATRIX_ADJUGATE: "\\adjugate" + +// determinants +AMSMATH_VMATRIX: L_BRACE "vmatrix" R_BRACE +CMD_DETERMINANT_BEGIN_SIMPLE: CMD_BEGIN AMSMATH_VMATRIX +CMD_DETERMINANT_BEGIN_VARIANT: BAR IGNORE_L CMD_BEGIN (ARRAY_MATRIX_BEGIN | AMSMATH_MATRIX) +CMD_DETERMINANT_BEGIN: CMD_DETERMINANT_BEGIN_SIMPLE | CMD_DETERMINANT_BEGIN_VARIANT +CMD_DETERMINANT_END_SIMPLE: CMD_END AMSMATH_VMATRIX +CMD_DETERMINANT_END_VARIANT: CMD_END (ARRAY_MATRIX_END | AMSMATH_MATRIX) IGNORE_R "\\right"? BAR +CMD_DETERMINANT_END: CMD_DETERMINANT_END_SIMPLE | CMD_DETERMINANT_END_VARIANT +FUNC_DETERMINANT: "\\det" + +//////////////////// grammar ////////////////////// + +latex_string: _relation | _expression + +_one_letter_symbol: SYMBOL + | LATIN_SYMBOL_WITH_LATIN_SUBSCRIPT + | LATIN_SYMBOL_WITH_GREEK_SUBSCRIPT + | GREEK_SYMBOL_WITH_LATIN_SUBSCRIPT + | GREEK_SYMBOL_WITH_GREEK_SUBSCRIPT + | GREEK_SYMBOL_WITH_PRIMES +// LuaTeX-generated outputs of \mathit{foo'} and \mathit{foo}' +// seem to be the same on the surface. We allow both styles. +multi_letter_symbol: CMD_MATHIT L_BRACE MULTI_LETTER_SYMBOL R_BRACE + | CMD_MATHIT L_BRACE MULTI_LETTER_SYMBOL R_BRACE /'+/ +number: /\d+(\.\d*)?/ | CMD_IMAGINARY_UNIT + +_atomic_expr: _one_letter_symbol + | multi_letter_symbol + | number + | CMD_INFTY + +group_round_parentheses: L_PAREN _expression R_PAREN +group_square_brackets: L_BRACKET _expression R_BRACKET +group_curly_parentheses: L_BRACE _expression R_BRACE + +_relation: eq | ne | lt | lte | gt | gte + +eq: _expression EQUAL _expression +ne: _expression NOT_EQUAL _expression +lt: _expression LT _expression +lte: _expression LTE _expression +gt: _expression GT _expression +gte: _expression GTE _expression + +_expression_core: _atomic_expr | group_curly_parentheses + +add: _expression ADD _expression_mul + | ADD _expression_mul +sub: _expression SUB _expression_mul + | SUB _expression_mul +mul: _expression_mul MUL_SYMBOL _expression_power +div: _expression_mul DIV_SYMBOL _expression_power + +adjacent_expressions: (_one_letter_symbol | number) _expression_mul + | group_round_parentheses (group_round_parentheses | _one_letter_symbol) + | _function _function + | fraction _expression_mul + +_expression_func: _expression_core + | group_round_parentheses + | fraction + | binomial + | _function + | _integral// | derivative + | limit + | matrix + +_expression_power: _expression_func | superscript | matrix_prime | symbol_prime + +_expression_mul: _expression_power + | mul | div | adjacent_expressions + | summation | product + +_expression: _expression_mul | add | sub + +_limit_dir: "+" | "-" | L_BRACE ("+" | "-") R_BRACE + +limit_dir_expr: _expression CARET _limit_dir + +group_curly_parentheses_lim: L_BRACE _expression LIM_APPROACH_SYM (limit_dir_expr | _expression) R_BRACE + +limit: FUNC_LIM UNDERSCORE group_curly_parentheses_lim _expression + +differential: DIFFERENTIAL_SYMBOL _one_letter_symbol + +//_derivative_operator: CMD_FRAC L_BRACE DIFFERENTIAL_SYMBOL R_BRACE L_BRACE differential R_BRACE + +//derivative: _derivative_operator _expression + +_integral: normal_integral | integral_with_special_fraction + +normal_integral: FUNC_INT _expression DIFFERENTIAL_SYMBOL _one_letter_symbol + | FUNC_INT (CARET _expression_core UNDERSCORE _expression_core)? _expression? DIFFERENTIAL_SYMBOL _one_letter_symbol + | FUNC_INT (UNDERSCORE _expression_core CARET _expression_core)? _expression? DIFFERENTIAL_SYMBOL _one_letter_symbol + +group_curly_parentheses_int: L_BRACE _expression? differential R_BRACE + +special_fraction: CMD_FRAC group_curly_parentheses_int group_curly_parentheses + +integral_with_special_fraction: FUNC_INT special_fraction + | FUNC_INT (CARET _expression_core UNDERSCORE _expression_core)? special_fraction + | FUNC_INT (UNDERSCORE _expression_core CARET _expression_core)? special_fraction + +group_curly_parentheses_special: UNDERSCORE L_BRACE _atomic_expr EQUAL _atomic_expr R_BRACE CARET _expression_core + | CARET _expression_core UNDERSCORE L_BRACE _atomic_expr EQUAL _atomic_expr R_BRACE + +summation: FUNC_SUM group_curly_parentheses_special _expression + | FUNC_SUM group_curly_parentheses_special _expression + +product: FUNC_PROD group_curly_parentheses_special _expression + | FUNC_PROD group_curly_parentheses_special _expression + +superscript: _expression_func CARET (_expression_power | CMD_PRIME | CMD_ASTERISK) + | _expression_func CARET L_BRACE (PRIMES | STARS | PRIMES_VIA_CMD | STARS_VIA_CMD) R_BRACE + +matrix_prime: (matrix | group_round_parentheses) PRIMES + +symbol_prime: (LATIN_SYMBOL_WITH_LATIN_SUBSCRIPT + | LATIN_SYMBOL_WITH_GREEK_SUBSCRIPT + | GREEK_SYMBOL_WITH_LATIN_SUBSCRIPT + | GREEK_SYMBOL_WITH_GREEK_SUBSCRIPT) PRIMES + +fraction: _basic_fraction + | _simple_fraction + | _general_fraction + +_basic_fraction: CMD_FRAC DIGIT (DIGIT | SYMBOL | GREEK_SYMBOL_WITH_PRIMES) + +_simple_fraction: CMD_FRAC DIGIT group_curly_parentheses + | CMD_FRAC group_curly_parentheses (DIGIT | SYMBOL | GREEK_SYMBOL_WITH_PRIMES) + +_general_fraction: CMD_FRAC group_curly_parentheses group_curly_parentheses + +binomial: _basic_binomial + | _simple_binomial + | _general_binomial + +_basic_binomial: CMD_BINOM DIGIT (DIGIT | SYMBOL | GREEK_SYMBOL_WITH_PRIMES) + +_simple_binomial: CMD_BINOM DIGIT group_curly_parentheses + | CMD_BINOM group_curly_parentheses (DIGIT | SYMBOL | GREEK_SYMBOL_WITH_PRIMES) + +_general_binomial: CMD_BINOM group_curly_parentheses group_curly_parentheses + +list_of_expressions: _expression ("," _expression)* + +function_applied: _one_letter_symbol L_PAREN list_of_expressions R_PAREN + +min: FUNC_MIN L_PAREN list_of_expressions R_PAREN + +max: FUNC_MAX L_PAREN list_of_expressions R_PAREN + +bra: CMD_LANGLE _expression BAR + +ket: BAR _expression CMD_RANGLE + +inner_product: CMD_LANGLE _expression BAR _expression CMD_RANGLE + +_function: function_applied + | abs | floor | ceil + | _trigonometric_function | _inverse_trigonometric_function + | _trigonometric_function_power + | _hyperbolic_trigonometric_function | _inverse_hyperbolic_trigonometric_function + | exponential + | log + | square_root + | factorial + | conjugate + | max | min + | bra | ket | inner_product + | determinant + | trace + | adjugate + +exponential: FUNC_EXP _expression + +log: FUNC_LOG _expression + | FUNC_LN _expression + | FUNC_LG _expression + | FUNC_LOG UNDERSCORE (DIGIT | _one_letter_symbol) _expression + | FUNC_LOG UNDERSCORE group_curly_parentheses _expression + +square_root: FUNC_SQRT group_curly_parentheses + | FUNC_SQRT group_square_brackets group_curly_parentheses + +factorial: _expression_func BANG + +conjugate: CMD_OVERLINE group_curly_parentheses + | CMD_OVERLINE DIGIT + +_trigonometric_function: sin | cos | tan | csc | sec | cot + +sin: FUNC_SIN _expression +cos: FUNC_COS _expression +tan: FUNC_TAN _expression +csc: FUNC_CSC _expression +sec: FUNC_SEC _expression +cot: FUNC_COT _expression + +_trigonometric_function_power: sin_power | cos_power | tan_power | csc_power | sec_power | cot_power + +sin_power: FUNC_SIN CARET _expression_core _expression +cos_power: FUNC_COS CARET _expression_core _expression +tan_power: FUNC_TAN CARET _expression_core _expression +csc_power: FUNC_CSC CARET _expression_core _expression +sec_power: FUNC_SEC CARET _expression_core _expression +cot_power: FUNC_COT CARET _expression_core _expression + +_hyperbolic_trigonometric_function: sinh | cosh | tanh + +sinh: FUNC_SINH _expression +cosh: FUNC_COSH _expression +tanh: FUNC_TANH _expression + +_inverse_trigonometric_function: arcsin | arccos | arctan | arccsc | arcsec | arccot + +arcsin: FUNC_ARCSIN _expression +arccos: FUNC_ARCCOS _expression +arctan: FUNC_ARCTAN _expression +arccsc: FUNC_ARCCSC _expression +arcsec: FUNC_ARCSEC _expression +arccot: FUNC_ARCCOT _expression + +_inverse_hyperbolic_trigonometric_function: asinh | acosh | atanh + +asinh: FUNC_ARSINH _expression +acosh: FUNC_ARCOSH _expression +atanh: FUNC_ARTANH _expression + +abs: BAR _expression BAR +floor: L_FLOOR _expression R_FLOOR +ceil: L_CEIL _expression R_CEIL + +matrix: CMD_MATRIX_BEGIN matrix_body CMD_MATRIX_END +matrix_body: matrix_row (MATRIX_ROW_DELIM matrix_row)* (MATRIX_ROW_DELIM)? +matrix_row: _expression (MATRIX_COL_DELIM _expression)* +determinant: (CMD_DETERMINANT_BEGIN matrix_body CMD_DETERMINANT_END) + | FUNC_DETERMINANT _expression +trace: FUNC_MATRIX_TRACE _expression +adjugate: FUNC_MATRIX_ADJUGATE _expression diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/latex_parser.py b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/latex_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..29f594b0de4bfd4648df1554d5863a37afff035f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/latex_parser.py @@ -0,0 +1,145 @@ +import os +import logging +import re +from pathlib import Path + +from sympy.external import import_module +from sympy.parsing.latex.lark.transformer import TransformToSymPyExpr + +_lark = import_module("lark") + + +class LarkLaTeXParser: + r"""Class for converting input `\mathrm{\LaTeX}` strings into SymPy Expressions. + It holds all the necessary internal data for doing so, and exposes hooks for + customizing its behavior. + + Parameters + ========== + + print_debug_output : bool, optional + + If set to ``True``, prints debug output to the logger. Defaults to ``False``. + + transform : bool, optional + + If set to ``True``, the class runs the Transformer class on the parse tree + generated by running ``Lark.parse`` on the input string. Defaults to ``True``. + + Setting it to ``False`` can help with debugging the `\mathrm{\LaTeX}` grammar. + + grammar_file : str, optional + + The path to the grammar file that the parser should use. If set to ``None``, + it uses the default grammar, which is in ``grammar/latex.lark``, relative to + the ``sympy/parsing/latex/lark/`` directory. + + transformer : str, optional + + The name of the Transformer class to use. If set to ``None``, it uses the + default transformer class, which is :py:func:`TransformToSymPyExpr`. + + """ + def __init__(self, print_debug_output=False, transform=True, grammar_file=None, transformer=None): + grammar_dir_path = os.path.join(os.path.dirname(__file__), "grammar/") + + if grammar_file is None: + latex_grammar = Path(os.path.join(grammar_dir_path, "latex.lark")).read_text(encoding="utf-8") + else: + latex_grammar = Path(grammar_file).read_text(encoding="utf-8") + + self.parser = _lark.Lark( + latex_grammar, + source_path=grammar_dir_path, + parser="earley", + start="latex_string", + lexer="auto", + ambiguity="explicit", + propagate_positions=False, + maybe_placeholders=False, + keep_all_tokens=True) + + self.print_debug_output = print_debug_output + self.transform_expr = transform + + if transformer is None: + self.transformer = TransformToSymPyExpr() + else: + self.transformer = transformer() + + def doparse(self, s: str): + if self.print_debug_output: + _lark.logger.setLevel(logging.DEBUG) + + parse_tree = self.parser.parse(s) + + if not self.transform_expr: + # exit early and return the parse tree + _lark.logger.debug("expression = %s", s) + _lark.logger.debug(parse_tree) + _lark.logger.debug(parse_tree.pretty()) + return parse_tree + + if self.print_debug_output: + # print this stuff before attempting to run the transformer + _lark.logger.debug("expression = %s", s) + # print the `parse_tree` variable + _lark.logger.debug(parse_tree.pretty()) + + sympy_expression = self.transformer.transform(parse_tree) + + if self.print_debug_output: + _lark.logger.debug("SymPy expression = %s", sympy_expression) + + return sympy_expression + + +if _lark is not None: + _lark_latex_parser = LarkLaTeXParser() + + +def parse_latex_lark(s: str): + """ + Experimental LaTeX parser using Lark. + + This function is still under development and its API may change with the + next releases of SymPy. + """ + if _lark is None: + raise ImportError("Lark is probably not installed") + return _lark_latex_parser.doparse(s) + + +def _pretty_print_lark_trees(tree, indent=0, show_expr=True): + if isinstance(tree, _lark.Token): + return tree.value + + data = str(tree.data) + + is_expr = data.startswith("expression") + + if is_expr: + data = re.sub(r"^expression", "E", data) + + is_ambig = (data == "_ambig") + + if is_ambig: + new_indent = indent + 2 + else: + new_indent = indent + + output = "" + show_node = not is_expr or show_expr + + if show_node: + output += str(data) + "(" + + if is_ambig: + output += "\n" + "\n".join([" " * new_indent + _pretty_print_lark_trees(i, new_indent, show_expr) for i in tree.children]) + else: + output += ",".join([_pretty_print_lark_trees(i, new_indent, show_expr) for i in tree.children]) + + if show_node: + output += ")" + + return output diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/transformer.py b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd514b6517336207a57de6d28bcce25858071dc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/latex/lark/transformer.py @@ -0,0 +1,730 @@ +import re + +import sympy +from sympy.external import import_module +from sympy.parsing.latex.errors import LaTeXParsingError + +lark = import_module("lark") + +if lark: + from lark import Transformer, Token, Tree # type: ignore +else: + class Transformer: # type: ignore + def transform(self, *args): + pass + + + class Token: # type: ignore + pass + + + class Tree: # type: ignore + pass + + +# noinspection PyPep8Naming,PyMethodMayBeStatic +class TransformToSymPyExpr(Transformer): + """Returns a SymPy expression that is generated by traversing the ``lark.Tree`` + passed to the ``.transform()`` function. + + Notes + ===== + + **This class is never supposed to be used directly.** + + In order to tweak the behavior of this class, it has to be subclassed and then after + the required modifications are made, the name of the new class should be passed to + the :py:class:`LarkLaTeXParser` class by using the ``transformer`` argument in the + constructor. + + Parameters + ========== + + visit_tokens : bool, optional + For information about what this option does, see `here + `_. + + Note that the option must be set to ``True`` for the default parser to work. + """ + + SYMBOL = sympy.Symbol + DIGIT = sympy.core.numbers.Integer + + def CMD_INFTY(self, tokens): + return sympy.oo + + def GREEK_SYMBOL_WITH_PRIMES(self, tokens): + # we omit the first character because it is a backslash. Also, if the variable name has "var" in it, + # like "varphi" or "varepsilon", we remove that too + variable_name = re.sub("var", "", tokens[1:]) + + return sympy.Symbol(variable_name) + + def LATIN_SYMBOL_WITH_LATIN_SUBSCRIPT(self, tokens): + base, sub = tokens.value.split("_") + if sub.startswith("{"): + return sympy.Symbol("%s_{%s}" % (base, sub[1:-1])) + else: + return sympy.Symbol("%s_{%s}" % (base, sub)) + + def GREEK_SYMBOL_WITH_LATIN_SUBSCRIPT(self, tokens): + base, sub = tokens.value.split("_") + greek_letter = re.sub("var", "", base[1:]) + + if sub.startswith("{"): + return sympy.Symbol("%s_{%s}" % (greek_letter, sub[1:-1])) + else: + return sympy.Symbol("%s_{%s}" % (greek_letter, sub)) + + def LATIN_SYMBOL_WITH_GREEK_SUBSCRIPT(self, tokens): + base, sub = tokens.value.split("_") + if sub.startswith("{"): + greek_letter = sub[2:-1] + else: + greek_letter = sub[1:] + + greek_letter = re.sub("var", "", greek_letter) + return sympy.Symbol("%s_{%s}" % (base, greek_letter)) + + + def GREEK_SYMBOL_WITH_GREEK_SUBSCRIPT(self, tokens): + base, sub = tokens.value.split("_") + greek_base = re.sub("var", "", base[1:]) + + if sub.startswith("{"): + greek_sub = sub[2:-1] + else: + greek_sub = sub[1:] + + greek_sub = re.sub("var", "", greek_sub) + return sympy.Symbol("%s_{%s}" % (greek_base, greek_sub)) + + def multi_letter_symbol(self, tokens): + if len(tokens) == 4: # no primes (single quotes) on symbol + return sympy.Symbol(tokens[2]) + if len(tokens) == 5: # there are primes on the symbol + return sympy.Symbol(tokens[2] + tokens[4]) + + def number(self, tokens): + if tokens[0].type == "CMD_IMAGINARY_UNIT": + return sympy.I + + if "." in tokens[0]: + return sympy.core.numbers.Float(tokens[0]) + else: + return sympy.core.numbers.Integer(tokens[0]) + + def latex_string(self, tokens): + return tokens[0] + + def group_round_parentheses(self, tokens): + return tokens[1] + + def group_square_brackets(self, tokens): + return tokens[1] + + def group_curly_parentheses(self, tokens): + return tokens[1] + + def eq(self, tokens): + return sympy.Eq(tokens[0], tokens[2]) + + def ne(self, tokens): + return sympy.Ne(tokens[0], tokens[2]) + + def lt(self, tokens): + return sympy.Lt(tokens[0], tokens[2]) + + def lte(self, tokens): + return sympy.Le(tokens[0], tokens[2]) + + def gt(self, tokens): + return sympy.Gt(tokens[0], tokens[2]) + + def gte(self, tokens): + return sympy.Ge(tokens[0], tokens[2]) + + def add(self, tokens): + if len(tokens) == 2: # +a + return tokens[1] + if len(tokens) == 3: # a + b + lh = tokens[0] + rh = tokens[2] + + if self._obj_is_sympy_Matrix(lh) or self._obj_is_sympy_Matrix(rh): + return sympy.MatAdd(lh, rh) + + return sympy.Add(lh, rh) + + def sub(self, tokens): + if len(tokens) == 2: # -a + x = tokens[1] + + if self._obj_is_sympy_Matrix(x): + return sympy.MatMul(-1, x) + + return -x + if len(tokens) == 3: # a - b + lh = tokens[0] + rh = tokens[2] + + if self._obj_is_sympy_Matrix(lh) or self._obj_is_sympy_Matrix(rh): + return sympy.MatAdd(lh, sympy.MatMul(-1, rh)) + + return sympy.Add(lh, -rh) + + def mul(self, tokens): + lh = tokens[0] + rh = tokens[2] + + if self._obj_is_sympy_Matrix(lh) or self._obj_is_sympy_Matrix(rh): + return sympy.MatMul(lh, rh) + + return sympy.Mul(lh, rh) + + def div(self, tokens): + return self._handle_division(tokens[0], tokens[2]) + + def adjacent_expressions(self, tokens): + # Most of the time, if two expressions are next to each other, it means implicit multiplication, + # but not always + from sympy.physics.quantum import Bra, Ket + if isinstance(tokens[0], Ket) and isinstance(tokens[1], Bra): + from sympy.physics.quantum import OuterProduct + return OuterProduct(tokens[0], tokens[1]) + elif tokens[0] == sympy.Symbol("d"): + # If the leftmost token is a "d", then it is highly likely that this is a differential + return tokens[0], tokens[1] + elif isinstance(tokens[0], tuple): + # then we have a derivative + return sympy.Derivative(tokens[1], tokens[0][1]) + else: + return sympy.Mul(tokens[0], tokens[1]) + + def superscript(self, tokens): + def isprime(x): + return isinstance(x, Token) and x.type == "PRIMES" + + def iscmdprime(x): + return isinstance(x, Token) and (x.type == "PRIMES_VIA_CMD" + or x.type == "CMD_PRIME") + + def isstar(x): + return isinstance(x, Token) and x.type == "STARS" + + def iscmdstar(x): + return isinstance(x, Token) and (x.type == "STARS_VIA_CMD" + or x.type == "CMD_ASTERISK") + + base = tokens[0] + if len(tokens) == 3: # a^b OR a^\prime OR a^\ast + sup = tokens[2] + if len(tokens) == 5: + # a^{'}, a^{''}, ... OR + # a^{*}, a^{**}, ... OR + # a^{\prime}, a^{\prime\prime}, ... OR + # a^{\ast}, a^{\ast\ast}, ... + sup = tokens[3] + + if self._obj_is_sympy_Matrix(base): + if sup == sympy.Symbol("T"): + return sympy.Transpose(base) + if sup == sympy.Symbol("H"): + return sympy.adjoint(base) + if isprime(sup): + sup = sup.value + if len(sup) % 2 == 0: + return base + return sympy.Transpose(base) + if iscmdprime(sup): + sup = sup.value + if (len(sup)/len(r"\prime")) % 2 == 0: + return base + return sympy.Transpose(base) + if isstar(sup): + sup = sup.value + # need .doit() in order to be consistent with + # sympy.adjoint() which returns the evaluated adjoint + # of a matrix + if len(sup) % 2 == 0: + return base.doit() + return sympy.adjoint(base) + if iscmdstar(sup): + sup = sup.value + # need .doit() for same reason as above + if (len(sup)/len(r"\ast")) % 2 == 0: + return base.doit() + return sympy.adjoint(base) + + if isprime(sup) or iscmdprime(sup) or isstar(sup) or iscmdstar(sup): + raise LaTeXParsingError(f"{base} with superscript {sup} is not understood.") + + return sympy.Pow(base, sup) + + def matrix_prime(self, tokens): + base = tokens[0] + primes = tokens[1].value + + if not self._obj_is_sympy_Matrix(base): + raise LaTeXParsingError(f"({base}){primes} is not understood.") + + if len(primes) % 2 == 0: + return base + + return sympy.Transpose(base) + + def symbol_prime(self, tokens): + base = tokens[0] + primes = tokens[1].value + + return sympy.Symbol(f"{base.name}{primes}") + + def fraction(self, tokens): + numerator = tokens[1] + if isinstance(tokens[2], tuple): + # we only need the variable w.r.t. which we are differentiating + _, variable = tokens[2] + + # we will pass this information upwards + return "derivative", variable + else: + denominator = tokens[2] + return self._handle_division(numerator, denominator) + + def binomial(self, tokens): + return sympy.binomial(tokens[1], tokens[2]) + + def normal_integral(self, tokens): + underscore_index = None + caret_index = None + + if "_" in tokens: + # we need to know the index because the next item in the list is the + # arguments for the lower bound of the integral + underscore_index = tokens.index("_") + + if "^" in tokens: + # we need to know the index because the next item in the list is the + # arguments for the upper bound of the integral + caret_index = tokens.index("^") + + lower_bound = tokens[underscore_index + 1] if underscore_index else None + upper_bound = tokens[caret_index + 1] if caret_index else None + + differential_symbol = self._extract_differential_symbol(tokens) + + if differential_symbol is None: + raise LaTeXParsingError("Differential symbol was not found in the expression." + "Valid differential symbols are \"d\", \"\\text{d}, and \"\\mathrm{d}\".") + + # else we can assume that a differential symbol was found + differential_variable_index = tokens.index(differential_symbol) + 1 + differential_variable = tokens[differential_variable_index] + + # we can't simply do something like `if (lower_bound and not upper_bound) ...` because this would + # evaluate to `True` if the `lower_bound` is 0 and upper bound is non-zero + if lower_bound is not None and upper_bound is None: + # then one was given and the other wasn't + raise LaTeXParsingError("Lower bound for the integral was found, but upper bound was not found.") + + if upper_bound is not None and lower_bound is None: + # then one was given and the other wasn't + raise LaTeXParsingError("Upper bound for the integral was found, but lower bound was not found.") + + # check if any expression was given or not. If it wasn't, then set the integrand to 1. + if underscore_index is not None and underscore_index == differential_variable_index - 3: + # The Token at differential_variable_index - 2 should be the integrand. However, if going one more step + # backwards after that gives us the underscore, then that means that there _was_ no integrand. + # Example: \int^7_0 dx + integrand = 1 + elif caret_index is not None and caret_index == differential_variable_index - 3: + # The Token at differential_variable_index - 2 should be the integrand. However, if going one more step + # backwards after that gives us the caret, then that means that there _was_ no integrand. + # Example: \int_0^7 dx + integrand = 1 + elif differential_variable_index == 2: + # this means we have something like "\int dx", because the "\int" symbol will always be + # at index 0 in `tokens` + integrand = 1 + else: + # The Token at differential_variable_index - 1 is the differential symbol itself, so we need to go one + # more step before that. + integrand = tokens[differential_variable_index - 2] + + if lower_bound is not None: + # then we have a definite integral + + # we can assume that either both the lower and upper bounds are given, or + # neither of them are + return sympy.Integral(integrand, (differential_variable, lower_bound, upper_bound)) + else: + # we have an indefinite integral + return sympy.Integral(integrand, differential_variable) + + def group_curly_parentheses_int(self, tokens): + # return signature is a tuple consisting of the expression in the numerator, along with the variable of + # integration + if len(tokens) == 3: + return 1, tokens[1] + elif len(tokens) == 4: + return tokens[1], tokens[2] + # there are no other possibilities + + def special_fraction(self, tokens): + numerator, variable = tokens[1] + denominator = tokens[2] + + # We pass the integrand, along with information about the variable of integration, upw + return sympy.Mul(numerator, sympy.Pow(denominator, -1)), variable + + def integral_with_special_fraction(self, tokens): + underscore_index = None + caret_index = None + + if "_" in tokens: + # we need to know the index because the next item in the list is the + # arguments for the lower bound of the integral + underscore_index = tokens.index("_") + + if "^" in tokens: + # we need to know the index because the next item in the list is the + # arguments for the upper bound of the integral + caret_index = tokens.index("^") + + lower_bound = tokens[underscore_index + 1] if underscore_index else None + upper_bound = tokens[caret_index + 1] if caret_index else None + + # we can't simply do something like `if (lower_bound and not upper_bound) ...` because this would + # evaluate to `True` if the `lower_bound` is 0 and upper bound is non-zero + if lower_bound is not None and upper_bound is None: + # then one was given and the other wasn't + raise LaTeXParsingError("Lower bound for the integral was found, but upper bound was not found.") + + if upper_bound is not None and lower_bound is None: + # then one was given and the other wasn't + raise LaTeXParsingError("Upper bound for the integral was found, but lower bound was not found.") + + integrand, differential_variable = tokens[-1] + + if lower_bound is not None: + # then we have a definite integral + + # we can assume that either both the lower and upper bounds are given, or + # neither of them are + return sympy.Integral(integrand, (differential_variable, lower_bound, upper_bound)) + else: + # we have an indefinite integral + return sympy.Integral(integrand, differential_variable) + + def group_curly_parentheses_special(self, tokens): + underscore_index = tokens.index("_") + caret_index = tokens.index("^") + + # given the type of expressions we are parsing, we can assume that the lower limit + # will always use braces around its arguments. This is because we don't support + # converting unconstrained sums into SymPy expressions. + + # first we isolate the bottom limit + left_brace_index = tokens.index("{", underscore_index) + right_brace_index = tokens.index("}", underscore_index) + + bottom_limit = tokens[left_brace_index + 1: right_brace_index] + + # next, we isolate the upper limit + top_limit = tokens[caret_index + 1:] + + # the code below will be useful for supporting things like `\sum_{n = 0}^{n = 5} n^2` + # if "{" in top_limit: + # left_brace_index = tokens.index("{", caret_index) + # if left_brace_index != -1: + # # then there's a left brace in the string, and we need to find the closing right brace + # right_brace_index = tokens.index("}", caret_index) + # top_limit = tokens[left_brace_index + 1: right_brace_index] + + # print(f"top limit = {top_limit}") + + index_variable = bottom_limit[0] + lower_limit = bottom_limit[-1] + upper_limit = top_limit[0] # for now, the index will always be 0 + + # print(f"return value = ({index_variable}, {lower_limit}, {upper_limit})") + + return index_variable, lower_limit, upper_limit + + def summation(self, tokens): + return sympy.Sum(tokens[2], tokens[1]) + + def product(self, tokens): + return sympy.Product(tokens[2], tokens[1]) + + def limit_dir_expr(self, tokens): + caret_index = tokens.index("^") + + if "{" in tokens: + left_curly_brace_index = tokens.index("{", caret_index) + direction = tokens[left_curly_brace_index + 1] + else: + direction = tokens[caret_index + 1] + + if direction == "+": + return tokens[0], "+" + elif direction == "-": + return tokens[0], "-" + else: + return tokens[0], "+-" + + def group_curly_parentheses_lim(self, tokens): + limit_variable = tokens[1] + if isinstance(tokens[3], tuple): + destination, direction = tokens[3] + else: + destination = tokens[3] + direction = "+-" + + return limit_variable, destination, direction + + def limit(self, tokens): + limit_variable, destination, direction = tokens[2] + + return sympy.Limit(tokens[-1], limit_variable, destination, direction) + + def differential(self, tokens): + return tokens[1] + + def derivative(self, tokens): + return sympy.Derivative(tokens[-1], tokens[5]) + + def list_of_expressions(self, tokens): + if len(tokens) == 1: + # we return it verbatim because the function_applied node expects + # a list + return tokens + else: + def remove_tokens(args): + if isinstance(args, Token): + if args.type != "COMMA": + # An unexpected token was encountered + raise LaTeXParsingError("A comma token was expected, but some other token was encountered.") + return False + return True + + return filter(remove_tokens, tokens) + + def function_applied(self, tokens): + return sympy.Function(tokens[0])(*tokens[2]) + + def min(self, tokens): + return sympy.Min(*tokens[2]) + + def max(self, tokens): + return sympy.Max(*tokens[2]) + + def bra(self, tokens): + from sympy.physics.quantum import Bra + return Bra(tokens[1]) + + def ket(self, tokens): + from sympy.physics.quantum import Ket + return Ket(tokens[1]) + + def inner_product(self, tokens): + from sympy.physics.quantum import Bra, Ket, InnerProduct + return InnerProduct(Bra(tokens[1]), Ket(tokens[3])) + + def sin(self, tokens): + return sympy.sin(tokens[1]) + + def cos(self, tokens): + return sympy.cos(tokens[1]) + + def tan(self, tokens): + return sympy.tan(tokens[1]) + + def csc(self, tokens): + return sympy.csc(tokens[1]) + + def sec(self, tokens): + return sympy.sec(tokens[1]) + + def cot(self, tokens): + return sympy.cot(tokens[1]) + + def sin_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.asin(tokens[-1]) + else: + return sympy.Pow(sympy.sin(tokens[-1]), exponent) + + def cos_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.acos(tokens[-1]) + else: + return sympy.Pow(sympy.cos(tokens[-1]), exponent) + + def tan_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.atan(tokens[-1]) + else: + return sympy.Pow(sympy.tan(tokens[-1]), exponent) + + def csc_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.acsc(tokens[-1]) + else: + return sympy.Pow(sympy.csc(tokens[-1]), exponent) + + def sec_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.asec(tokens[-1]) + else: + return sympy.Pow(sympy.sec(tokens[-1]), exponent) + + def cot_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.acot(tokens[-1]) + else: + return sympy.Pow(sympy.cot(tokens[-1]), exponent) + + def arcsin(self, tokens): + return sympy.asin(tokens[1]) + + def arccos(self, tokens): + return sympy.acos(tokens[1]) + + def arctan(self, tokens): + return sympy.atan(tokens[1]) + + def arccsc(self, tokens): + return sympy.acsc(tokens[1]) + + def arcsec(self, tokens): + return sympy.asec(tokens[1]) + + def arccot(self, tokens): + return sympy.acot(tokens[1]) + + def sinh(self, tokens): + return sympy.sinh(tokens[1]) + + def cosh(self, tokens): + return sympy.cosh(tokens[1]) + + def tanh(self, tokens): + return sympy.tanh(tokens[1]) + + def asinh(self, tokens): + return sympy.asinh(tokens[1]) + + def acosh(self, tokens): + return sympy.acosh(tokens[1]) + + def atanh(self, tokens): + return sympy.atanh(tokens[1]) + + def abs(self, tokens): + return sympy.Abs(tokens[1]) + + def floor(self, tokens): + return sympy.floor(tokens[1]) + + def ceil(self, tokens): + return sympy.ceiling(tokens[1]) + + def factorial(self, tokens): + return sympy.factorial(tokens[0]) + + def conjugate(self, tokens): + return sympy.conjugate(tokens[1]) + + def square_root(self, tokens): + if len(tokens) == 2: + # then there was no square bracket argument + return sympy.sqrt(tokens[1]) + elif len(tokens) == 3: + # then there _was_ a square bracket argument + return sympy.root(tokens[2], tokens[1]) + + def exponential(self, tokens): + return sympy.exp(tokens[1]) + + def log(self, tokens): + if tokens[0].type == "FUNC_LG": + # we don't need to check if there's an underscore or not because having one + # in this case would be meaningless + # TODO: ANTLR refers to ISO 80000-2:2019. should we keep base 10 or base 2? + return sympy.log(tokens[1], 10) + elif tokens[0].type == "FUNC_LN": + return sympy.log(tokens[1]) + elif tokens[0].type == "FUNC_LOG": + # we check if a base was specified or not + if "_" in tokens: + # then a base was specified + return sympy.log(tokens[3], tokens[2]) + else: + # a base was not specified + return sympy.log(tokens[1]) + + def _extract_differential_symbol(self, s: str): + differential_symbols = {"d", r"\text{d}", r"\mathrm{d}"} + + differential_symbol = next((symbol for symbol in differential_symbols if symbol in s), None) + + return differential_symbol + + def matrix(self, tokens): + def is_matrix_row(x): + return (isinstance(x, Tree) and x.data == "matrix_row") + + def is_not_col_delim(y): + return (not isinstance(y, Token) or y.type != "MATRIX_COL_DELIM") + + matrix_body = tokens[1].children + return sympy.Matrix([[y for y in x.children if is_not_col_delim(y)] + for x in matrix_body if is_matrix_row(x)]) + + def determinant(self, tokens): + if len(tokens) == 2: # \det A + if not self._obj_is_sympy_Matrix(tokens[1]): + raise LaTeXParsingError("Cannot take determinant of non-matrix.") + + return tokens[1].det() + + if len(tokens) == 3: # | A | + return self.matrix(tokens).det() + + def trace(self, tokens): + if not self._obj_is_sympy_Matrix(tokens[1]): + raise LaTeXParsingError("Cannot take trace of non-matrix.") + + return sympy.Trace(tokens[1]) + + def adjugate(self, tokens): + if not self._obj_is_sympy_Matrix(tokens[1]): + raise LaTeXParsingError("Cannot take adjugate of non-matrix.") + + # need .doit() since MatAdd does not support .adjugate() method + return tokens[1].doit().adjugate() + + def _obj_is_sympy_Matrix(self, obj): + if hasattr(obj, "is_Matrix"): + return obj.is_Matrix + + return isinstance(obj, sympy.Matrix) + + def _handle_division(self, numerator, denominator): + if self._obj_is_sympy_Matrix(denominator): + raise LaTeXParsingError("Cannot divide by matrices like this since " + "it is not clear if left or right multiplication " + "by the inverse is intended. Try explicitly " + "multiplying by the inverse instead.") + + if self._obj_is_sympy_Matrix(numerator): + return sympy.MatMul(numerator, sympy.Pow(denominator, -1)) + + return sympy.Mul(numerator, sympy.Pow(denominator, -1)) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/mathematica.py b/.venv/lib/python3.13/site-packages/sympy/parsing/mathematica.py new file mode 100644 index 0000000000000000000000000000000000000000..b5824a8c33ee402d03e6c5617eeeea21d4a457d1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/mathematica.py @@ -0,0 +1,1085 @@ +from __future__ import annotations +import re +import typing +from itertools import product +from typing import Any, Callable + +import sympy +from sympy import Mul, Add, Pow, Rational, log, exp, sqrt, cos, sin, tan, asin, acos, acot, asec, acsc, sinh, cosh, tanh, asinh, \ + acosh, atanh, acoth, asech, acsch, expand, im, flatten, polylog, cancel, expand_trig, sign, simplify, \ + UnevaluatedExpr, S, atan, atan2, Mod, Max, Min, rf, Ei, Si, Ci, airyai, airyaiprime, airybi, primepi, prime, \ + isprime, cot, sec, csc, csch, sech, coth, Function, I, pi, Tuple, GreaterThan, StrictGreaterThan, StrictLessThan, \ + LessThan, Equality, Or, And, Lambda, Integer, Dummy, symbols +from sympy.core.sympify import sympify, _sympify +from sympy.functions.special.bessel import airybiprime +from sympy.functions.special.error_functions import li +from sympy.utilities.exceptions import sympy_deprecation_warning + + +def mathematica(s, additional_translations=None): + sympy_deprecation_warning( + """The ``mathematica`` function for the Mathematica parser is now +deprecated. Use ``parse_mathematica`` instead. +The parameter ``additional_translation`` can be replaced by SymPy's +.replace( ) or .subs( ) methods on the output expression instead.""", + deprecated_since_version="1.11", + active_deprecations_target="mathematica-parser-new", + ) + parser = MathematicaParser(additional_translations) + return sympify(parser._parse_old(s)) + + +def parse_mathematica(s): + """ + Translate a string containing a Wolfram Mathematica expression to a SymPy + expression. + + If the translator is unable to find a suitable SymPy expression, the + ``FullForm`` of the Mathematica expression will be output, using SymPy + ``Function`` objects as nodes of the syntax tree. + + Examples + ======== + + >>> from sympy.parsing.mathematica import parse_mathematica + >>> parse_mathematica("Sin[x]^2 Tan[y]") + sin(x)**2*tan(y) + >>> e = parse_mathematica("F[7,5,3]") + >>> e + F(7, 5, 3) + >>> from sympy import Function, Max, Min + >>> e.replace(Function("F"), lambda *x: Max(*x)*Min(*x)) + 21 + + Both standard input form and Mathematica full form are supported: + + >>> parse_mathematica("x*(a + b)") + x*(a + b) + >>> parse_mathematica("Times[x, Plus[a, b]]") + x*(a + b) + + To get a matrix from Wolfram's code: + + >>> m = parse_mathematica("{{a, b}, {c, d}}") + >>> m + ((a, b), (c, d)) + >>> from sympy import Matrix + >>> Matrix(m) + Matrix([ + [a, b], + [c, d]]) + + If the translation into equivalent SymPy expressions fails, an SymPy + expression equivalent to Wolfram Mathematica's "FullForm" will be created: + + >>> parse_mathematica("x_.") + Optional(Pattern(x, Blank())) + >>> parse_mathematica("Plus @@ {x, y, z}") + Apply(Plus, (x, y, z)) + >>> parse_mathematica("f[x_, 3] := x^3 /; x > 0") + SetDelayed(f(Pattern(x, Blank()), 3), Condition(x**3, x > 0)) + """ + parser = MathematicaParser() + return parser.parse(s) + + +def _parse_Function(*args): + if len(args) == 1: + arg = args[0] + Slot = Function("Slot") + slots = arg.atoms(Slot) + numbers = [a.args[0] for a in slots] + number_of_arguments = max(numbers) + if isinstance(number_of_arguments, Integer): + variables = symbols(f"dummy0:{number_of_arguments}", cls=Dummy) + return Lambda(variables, arg.xreplace({Slot(i+1): v for i, v in enumerate(variables)})) + return Lambda((), arg) + elif len(args) == 2: + variables = args[0] + body = args[1] + return Lambda(variables, body) + else: + raise SyntaxError("Function node expects 1 or 2 arguments") + + +def _deco(cls): + cls._initialize_class() + return cls + + +@_deco +class MathematicaParser: + """ + An instance of this class converts a string of a Wolfram Mathematica + expression to a SymPy expression. + + The main parser acts internally in three stages: + + 1. tokenizer: tokenizes the Mathematica expression and adds the missing * + operators. Handled by ``_from_mathematica_to_tokens(...)`` + 2. full form list: sort the list of strings output by the tokenizer into a + syntax tree of nested lists and strings, equivalent to Mathematica's + ``FullForm`` expression output. This is handled by the function + ``_from_tokens_to_fullformlist(...)``. + 3. SymPy expression: the syntax tree expressed as full form list is visited + and the nodes with equivalent classes in SymPy are replaced. Unknown + syntax tree nodes are cast to SymPy ``Function`` objects. This is + handled by ``_from_fullformlist_to_sympy(...)``. + + """ + + # left: Mathematica, right: SymPy + CORRESPONDENCES = { + 'Sqrt[x]': 'sqrt(x)', + 'Rational[x,y]': 'Rational(x,y)', + 'Exp[x]': 'exp(x)', + 'Log[x]': 'log(x)', + 'Log[x,y]': 'log(y,x)', + 'Log2[x]': 'log(x,2)', + 'Log10[x]': 'log(x,10)', + 'Mod[x,y]': 'Mod(x,y)', + 'Max[*x]': 'Max(*x)', + 'Min[*x]': 'Min(*x)', + 'Pochhammer[x,y]':'rf(x,y)', + 'ArcTan[x,y]':'atan2(y,x)', + 'ExpIntegralEi[x]': 'Ei(x)', + 'SinIntegral[x]': 'Si(x)', + 'CosIntegral[x]': 'Ci(x)', + 'AiryAi[x]': 'airyai(x)', + 'AiryAiPrime[x]': 'airyaiprime(x)', + 'AiryBi[x]' :'airybi(x)', + 'AiryBiPrime[x]' :'airybiprime(x)', + 'LogIntegral[x]':' li(x)', + 'PrimePi[x]': 'primepi(x)', + 'Prime[x]': 'prime(x)', + 'PrimeQ[x]': 'isprime(x)' + } + + # trigonometric, e.t.c. + for arc, tri, h in product(('', 'Arc'), ( + 'Sin', 'Cos', 'Tan', 'Cot', 'Sec', 'Csc'), ('', 'h')): + fm = arc + tri + h + '[x]' + if arc: # arc func + fs = 'a' + tri.lower() + h + '(x)' + else: # non-arc func + fs = tri.lower() + h + '(x)' + CORRESPONDENCES.update({fm: fs}) + + REPLACEMENTS = { + ' ': '', + '^': '**', + '{': '[', + '}': ']', + } + + RULES = { + # a single whitespace to '*' + 'whitespace': ( + re.compile(r''' + (?:(?<=[a-zA-Z\d])|(?<=\d\.)) # a letter or a number + \s+ # any number of whitespaces + (?:(?=[a-zA-Z\d])|(?=\.\d)) # a letter or a number + ''', re.VERBOSE), + '*'), + + # add omitted '*' character + 'add*_1': ( + re.compile(r''' + (?:(?<=[])\d])|(?<=\d\.)) # ], ) or a number + # '' + (?=[(a-zA-Z]) # ( or a single letter + ''', re.VERBOSE), + '*'), + + # add omitted '*' character (variable letter preceding) + 'add*_2': ( + re.compile(r''' + (?<=[a-zA-Z]) # a letter + \( # ( as a character + (?=.) # any characters + ''', re.VERBOSE), + '*('), + + # convert 'Pi' to 'pi' + 'Pi': ( + re.compile(r''' + (?: + \A|(?<=[^a-zA-Z]) + ) + Pi # 'Pi' is 3.14159... in Mathematica + (?=[^a-zA-Z]) + ''', re.VERBOSE), + 'pi'), + } + + # Mathematica function name pattern + FM_PATTERN = re.compile(r''' + (?: + \A|(?<=[^a-zA-Z]) # at the top or a non-letter + ) + [A-Z][a-zA-Z\d]* # Function + (?=\[) # [ as a character + ''', re.VERBOSE) + + # list or matrix pattern (for future usage) + ARG_MTRX_PATTERN = re.compile(r''' + \{.*\} + ''', re.VERBOSE) + + # regex string for function argument pattern + ARGS_PATTERN_TEMPLATE = r''' + (?: + \A|(?<=[^a-zA-Z]) + ) + {arguments} # model argument like x, y,... + (?=[^a-zA-Z]) + ''' + + # will contain transformed CORRESPONDENCES dictionary + TRANSLATIONS: dict[tuple[str, int], dict[str, Any]] = {} + + # cache for a raw users' translation dictionary + cache_original: dict[tuple[str, int], dict[str, Any]] = {} + + # cache for a compiled users' translation dictionary + cache_compiled: dict[tuple[str, int], dict[str, Any]] = {} + + @classmethod + def _initialize_class(cls): + # get a transformed CORRESPONDENCES dictionary + d = cls._compile_dictionary(cls.CORRESPONDENCES) + cls.TRANSLATIONS.update(d) + + def __init__(self, additional_translations=None): + self.translations = {} + + # update with TRANSLATIONS (class constant) + self.translations.update(self.TRANSLATIONS) + + if additional_translations is None: + additional_translations = {} + + # check the latest added translations + if self.__class__.cache_original != additional_translations: + if not isinstance(additional_translations, dict): + raise ValueError('The argument must be dict type') + + # get a transformed additional_translations dictionary + d = self._compile_dictionary(additional_translations) + + # update cache + self.__class__.cache_original = additional_translations + self.__class__.cache_compiled = d + + # merge user's own translations + self.translations.update(self.__class__.cache_compiled) + + @classmethod + def _compile_dictionary(cls, dic): + # for return + d = {} + + for fm, fs in dic.items(): + # check function form + cls._check_input(fm) + cls._check_input(fs) + + # uncover '*' hiding behind a whitespace + fm = cls._apply_rules(fm, 'whitespace') + fs = cls._apply_rules(fs, 'whitespace') + + # remove whitespace(s) + fm = cls._replace(fm, ' ') + fs = cls._replace(fs, ' ') + + # search Mathematica function name + m = cls.FM_PATTERN.search(fm) + + # if no-hit + if m is None: + err = "'{f}' function form is invalid.".format(f=fm) + raise ValueError(err) + + # get Mathematica function name like 'Log' + fm_name = m.group() + + # get arguments of Mathematica function + args, end = cls._get_args(m) + + # function side check. (e.g.) '2*Func[x]' is invalid. + if m.start() != 0 or end != len(fm): + err = "'{f}' function form is invalid.".format(f=fm) + raise ValueError(err) + + # check the last argument's 1st character + if args[-1][0] == '*': + key_arg = '*' + else: + key_arg = len(args) + + key = (fm_name, key_arg) + + # convert '*x' to '\\*x' for regex + re_args = [x if x[0] != '*' else '\\' + x for x in args] + + # for regex. Example: (?:(x|y|z)) + xyz = '(?:(' + '|'.join(re_args) + '))' + + # string for regex compile + patStr = cls.ARGS_PATTERN_TEMPLATE.format(arguments=xyz) + + pat = re.compile(patStr, re.VERBOSE) + + # update dictionary + d[key] = {} + d[key]['fs'] = fs # SymPy function template + d[key]['args'] = args # args are ['x', 'y'] for example + d[key]['pat'] = pat + + return d + + def _convert_function(self, s): + '''Parse Mathematica function to SymPy one''' + + # compiled regex object + pat = self.FM_PATTERN + + scanned = '' # converted string + cur = 0 # position cursor + while True: + m = pat.search(s) + + if m is None: + # append the rest of string + scanned += s + break + + # get Mathematica function name + fm = m.group() + + # get arguments, and the end position of fm function + args, end = self._get_args(m) + + # the start position of fm function + bgn = m.start() + + # convert Mathematica function to SymPy one + s = self._convert_one_function(s, fm, args, bgn, end) + + # update cursor + cur = bgn + + # append converted part + scanned += s[:cur] + + # shrink s + s = s[cur:] + + return scanned + + def _convert_one_function(self, s, fm, args, bgn, end): + # no variable-length argument + if (fm, len(args)) in self.translations: + key = (fm, len(args)) + + # x, y,... model arguments + x_args = self.translations[key]['args'] + + # make CORRESPONDENCES between model arguments and actual ones + d = dict(zip(x_args, args)) + + # with variable-length argument + elif (fm, '*') in self.translations: + key = (fm, '*') + + # x, y,..*args (model arguments) + x_args = self.translations[key]['args'] + + # make CORRESPONDENCES between model arguments and actual ones + d = {} + for i, x in enumerate(x_args): + if x[0] == '*': + d[x] = ','.join(args[i:]) + break + d[x] = args[i] + + # out of self.translations + else: + err = "'{f}' is out of the whitelist.".format(f=fm) + raise ValueError(err) + + # template string of converted function + template = self.translations[key]['fs'] + + # regex pattern for x_args + pat = self.translations[key]['pat'] + + scanned = '' + cur = 0 + while True: + m = pat.search(template) + + if m is None: + scanned += template + break + + # get model argument + x = m.group() + + # get a start position of the model argument + xbgn = m.start() + + # add the corresponding actual argument + scanned += template[:xbgn] + d[x] + + # update cursor to the end of the model argument + cur = m.end() + + # shrink template + template = template[cur:] + + # update to swapped string + s = s[:bgn] + scanned + s[end:] + + return s + + @classmethod + def _get_args(cls, m): + '''Get arguments of a Mathematica function''' + + s = m.string # whole string + anc = m.end() + 1 # pointing the first letter of arguments + square, curly = [], [] # stack for brackets + args = [] + + # current cursor + cur = anc + for i, c in enumerate(s[anc:], anc): + # extract one argument + if c == ',' and (not square) and (not curly): + args.append(s[cur:i]) # add an argument + cur = i + 1 # move cursor + + # handle list or matrix (for future usage) + if c == '{': + curly.append(c) + elif c == '}': + curly.pop() + + # seek corresponding ']' with skipping irrevant ones + if c == '[': + square.append(c) + elif c == ']': + if square: + square.pop() + else: # empty stack + args.append(s[cur:i]) + break + + # the next position to ']' bracket (the function end) + func_end = i + 1 + + return args, func_end + + @classmethod + def _replace(cls, s, bef): + aft = cls.REPLACEMENTS[bef] + s = s.replace(bef, aft) + return s + + @classmethod + def _apply_rules(cls, s, bef): + pat, aft = cls.RULES[bef] + return pat.sub(aft, s) + + @classmethod + def _check_input(cls, s): + for bracket in (('[', ']'), ('{', '}'), ('(', ')')): + if s.count(bracket[0]) != s.count(bracket[1]): + err = "'{f}' function form is invalid.".format(f=s) + raise ValueError(err) + + if '{' in s: + err = "Currently list is not supported." + raise ValueError(err) + + def _parse_old(self, s): + # input check + self._check_input(s) + + # uncover '*' hiding behind a whitespace + s = self._apply_rules(s, 'whitespace') + + # remove whitespace(s) + s = self._replace(s, ' ') + + # add omitted '*' character + s = self._apply_rules(s, 'add*_1') + s = self._apply_rules(s, 'add*_2') + + # translate function + s = self._convert_function(s) + + # '^' to '**' + s = self._replace(s, '^') + + # 'Pi' to 'pi' + s = self._apply_rules(s, 'Pi') + + # '{', '}' to '[', ']', respectively +# s = cls._replace(s, '{') # currently list is not taken into account +# s = cls._replace(s, '}') + + return s + + def parse(self, s): + s2 = self._from_mathematica_to_tokens(s) + s3 = self._from_tokens_to_fullformlist(s2) + s4 = self._from_fullformlist_to_sympy(s3) + return s4 + + INFIX = "Infix" + PREFIX = "Prefix" + POSTFIX = "Postfix" + FLAT = "Flat" + RIGHT = "Right" + LEFT = "Left" + + _mathematica_op_precedence: list[tuple[str, str | None, dict[str, str | Callable]]] = [ + (POSTFIX, None, {";": lambda x: x + ["Null"] if isinstance(x, list) and x and x[0] == "CompoundExpression" else ["CompoundExpression", x, "Null"]}), + (INFIX, FLAT, {";": "CompoundExpression"}), + (INFIX, RIGHT, {"=": "Set", ":=": "SetDelayed", "+=": "AddTo", "-=": "SubtractFrom", "*=": "TimesBy", "/=": "DivideBy"}), + (INFIX, LEFT, {"//": lambda x, y: [x, y]}), + (POSTFIX, None, {"&": "Function"}), + (INFIX, LEFT, {"/.": "ReplaceAll"}), + (INFIX, RIGHT, {"->": "Rule", ":>": "RuleDelayed"}), + (INFIX, LEFT, {"/;": "Condition"}), + (INFIX, FLAT, {"|": "Alternatives"}), + (POSTFIX, None, {"..": "Repeated", "...": "RepeatedNull"}), + (INFIX, FLAT, {"||": "Or"}), + (INFIX, FLAT, {"&&": "And"}), + (PREFIX, None, {"!": "Not"}), + (INFIX, FLAT, {"===": "SameQ", "=!=": "UnsameQ"}), + (INFIX, FLAT, {"==": "Equal", "!=": "Unequal", "<=": "LessEqual", "<": "Less", ">=": "GreaterEqual", ">": "Greater"}), + (INFIX, None, {";;": "Span"}), + (INFIX, FLAT, {"+": "Plus", "-": "Plus"}), + (INFIX, FLAT, {"*": "Times", "/": "Times"}), + (INFIX, FLAT, {".": "Dot"}), + (PREFIX, None, {"-": lambda x: MathematicaParser._get_neg(x), + "+": lambda x: x}), + (INFIX, RIGHT, {"^": "Power"}), + (INFIX, RIGHT, {"@@": "Apply", "/@": "Map", "//@": "MapAll", "@@@": lambda x, y: ["Apply", x, y, ["List", "1"]]}), + (POSTFIX, None, {"'": "Derivative", "!": "Factorial", "!!": "Factorial2", "--": "Decrement"}), + (INFIX, None, {"[": lambda x, y: [x, *y], "[[": lambda x, y: ["Part", x, *y]}), + (PREFIX, None, {"{": lambda x: ["List", *x], "(": lambda x: x[0]}), + (INFIX, None, {"?": "PatternTest"}), + (POSTFIX, None, { + "_": lambda x: ["Pattern", x, ["Blank"]], + "_.": lambda x: ["Optional", ["Pattern", x, ["Blank"]]], + "__": lambda x: ["Pattern", x, ["BlankSequence"]], + "___": lambda x: ["Pattern", x, ["BlankNullSequence"]], + }), + (INFIX, None, {"_": lambda x, y: ["Pattern", x, ["Blank", y]]}), + (PREFIX, None, {"#": "Slot", "##": "SlotSequence"}), + ] + + _missing_arguments_default = { + "#": lambda: ["Slot", "1"], + "##": lambda: ["SlotSequence", "1"], + } + + _literal = r"[A-Za-z][A-Za-z0-9]*" + _number = r"(?:[0-9]+(?:\.[0-9]*)?|\.[0-9]+)" + + _enclosure_open = ["(", "[", "[[", "{"] + _enclosure_close = [")", "]", "]]", "}"] + + @classmethod + def _get_neg(cls, x): + return f"-{x}" if isinstance(x, str) and re.match(MathematicaParser._number, x) else ["Times", "-1", x] + + @classmethod + def _get_inv(cls, x): + return ["Power", x, "-1"] + + _regex_tokenizer = None + + def _get_tokenizer(self): + if self._regex_tokenizer is not None: + # Check if the regular expression has already been compiled: + return self._regex_tokenizer + tokens = [self._literal, self._number] + tokens_escape = self._enclosure_open[:] + self._enclosure_close[:] + for typ, strat, symdict in self._mathematica_op_precedence: + for k in symdict: + tokens_escape.append(k) + tokens_escape.sort(key=lambda x: -len(x)) + tokens.extend(map(re.escape, tokens_escape)) + tokens.append(",") + tokens.append("\n") + tokenizer = re.compile("(" + "|".join(tokens) + ")") + self._regex_tokenizer = tokenizer + return self._regex_tokenizer + + def _from_mathematica_to_tokens(self, code: str): + tokenizer = self._get_tokenizer() + + # Find strings: + code_splits: list[str | list] = [] + while True: + string_start = code.find("\"") + if string_start == -1: + if len(code) > 0: + code_splits.append(code) + break + match_end = re.search(r'(? 0: + code_splits.append(code[:string_start]) + code_splits.append(["_Str", code[string_start+1:string_end].replace('\\"', '"')]) + code = code[string_end+1:] + + # Remove comments: + for i, code_split in enumerate(code_splits): + if isinstance(code_split, list): + continue + while True: + pos_comment_start = code_split.find("(*") + if pos_comment_start == -1: + break + pos_comment_end = code_split.find("*)") + if pos_comment_end == -1 or pos_comment_end < pos_comment_start: + raise SyntaxError("mismatch in comment (* *) code") + code_split = code_split[:pos_comment_start] + code_split[pos_comment_end+2:] + code_splits[i] = code_split + + # Tokenize the input strings with a regular expression: + token_lists = [tokenizer.findall(i) if isinstance(i, str) and i.isascii() else [i] for i in code_splits] + tokens = [j for i in token_lists for j in i] + + # Remove newlines at the beginning + while tokens and tokens[0] == "\n": + tokens.pop(0) + # Remove newlines at the end + while tokens and tokens[-1] == "\n": + tokens.pop(-1) + + return tokens + + def _is_op(self, token: str | list) -> bool: + if isinstance(token, list): + return False + if re.match(self._literal, token): + return False + if re.match("-?" + self._number, token): + return False + return True + + def _is_valid_star1(self, token: str | list) -> bool: + if token in (")", "}"): + return True + return not self._is_op(token) + + def _is_valid_star2(self, token: str | list) -> bool: + if token in ("(", "{"): + return True + return not self._is_op(token) + + def _from_tokens_to_fullformlist(self, tokens: list): + stack: list[list] = [[]] + open_seq = [] + pointer: int = 0 + while pointer < len(tokens): + token = tokens[pointer] + if token in self._enclosure_open: + stack[-1].append(token) + open_seq.append(token) + stack.append([]) + elif token == ",": + if len(stack[-1]) == 0 and stack[-2][-1] == open_seq[-1]: + raise SyntaxError("%s cannot be followed by comma ," % open_seq[-1]) + stack[-1] = self._parse_after_braces(stack[-1]) + stack.append([]) + elif token in self._enclosure_close: + ind = self._enclosure_close.index(token) + if self._enclosure_open[ind] != open_seq[-1]: + unmatched_enclosure = SyntaxError("unmatched enclosure") + if token == "]]" and open_seq[-1] == "[": + if open_seq[-2] == "[": + # These two lines would be logically correct, but are + # unnecessary: + # token = "]" + # tokens[pointer] = "]" + tokens.insert(pointer+1, "]") + elif open_seq[-2] == "[[": + if tokens[pointer+1] == "]": + tokens[pointer+1] = "]]" + elif tokens[pointer+1] == "]]": + tokens[pointer+1] = "]]" + tokens.insert(pointer+2, "]") + else: + raise unmatched_enclosure + else: + raise unmatched_enclosure + if len(stack[-1]) == 0 and stack[-2][-1] == "(": + raise SyntaxError("( ) not valid syntax") + last_stack = self._parse_after_braces(stack[-1], True) + stack[-1] = last_stack + new_stack_element = [] + while stack[-1][-1] != open_seq[-1]: + new_stack_element.append(stack.pop()) + new_stack_element.reverse() + if open_seq[-1] == "(" and len(new_stack_element) != 1: + raise SyntaxError("( must be followed by one expression, %i detected" % len(new_stack_element)) + stack[-1].append(new_stack_element) + open_seq.pop(-1) + else: + stack[-1].append(token) + pointer += 1 + if len(stack) != 1: + raise RuntimeError("Stack should have only one element") + return self._parse_after_braces(stack[0]) + + def _util_remove_newlines(self, lines: list, tokens: list, inside_enclosure: bool): + pointer = 0 + size = len(tokens) + while pointer < size: + token = tokens[pointer] + if token == "\n": + if inside_enclosure: + # Ignore newlines inside enclosures + tokens.pop(pointer) + size -= 1 + continue + if pointer == 0: + tokens.pop(0) + size -= 1 + continue + if pointer > 1: + try: + prev_expr = self._parse_after_braces(tokens[:pointer], inside_enclosure) + except SyntaxError: + tokens.pop(pointer) + size -= 1 + continue + else: + prev_expr = tokens[0] + if len(prev_expr) > 0 and prev_expr[0] == "CompoundExpression": + lines.extend(prev_expr[1:]) + else: + lines.append(prev_expr) + for i in range(pointer): + tokens.pop(0) + size -= pointer + pointer = 0 + continue + pointer += 1 + + def _util_add_missing_asterisks(self, tokens: list): + size: int = len(tokens) + pointer: int = 0 + while pointer < size: + if (pointer > 0 and + self._is_valid_star1(tokens[pointer - 1]) and + self._is_valid_star2(tokens[pointer])): + # This is a trick to add missing * operators in the expression, + # `"*" in op_dict` makes sure the precedence level is the same as "*", + # while `not self._is_op( ... )` makes sure this and the previous + # expression are not operators. + if tokens[pointer] == "(": + # ( has already been processed by now, replace: + tokens[pointer] = "*" + tokens[pointer + 1] = tokens[pointer + 1][0] + else: + tokens.insert(pointer, "*") + pointer += 1 + size += 1 + pointer += 1 + + def _parse_after_braces(self, tokens: list, inside_enclosure: bool = False): + op_dict: dict + changed: bool = False + lines: list = [] + + self._util_remove_newlines(lines, tokens, inside_enclosure) + + for op_type, grouping_strat, op_dict in reversed(self._mathematica_op_precedence): + if "*" in op_dict: + self._util_add_missing_asterisks(tokens) + size: int = len(tokens) + pointer: int = 0 + while pointer < size: + token = tokens[pointer] + if isinstance(token, str) and token in op_dict: + op_name: str | Callable = op_dict[token] + node: list + first_index: int + if isinstance(op_name, str): + node = [op_name] + first_index = 1 + else: + node = [] + first_index = 0 + if token in ("+", "-") and op_type == self.PREFIX and pointer > 0 and not self._is_op(tokens[pointer - 1]): + # Make sure that PREFIX + - don't match expressions like a + b or a - b, + # the INFIX + - are supposed to match that expression: + pointer += 1 + continue + if op_type == self.INFIX: + if pointer == 0 or pointer == size - 1 or self._is_op(tokens[pointer - 1]) or self._is_op(tokens[pointer + 1]): + pointer += 1 + continue + changed = True + tokens[pointer] = node + if op_type == self.INFIX: + arg1 = tokens.pop(pointer-1) + arg2 = tokens.pop(pointer) + if token == "/": + arg2 = self._get_inv(arg2) + elif token == "-": + arg2 = self._get_neg(arg2) + pointer -= 1 + size -= 2 + node.append(arg1) + node_p = node + if grouping_strat == self.FLAT: + while pointer + 2 < size and self._check_op_compatible(tokens[pointer+1], token): + node_p.append(arg2) + other_op = tokens.pop(pointer+1) + arg2 = tokens.pop(pointer+1) + if other_op == "/": + arg2 = self._get_inv(arg2) + elif other_op == "-": + arg2 = self._get_neg(arg2) + size -= 2 + node_p.append(arg2) + elif grouping_strat == self.RIGHT: + while pointer + 2 < size and tokens[pointer+1] == token: + node_p.append([op_name, arg2]) + node_p = node_p[-1] + tokens.pop(pointer+1) + arg2 = tokens.pop(pointer+1) + size -= 2 + node_p.append(arg2) + elif grouping_strat == self.LEFT: + while pointer + 1 < size and tokens[pointer+1] == token: + if isinstance(op_name, str): + node_p[first_index] = [op_name, node_p[first_index], arg2] + else: + node_p[first_index] = op_name(node_p[first_index], arg2) + tokens.pop(pointer+1) + arg2 = tokens.pop(pointer+1) + size -= 2 + node_p.append(arg2) + else: + node.append(arg2) + elif op_type == self.PREFIX: + if grouping_strat is not None: + raise TypeError("'Prefix' op_type should not have a grouping strat") + if pointer == size - 1 or self._is_op(tokens[pointer + 1]): + tokens[pointer] = self._missing_arguments_default[token]() + else: + node.append(tokens.pop(pointer+1)) + size -= 1 + elif op_type == self.POSTFIX: + if grouping_strat is not None: + raise TypeError("'Prefix' op_type should not have a grouping strat") + if pointer == 0 or self._is_op(tokens[pointer - 1]): + tokens[pointer] = self._missing_arguments_default[token]() + else: + node.append(tokens.pop(pointer-1)) + pointer -= 1 + size -= 1 + if isinstance(op_name, Callable): # type: ignore + op_call: Callable = typing.cast(Callable, op_name) + new_node = op_call(*node) + node.clear() + if isinstance(new_node, list): + node.extend(new_node) + else: + tokens[pointer] = new_node + pointer += 1 + if len(tokens) > 1 or (len(lines) == 0 and len(tokens) == 0): + if changed: + # Trick to deal with cases in which an operator with lower + # precedence should be transformed before an operator of higher + # precedence. Such as in the case of `#&[x]` (that is + # equivalent to `Lambda(d_, d_)(x)` in SymPy). In this case the + # operator `&` has lower precedence than `[`, but needs to be + # evaluated first because otherwise `# (&[x])` is not a valid + # expression: + return self._parse_after_braces(tokens, inside_enclosure) + raise SyntaxError("unable to create a single AST for the expression") + if len(lines) > 0: + if tokens[0] and tokens[0][0] == "CompoundExpression": + tokens = tokens[0][1:] + compound_expression = ["CompoundExpression", *lines, *tokens] + return compound_expression + return tokens[0] + + def _check_op_compatible(self, op1: str, op2: str): + if op1 == op2: + return True + muldiv = {"*", "/"} + addsub = {"+", "-"} + if op1 in muldiv and op2 in muldiv: + return True + if op1 in addsub and op2 in addsub: + return True + return False + + def _from_fullform_to_fullformlist(self, wmexpr: str): + """ + Parses FullForm[Downvalues[]] generated by Mathematica + """ + out: list = [] + stack = [out] + generator = re.finditer(r'[\[\],]', wmexpr) + last_pos = 0 + for match in generator: + if match is None: + break + position = match.start() + last_expr = wmexpr[last_pos:position].replace(',', '').replace(']', '').replace('[', '').strip() + + if match.group() == ',': + if last_expr != '': + stack[-1].append(last_expr) + elif match.group() == ']': + if last_expr != '': + stack[-1].append(last_expr) + stack.pop() + elif match.group() == '[': + stack[-1].append([last_expr]) + stack.append(stack[-1][-1]) + last_pos = match.end() + return out[0] + + def _from_fullformlist_to_fullformsympy(self, pylist: list): + from sympy import Function, Symbol + + def converter(expr): + if isinstance(expr, list): + if len(expr) > 0: + head = expr[0] + args = [converter(arg) for arg in expr[1:]] + return Function(head)(*args) + else: + raise ValueError("Empty list of expressions") + elif isinstance(expr, str): + return Symbol(expr) + else: + return _sympify(expr) + + return converter(pylist) + + _node_conversions = { + "Times": Mul, + "Plus": Add, + "Power": Pow, + "Rational": Rational, + "Log": lambda *a: log(*reversed(a)), + "Log2": lambda x: log(x, 2), + "Log10": lambda x: log(x, 10), + "Exp": exp, + "Sqrt": sqrt, + + "Sin": sin, + "Cos": cos, + "Tan": tan, + "Cot": cot, + "Sec": sec, + "Csc": csc, + + "ArcSin": asin, + "ArcCos": acos, + "ArcTan": lambda *a: atan2(*reversed(a)) if len(a) == 2 else atan(*a), + "ArcCot": acot, + "ArcSec": asec, + "ArcCsc": acsc, + + "Sinh": sinh, + "Cosh": cosh, + "Tanh": tanh, + "Coth": coth, + "Sech": sech, + "Csch": csch, + + "ArcSinh": asinh, + "ArcCosh": acosh, + "ArcTanh": atanh, + "ArcCoth": acoth, + "ArcSech": asech, + "ArcCsch": acsch, + + "Expand": expand, + "Im": im, + "Re": sympy.re, + "Flatten": flatten, + "Polylog": polylog, + "Cancel": cancel, + # Gamma=gamma, + "TrigExpand": expand_trig, + "Sign": sign, + "Simplify": simplify, + "Defer": UnevaluatedExpr, + "Identity": S, + # Sum=Sum_doit, + # Module=With, + # Block=With, + "Null": lambda *a: S.Zero, + "Mod": Mod, + "Max": Max, + "Min": Min, + "Pochhammer": rf, + "ExpIntegralEi": Ei, + "SinIntegral": Si, + "CosIntegral": Ci, + "AiryAi": airyai, + "AiryAiPrime": airyaiprime, + "AiryBi": airybi, + "AiryBiPrime": airybiprime, + "LogIntegral": li, + "PrimePi": primepi, + "Prime": prime, + "PrimeQ": isprime, + + "List": Tuple, + "Greater": StrictGreaterThan, + "GreaterEqual": GreaterThan, + "Less": StrictLessThan, + "LessEqual": LessThan, + "Equal": Equality, + "Or": Or, + "And": And, + + "Function": _parse_Function, + } + + _atom_conversions = { + "I": I, + "Pi": pi, + } + + def _from_fullformlist_to_sympy(self, full_form_list): + + def recurse(expr): + if isinstance(expr, list): + if isinstance(expr[0], list): + head = recurse(expr[0]) + else: + head = self._node_conversions.get(expr[0], Function(expr[0])) + return head(*[recurse(arg) for arg in expr[1:]]) + else: + return self._atom_conversions.get(expr, sympify(expr)) + + return recurse(full_form_list) + + def _from_fullformsympy_to_sympy(self, mform): + + expr = mform + for mma_form, sympy_node in self._node_conversions.items(): + expr = expr.replace(Function(mma_form), sympy_node) + return expr diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/maxima.py b/.venv/lib/python3.13/site-packages/sympy/parsing/maxima.py new file mode 100644 index 0000000000000000000000000000000000000000..7a8ee5b17bb03a36e338803cb10f9ebf22763c2c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/maxima.py @@ -0,0 +1,71 @@ +import re +from sympy.concrete.products import product +from sympy.concrete.summations import Sum +from sympy.core.sympify import sympify +from sympy.functions.elementary.trigonometric import (cos, sin) + + +class MaximaHelpers: + def maxima_expand(expr): + return expr.expand() + + def maxima_float(expr): + return expr.evalf() + + def maxima_trigexpand(expr): + return expr.expand(trig=True) + + def maxima_sum(a1, a2, a3, a4): + return Sum(a1, (a2, a3, a4)).doit() + + def maxima_product(a1, a2, a3, a4): + return product(a1, (a2, a3, a4)) + + def maxima_csc(expr): + return 1/sin(expr) + + def maxima_sec(expr): + return 1/cos(expr) + +sub_dict = { + 'pi': re.compile(r'%pi'), + 'E': re.compile(r'%e'), + 'I': re.compile(r'%i'), + '**': re.compile(r'\^'), + 'oo': re.compile(r'\binf\b'), + '-oo': re.compile(r'\bminf\b'), + "'-'": re.compile(r'\bminus\b'), + 'maxima_expand': re.compile(r'\bexpand\b'), + 'maxima_float': re.compile(r'\bfloat\b'), + 'maxima_trigexpand': re.compile(r'\btrigexpand'), + 'maxima_sum': re.compile(r'\bsum\b'), + 'maxima_product': re.compile(r'\bproduct\b'), + 'cancel': re.compile(r'\bratsimp\b'), + 'maxima_csc': re.compile(r'\bcsc\b'), + 'maxima_sec': re.compile(r'\bsec\b') +} + +var_name = re.compile(r'^\s*(\w+)\s*:') + + +def parse_maxima(str, globals=None, name_dict={}): + str = str.strip() + str = str.rstrip('; ') + + for k, v in sub_dict.items(): + str = v.sub(k, str) + + assign_var = None + var_match = var_name.search(str) + if var_match: + assign_var = var_match.group(1) + str = str[var_match.end():].strip() + + dct = MaximaHelpers.__dict__.copy() + dct.update(name_dict) + obj = sympify(str, locals=dct) + + if assign_var and globals: + globals[assign_var] = obj + + return obj diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/sym_expr.py b/.venv/lib/python3.13/site-packages/sympy/parsing/sym_expr.py new file mode 100644 index 0000000000000000000000000000000000000000..9dbd0e94eb51147b51825fcf15cbec5ae18bb1b6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/sym_expr.py @@ -0,0 +1,279 @@ +from sympy.printing import pycode, ccode, fcode +from sympy.external import import_module +from sympy.utilities.decorator import doctest_depends_on + +lfortran = import_module('lfortran') +cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']}) + +if lfortran: + from sympy.parsing.fortran.fortran_parser import src_to_sympy +if cin: + from sympy.parsing.c.c_parser import parse_c + +@doctest_depends_on(modules=['lfortran', 'clang.cindex']) +class SymPyExpression: # type: ignore + """Class to store and handle SymPy expressions + + This class will hold SymPy Expressions and handle the API for the + conversion to and from different languages. + + It works with the C and the Fortran Parser to generate SymPy expressions + which are stored here and which can be converted to multiple language's + source code. + + Notes + ===== + + The module and its API are currently under development and experimental + and can be changed during development. + + The Fortran parser does not support numeric assignments, so all the + variables have been Initialized to zero. + + The module also depends on external dependencies: + + - LFortran which is required to use the Fortran parser + - Clang which is required for the C parser + + Examples + ======== + + Example of parsing C code: + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src = ''' + ... int a,b; + ... float c = 2, d =4; + ... ''' + >>> a = SymPyExpression(src, 'c') + >>> a.return_expr() + [Declaration(Variable(a, type=intc)), + Declaration(Variable(b, type=intc)), + Declaration(Variable(c, type=float32, value=2.0)), + Declaration(Variable(d, type=float32, value=4.0))] + + An example of variable definition: + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src2 = ''' + ... integer :: a, b, c, d + ... real :: p, q, r, s + ... ''' + >>> p = SymPyExpression() + >>> p.convert_to_expr(src2, 'f') + >>> p.convert_to_c() + ['int a = 0', 'int b = 0', 'int c = 0', 'int d = 0', 'double p = 0.0', 'double q = 0.0', 'double r = 0.0', 'double s = 0.0'] + + An example of Assignment: + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src3 = ''' + ... integer :: a, b, c, d, e + ... d = a + b - c + ... e = b * d + c * e / a + ... ''' + >>> p = SymPyExpression(src3, 'f') + >>> p.convert_to_python() + ['a = 0', 'b = 0', 'c = 0', 'd = 0', 'e = 0', 'd = a + b - c', 'e = b*d + c*e/a'] + + An example of function definition: + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src = ''' + ... integer function f(a,b) + ... integer, intent(in) :: a, b + ... integer :: r + ... end function + ... ''' + >>> a = SymPyExpression(src, 'f') + >>> a.convert_to_python() + ['def f(a, b):\\n f = 0\\n r = 0\\n return f'] + + """ + + def __init__(self, source_code = None, mode = None): + """Constructor for SymPyExpression class""" + super().__init__() + if not(mode or source_code): + self._expr = [] + elif mode: + if source_code: + if mode.lower() == 'f': + if lfortran: + self._expr = src_to_sympy(source_code) + else: + raise ImportError("LFortran is not installed, cannot parse Fortran code") + elif mode.lower() == 'c': + if cin: + self._expr = parse_c(source_code) + else: + raise ImportError("Clang is not installed, cannot parse C code") + else: + raise NotImplementedError( + 'Parser for specified language is not implemented' + ) + else: + raise ValueError('Source code not present') + else: + raise ValueError('Please specify a mode for conversion') + + def convert_to_expr(self, src_code, mode): + """Converts the given source code to SymPy Expressions + + Attributes + ========== + + src_code : String + the source code or filename of the source code that is to be + converted + + mode: String + the mode to determine which parser is to be used according to + the language of the source code + f or F for Fortran + c or C for C/C++ + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src3 = ''' + ... integer function f(a,b) result(r) + ... integer, intent(in) :: a, b + ... integer :: x + ... r = a + b -x + ... end function + ... ''' + >>> p = SymPyExpression() + >>> p.convert_to_expr(src3, 'f') + >>> p.return_expr() + [FunctionDefinition(integer, name=f, parameters=(Variable(a), Variable(b)), body=CodeBlock( + Declaration(Variable(r, type=integer, value=0)), + Declaration(Variable(x, type=integer, value=0)), + Assignment(Variable(r), a + b - x), + Return(Variable(r)) + ))] + + + + + """ + if mode.lower() == 'f': + if lfortran: + self._expr = src_to_sympy(src_code) + else: + raise ImportError("LFortran is not installed, cannot parse Fortran code") + elif mode.lower() == 'c': + if cin: + self._expr = parse_c(src_code) + else: + raise ImportError("Clang is not installed, cannot parse C code") + else: + raise NotImplementedError( + "Parser for specified language has not been implemented" + ) + + def convert_to_python(self): + """Returns a list with Python code for the SymPy expressions + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src2 = ''' + ... integer :: a, b, c, d + ... real :: p, q, r, s + ... c = a/b + ... d = c/a + ... s = p/q + ... r = q/p + ... ''' + >>> p = SymPyExpression(src2, 'f') + >>> p.convert_to_python() + ['a = 0', 'b = 0', 'c = 0', 'd = 0', 'p = 0.0', 'q = 0.0', 'r = 0.0', 's = 0.0', 'c = a/b', 'd = c/a', 's = p/q', 'r = q/p'] + + """ + self._pycode = [] + for iter in self._expr: + self._pycode.append(pycode(iter)) + return self._pycode + + def convert_to_c(self): + """Returns a list with the c source code for the SymPy expressions + + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src2 = ''' + ... integer :: a, b, c, d + ... real :: p, q, r, s + ... c = a/b + ... d = c/a + ... s = p/q + ... r = q/p + ... ''' + >>> p = SymPyExpression() + >>> p.convert_to_expr(src2, 'f') + >>> p.convert_to_c() + ['int a = 0', 'int b = 0', 'int c = 0', 'int d = 0', 'double p = 0.0', 'double q = 0.0', 'double r = 0.0', 'double s = 0.0', 'c = a/b;', 'd = c/a;', 's = p/q;', 'r = q/p;'] + + """ + self._ccode = [] + for iter in self._expr: + self._ccode.append(ccode(iter)) + return self._ccode + + def convert_to_fortran(self): + """Returns a list with the fortran source code for the SymPy expressions + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src2 = ''' + ... integer :: a, b, c, d + ... real :: p, q, r, s + ... c = a/b + ... d = c/a + ... s = p/q + ... r = q/p + ... ''' + >>> p = SymPyExpression(src2, 'f') + >>> p.convert_to_fortran() + [' integer*4 a', ' integer*4 b', ' integer*4 c', ' integer*4 d', ' real*8 p', ' real*8 q', ' real*8 r', ' real*8 s', ' c = a/b', ' d = c/a', ' s = p/q', ' r = q/p'] + + """ + self._fcode = [] + for iter in self._expr: + self._fcode.append(fcode(iter)) + return self._fcode + + def return_expr(self): + """Returns the expression list + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src3 = ''' + ... integer function f(a,b) + ... integer, intent(in) :: a, b + ... integer :: r + ... r = a+b + ... f = r + ... end function + ... ''' + >>> p = SymPyExpression() + >>> p.convert_to_expr(src3, 'f') + >>> p.return_expr() + [FunctionDefinition(integer, name=f, parameters=(Variable(a), Variable(b)), body=CodeBlock( + Declaration(Variable(f, type=integer, value=0)), + Declaration(Variable(r, type=integer, value=0)), + Assignment(Variable(f), Variable(r)), + Return(Variable(f)) + ))] + + """ + return self._expr diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/sympy_parser.py b/.venv/lib/python3.13/site-packages/sympy/parsing/sympy_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfda9ce0f73ffa3773031c48b9e9c245f69fe0b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/sympy_parser.py @@ -0,0 +1,1270 @@ +"""Transform a string with Python-like source code into SymPy expression. """ +from __future__ import annotations +from tokenize import (generate_tokens, untokenize, TokenError, + NUMBER, STRING, NAME, OP, ENDMARKER, ERRORTOKEN, NEWLINE) + +from keyword import iskeyword + +import ast +import unicodedata +from io import StringIO +import builtins +import types +from typing import Any, Callable +from functools import reduce +from sympy.assumptions.ask import AssumptionKeys +from sympy.core.basic import Basic +from sympy.core import Symbol +from sympy.core.function import Function +from sympy.utilities.misc import func_name +from sympy.functions.elementary.miscellaneous import Max, Min + + +null = '' + +TOKEN = tuple[int, str] +DICT = dict[str, Any] +TRANS = Callable[[list[TOKEN], DICT, DICT], list[TOKEN]] + +def _token_splittable(token_name: str) -> bool: + """ + Predicate for whether a token name can be split into multiple tokens. + + A token is splittable if it does not contain an underscore character and + it is not the name of a Greek letter. This is used to implicitly convert + expressions like 'xyz' into 'x*y*z'. + """ + if '_' in token_name: + return False + try: + return not unicodedata.lookup('GREEK SMALL LETTER ' + token_name) + except KeyError: + return len(token_name) > 1 + + +def _token_callable(token: TOKEN, local_dict: DICT, global_dict: DICT, nextToken=None): + """ + Predicate for whether a token name represents a callable function. + + Essentially wraps ``callable``, but looks up the token name in the + locals and globals. + """ + func = local_dict.get(token[1]) + if not func: + func = global_dict.get(token[1]) + return callable(func) and not isinstance(func, Symbol) + + +def _add_factorial_tokens(name: str, result: list[TOKEN]) -> list[TOKEN]: + if result == [] or result[-1][1] == '(': + raise TokenError() + + beginning = [(NAME, name), (OP, '(')] + end = [(OP, ')')] + + diff = 0 + length = len(result) + + for index, token in enumerate(result[::-1]): + toknum, tokval = token + i = length - index - 1 + + if tokval == ')': + diff += 1 + elif tokval == '(': + diff -= 1 + + if diff == 0: + if i - 1 >= 0 and result[i - 1][0] == NAME: + return result[:i - 1] + beginning + result[i - 1:] + end + else: + return result[:i] + beginning + result[i:] + end + + return result + + +class ParenthesisGroup(list[TOKEN]): + """List of tokens representing an expression in parentheses.""" + pass + + +class AppliedFunction: + """ + A group of tokens representing a function and its arguments. + + `exponent` is for handling the shorthand sin^2, ln^2, etc. + """ + def __init__(self, function: TOKEN, args: ParenthesisGroup, exponent=None): + if exponent is None: + exponent = [] + self.function = function + self.args = args + self.exponent = exponent + self.items = ['function', 'args', 'exponent'] + + def expand(self) -> list[TOKEN]: + """Return a list of tokens representing the function""" + return [self.function, *self.args] + + def __getitem__(self, index): + return getattr(self, self.items[index]) + + def __repr__(self): + return "AppliedFunction(%s, %s, %s)" % (self.function, self.args, + self.exponent) + + +def _flatten(result: list[TOKEN | AppliedFunction]): + result2: list[TOKEN] = [] + for tok in result: + if isinstance(tok, AppliedFunction): + result2.extend(tok.expand()) + else: + result2.append(tok) + return result2 + + +def _group_parentheses(recursor: TRANS): + def _inner(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Group tokens between parentheses with ParenthesisGroup. + + Also processes those tokens recursively. + + """ + result: list[TOKEN | ParenthesisGroup] = [] + stacks: list[ParenthesisGroup] = [] + stacklevel = 0 + for token in tokens: + if token[0] == OP: + if token[1] == '(': + stacks.append(ParenthesisGroup([])) + stacklevel += 1 + elif token[1] == ')': + stacks[-1].append(token) + stack = stacks.pop() + + if len(stacks) > 0: + # We don't recurse here since the upper-level stack + # would reprocess these tokens + stacks[-1].extend(stack) + else: + # Recurse here to handle nested parentheses + # Strip off the outer parentheses to avoid an infinite loop + inner = stack[1:-1] + inner = recursor(inner, + local_dict, + global_dict) + parenGroup = [stack[0]] + inner + [stack[-1]] + result.append(ParenthesisGroup(parenGroup)) + stacklevel -= 1 + continue + if stacklevel: + stacks[-1].append(token) + else: + result.append(token) + if stacklevel: + raise TokenError("Mismatched parentheses") + return result + return _inner + + +def _apply_functions(tokens: list[TOKEN | ParenthesisGroup], local_dict: DICT, global_dict: DICT): + """Convert a NAME token + ParenthesisGroup into an AppliedFunction. + + Note that ParenthesisGroups, if not applied to any function, are + converted back into lists of tokens. + + """ + result: list[TOKEN | AppliedFunction] = [] + symbol = None + for tok in tokens: + if isinstance(tok, ParenthesisGroup): + if symbol and _token_callable(symbol, local_dict, global_dict): + result[-1] = AppliedFunction(symbol, tok) + symbol = None + else: + result.extend(tok) + elif tok[0] == NAME: + symbol = tok + result.append(tok) + else: + symbol = None + result.append(tok) + return result + + +def _implicit_multiplication(tokens: list[TOKEN | AppliedFunction], local_dict: DICT, global_dict: DICT): + """Implicitly adds '*' tokens. + + Cases: + + - Two AppliedFunctions next to each other ("sin(x)cos(x)") + + - AppliedFunction next to an open parenthesis ("sin x (cos x + 1)") + + - A close parenthesis next to an AppliedFunction ("(x+2)sin x")\ + + - A close parenthesis next to an open parenthesis ("(x+2)(x+3)") + + - AppliedFunction next to an implicitly applied function ("sin(x)cos x") + + """ + result: list[TOKEN | AppliedFunction] = [] + skip = False + for tok, nextTok in zip(tokens, tokens[1:]): + result.append(tok) + if skip: + skip = False + continue + if tok[0] == OP and tok[1] == '.' and nextTok[0] == NAME: + # Dotted name. Do not do implicit multiplication + skip = True + continue + if isinstance(tok, AppliedFunction): + if isinstance(nextTok, AppliedFunction): + result.append((OP, '*')) + elif nextTok == (OP, '('): + # Applied function followed by an open parenthesis + if tok.function[1] == "Function": + tok.function = (tok.function[0], 'Symbol') + result.append((OP, '*')) + elif nextTok[0] == NAME: + # Applied function followed by implicitly applied function + result.append((OP, '*')) + else: + if tok == (OP, ')'): + if isinstance(nextTok, AppliedFunction): + # Close parenthesis followed by an applied function + result.append((OP, '*')) + elif nextTok[0] == NAME: + # Close parenthesis followed by an implicitly applied function + result.append((OP, '*')) + elif nextTok == (OP, '('): + # Close parenthesis followed by an open parenthesis + result.append((OP, '*')) + elif tok[0] == NAME and not _token_callable(tok, local_dict, global_dict): + if isinstance(nextTok, AppliedFunction) or \ + (nextTok[0] == NAME and _token_callable(nextTok, local_dict, global_dict)): + # Constant followed by (implicitly applied) function + result.append((OP, '*')) + elif nextTok == (OP, '('): + # Constant followed by parenthesis + result.append((OP, '*')) + elif nextTok[0] == NAME: + # Constant followed by constant + result.append((OP, '*')) + if tokens: + result.append(tokens[-1]) + return result + + +def _implicit_application(tokens: list[TOKEN | AppliedFunction], local_dict: DICT, global_dict: DICT): + """Adds parentheses as needed after functions.""" + result: list[TOKEN | AppliedFunction] = [] + appendParen = 0 # number of closing parentheses to add + skip = 0 # number of tokens to delay before adding a ')' (to + # capture **, ^, etc.) + exponentSkip = False # skipping tokens before inserting parentheses to + # work with function exponentiation + for tok, nextTok in zip(tokens, tokens[1:]): + result.append(tok) + if (tok[0] == NAME and nextTok[0] not in [OP, ENDMARKER, NEWLINE]): + if _token_callable(tok, local_dict, global_dict, nextTok): # type: ignore + result.append((OP, '(')) + appendParen += 1 + # name followed by exponent - function exponentiation + elif (tok[0] == NAME and nextTok[0] == OP and nextTok[1] == '**'): + if _token_callable(tok, local_dict, global_dict): # type: ignore + exponentSkip = True + elif exponentSkip: + # if the last token added was an applied function (i.e. the + # power of the function exponent) OR a multiplication (as + # implicit multiplication would have added an extraneous + # multiplication) + if (isinstance(tok, AppliedFunction) + or (tok[0] == OP and tok[1] == '*')): + # don't add anything if the next token is a multiplication + # or if there's already a parenthesis (if parenthesis, still + # stop skipping tokens) + if not (nextTok[0] == OP and nextTok[1] == '*'): + if not(nextTok[0] == OP and nextTok[1] == '('): + result.append((OP, '(')) + appendParen += 1 + exponentSkip = False + elif appendParen: + if nextTok[0] == OP and nextTok[1] in ('^', '**', '*'): + skip = 1 + continue + if skip: + skip -= 1 + continue + result.append((OP, ')')) + appendParen -= 1 + + if tokens: + result.append(tokens[-1]) + + if appendParen: + result.extend([(OP, ')')] * appendParen) + return result + + +def function_exponentiation(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Allows functions to be exponentiated, e.g. ``cos**2(x)``. + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, function_exponentiation) + >>> transformations = standard_transformations + (function_exponentiation,) + >>> parse_expr('sin**4(x)', transformations=transformations) + sin(x)**4 + """ + result: list[TOKEN] = [] + exponent: list[TOKEN] = [] + consuming_exponent = False + level = 0 + for tok, nextTok in zip(tokens, tokens[1:]): + if tok[0] == NAME and nextTok[0] == OP and nextTok[1] == '**': + if _token_callable(tok, local_dict, global_dict): + consuming_exponent = True + elif consuming_exponent: + if tok[0] == NAME and tok[1] == 'Function': + tok = (NAME, 'Symbol') + exponent.append(tok) + + # only want to stop after hitting ) + if tok[0] == nextTok[0] == OP and tok[1] == ')' and nextTok[1] == '(': + consuming_exponent = False + # if implicit multiplication was used, we may have )*( instead + if tok[0] == nextTok[0] == OP and tok[1] == '*' and nextTok[1] == '(': + consuming_exponent = False + del exponent[-1] + continue + elif exponent and not consuming_exponent: + if tok[0] == OP: + if tok[1] == '(': + level += 1 + elif tok[1] == ')': + level -= 1 + if level == 0: + result.append(tok) + result.extend(exponent) + exponent = [] + continue + result.append(tok) + if tokens: + result.append(tokens[-1]) + if exponent: + result.extend(exponent) + return result + + +def split_symbols_custom(predicate: Callable[[str], bool]): + """Creates a transformation that splits symbol names. + + ``predicate`` should return True if the symbol name is to be split. + + For instance, to retain the default behavior but avoid splitting certain + symbol names, a predicate like this would work: + + + >>> from sympy.parsing.sympy_parser import (parse_expr, _token_splittable, + ... standard_transformations, implicit_multiplication, + ... split_symbols_custom) + >>> def can_split(symbol): + ... if symbol not in ('list', 'of', 'unsplittable', 'names'): + ... return _token_splittable(symbol) + ... return False + ... + >>> transformation = split_symbols_custom(can_split) + >>> parse_expr('unsplittable', transformations=standard_transformations + + ... (transformation, implicit_multiplication)) + unsplittable + """ + def _split_symbols(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + result: list[TOKEN] = [] + split = False + split_previous=False + + for tok in tokens: + if split_previous: + # throw out closing parenthesis of Symbol that was split + split_previous=False + continue + split_previous=False + + if tok[0] == NAME and tok[1] in ['Symbol', 'Function']: + split = True + + elif split and tok[0] == NAME: + symbol = tok[1][1:-1] + + if predicate(symbol): + tok_type = result[-2][1] # Symbol or Function + del result[-2:] # Get rid of the call to Symbol + + i = 0 + while i < len(symbol): + char = symbol[i] + if char in local_dict or char in global_dict: + result.append((NAME, "%s" % char)) + elif char.isdigit(): + chars = [char] + for i in range(i + 1, len(symbol)): + if not symbol[i].isdigit(): + i -= 1 + break + chars.append(symbol[i]) + char = ''.join(chars) + result.extend([(NAME, 'Number'), (OP, '('), + (NAME, "'%s'" % char), (OP, ')')]) + else: + use = tok_type if i == len(symbol) else 'Symbol' + result.extend([(NAME, use), (OP, '('), + (NAME, "'%s'" % char), (OP, ')')]) + i += 1 + + # Set split_previous=True so will skip + # the closing parenthesis of the original Symbol + split = False + split_previous = True + continue + + else: + split = False + + result.append(tok) + + return result + + return _split_symbols + + +#: Splits symbol names for implicit multiplication. +#: +#: Intended to let expressions like ``xyz`` be parsed as ``x*y*z``. Does not +#: split Greek character names, so ``theta`` will *not* become +#: ``t*h*e*t*a``. Generally this should be used with +#: ``implicit_multiplication``. +split_symbols = split_symbols_custom(_token_splittable) + + +def implicit_multiplication(tokens: list[TOKEN], local_dict: DICT, + global_dict: DICT) -> list[TOKEN]: + """Makes the multiplication operator optional in most cases. + + Use this before :func:`implicit_application`, otherwise expressions like + ``sin 2x`` will be parsed as ``x * sin(2)`` rather than ``sin(2*x)``. + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, implicit_multiplication) + >>> transformations = standard_transformations + (implicit_multiplication,) + >>> parse_expr('3 x y', transformations=transformations) + 3*x*y + """ + # These are interdependent steps, so we don't expose them separately + res1 = _group_parentheses(implicit_multiplication)(tokens, local_dict, global_dict) + res2 = _apply_functions(res1, local_dict, global_dict) + res3 = _implicit_multiplication(res2, local_dict, global_dict) + result = _flatten(res3) + return result + + +def implicit_application(tokens: list[TOKEN], local_dict: DICT, + global_dict: DICT) -> list[TOKEN]: + """Makes parentheses optional in some cases for function calls. + + Use this after :func:`implicit_multiplication`, otherwise expressions + like ``sin 2x`` will be parsed as ``x * sin(2)`` rather than + ``sin(2*x)``. + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, implicit_application) + >>> transformations = standard_transformations + (implicit_application,) + >>> parse_expr('cot z + csc z', transformations=transformations) + cot(z) + csc(z) + """ + res1 = _group_parentheses(implicit_application)(tokens, local_dict, global_dict) + res2 = _apply_functions(res1, local_dict, global_dict) + res3 = _implicit_application(res2, local_dict, global_dict) + result = _flatten(res3) + return result + + +def implicit_multiplication_application(result: list[TOKEN], local_dict: DICT, + global_dict: DICT) -> list[TOKEN]: + """Allows a slightly relaxed syntax. + + - Parentheses for single-argument method calls are optional. + + - Multiplication is implicit. + + - Symbol names can be split (i.e. spaces are not needed between + symbols). + + - Functions can be exponentiated. + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, implicit_multiplication_application) + >>> parse_expr("10sin**2 x**2 + 3xyz + tan theta", + ... transformations=(standard_transformations + + ... (implicit_multiplication_application,))) + 3*x*y*z + 10*sin(x**2)**2 + tan(theta) + + """ + for step in (split_symbols, implicit_multiplication, + implicit_application, function_exponentiation): + result = step(result, local_dict, global_dict) + + return result + + +def auto_symbol(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Inserts calls to ``Symbol``/``Function`` for undefined variables.""" + result: list[TOKEN] = [] + prevTok = (-1, '') + + tokens.append((-1, '')) # so zip traverses all tokens + for tok, nextTok in zip(tokens, tokens[1:]): + tokNum, tokVal = tok + nextTokNum, nextTokVal = nextTok + if tokNum == NAME: + name = tokVal + + if (name in ['True', 'False', 'None'] + or iskeyword(name) + # Don't convert attribute access + or (prevTok[0] == OP and prevTok[1] == '.') + # Don't convert keyword arguments + or (prevTok[0] == OP and prevTok[1] in ('(', ',') + and nextTokNum == OP and nextTokVal == '=') + # the name has already been defined + or name in local_dict and local_dict[name] is not null): + result.append((NAME, name)) + continue + elif name in local_dict: + local_dict.setdefault(null, set()).add(name) + if nextTokVal == '(': + local_dict[name] = Function(name) + else: + local_dict[name] = Symbol(name) + result.append((NAME, name)) + continue + elif name in global_dict: + obj = global_dict[name] + if isinstance(obj, (AssumptionKeys, Basic, type)) or callable(obj): + result.append((NAME, name)) + continue + + result.extend([ + (NAME, 'Symbol' if nextTokVal != '(' else 'Function'), + (OP, '('), + (NAME, repr(str(name))), + (OP, ')'), + ]) + else: + result.append((tokNum, tokVal)) + + prevTok = (tokNum, tokVal) + + return result + + +def lambda_notation(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Substitutes "lambda" with its SymPy equivalent Lambda(). + However, the conversion does not take place if only "lambda" + is passed because that is a syntax error. + + """ + result: list[TOKEN] = [] + flag = False + toknum, tokval = tokens[0] + tokLen = len(tokens) + + if toknum == NAME and tokval == 'lambda': + if tokLen == 2 or tokLen == 3 and tokens[1][0] == NEWLINE: + # In Python 3.6.7+, inputs without a newline get NEWLINE added to + # the tokens + result.extend(tokens) + elif tokLen > 2: + result.extend([ + (NAME, 'Lambda'), + (OP, '('), + (OP, '('), + (OP, ')'), + (OP, ')'), + ]) + for tokNum, tokVal in tokens[1:]: + if tokNum == OP and tokVal == ':': + tokVal = ',' + flag = True + if not flag and tokNum == OP and tokVal in ('*', '**'): + raise TokenError("Starred arguments in lambda not supported") + if flag: + result.insert(-1, (tokNum, tokVal)) + else: + result.insert(-2, (tokNum, tokVal)) + else: + result.extend(tokens) + + return result + + +def factorial_notation(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Allows standard notation for factorial.""" + result: list[TOKEN] = [] + nfactorial = 0 + for toknum, tokval in tokens: + if toknum == OP and tokval == "!": + # In Python 3.12 "!" are OP instead of ERRORTOKEN + nfactorial += 1 + elif toknum == ERRORTOKEN: + op = tokval + if op == '!': + nfactorial += 1 + else: + nfactorial = 0 + result.append((OP, op)) + else: + if nfactorial == 1: + result = _add_factorial_tokens('factorial', result) + elif nfactorial == 2: + result = _add_factorial_tokens('factorial2', result) + elif nfactorial > 2: + raise TokenError + nfactorial = 0 + result.append((toknum, tokval)) + return result + + +def convert_xor(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Treats XOR, ``^``, as exponentiation, ``**``.""" + result: list[TOKEN] = [] + for toknum, tokval in tokens: + if toknum == OP: + if tokval == '^': + result.append((OP, '**')) + else: + result.append((toknum, tokval)) + else: + result.append((toknum, tokval)) + + return result + + +def repeated_decimals(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """ + Allows 0.2[1] notation to represent the repeated decimal 0.2111... (19/90) + + Run this before auto_number. + + """ + result: list[TOKEN] = [] + + def is_digit(s): + return all(i in '0123456789_' for i in s) + + # num will running match any DECIMAL [ INTEGER ] + num: list[TOKEN] = [] + for toknum, tokval in tokens: + if toknum == NUMBER: + if (not num and '.' in tokval and 'e' not in tokval.lower() and + 'j' not in tokval.lower()): + num.append((toknum, tokval)) + elif is_digit(tokval) and (len(num) == 2 or + len(num) == 3 and is_digit(num[-1][1])): + num.append((toknum, tokval)) + else: + num = [] + elif toknum == OP: + if tokval == '[' and len(num) == 1: + num.append((OP, tokval)) + elif tokval == ']' and len(num) >= 3: + num.append((OP, tokval)) + elif tokval == '.' and not num: + # handle .[1] + num.append((NUMBER, '0.')) + else: + num = [] + else: + num = [] + + result.append((toknum, tokval)) + + if num and num[-1][1] == ']': + # pre.post[repetend] = a + b/c + d/e where a = pre, b/c = post, + # and d/e = repetend + result = result[:-len(num)] + pre, post = num[0][1].split('.') + repetend = num[2][1] + if len(num) == 5: + repetend += num[3][1] + + pre = pre.replace('_', '') + post = post.replace('_', '') + repetend = repetend.replace('_', '') + + zeros = '0'*len(post) + post, repetends = [w.lstrip('0') for w in [post, repetend]] + # or else interpreted as octal + + a = pre or '0' + b, c = post or '0', '1' + zeros + d, e = repetends, ('9'*len(repetend)) + zeros + + seq = [ + (OP, '('), + (NAME, 'Integer'), + (OP, '('), + (NUMBER, a), + (OP, ')'), + (OP, '+'), + (NAME, 'Rational'), + (OP, '('), + (NUMBER, b), + (OP, ','), + (NUMBER, c), + (OP, ')'), + (OP, '+'), + (NAME, 'Rational'), + (OP, '('), + (NUMBER, d), + (OP, ','), + (NUMBER, e), + (OP, ')'), + (OP, ')'), + ] + result.extend(seq) + num = [] + + return result + + +def auto_number(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """ + Converts numeric literals to use SymPy equivalents. + + Complex numbers use ``I``, integer literals use ``Integer``, and float + literals use ``Float``. + + """ + result: list[TOKEN] = [] + + for toknum, tokval in tokens: + if toknum == NUMBER: + number = tokval + postfix = [] + + if number.endswith(('j', 'J')): + number = number[:-1] + postfix = [(OP, '*'), (NAME, 'I')] + + if '.' in number or (('e' in number or 'E' in number) and + not (number.startswith(('0x', '0X')))): + seq = [(NAME, 'Float'), (OP, '('), + (NUMBER, repr(str(number))), (OP, ')')] + else: + seq = [(NAME, 'Integer'), (OP, '('), ( + NUMBER, number), (OP, ')')] + + result.extend(seq + postfix) + else: + result.append((toknum, tokval)) + + return result + + +def rationalize(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Converts floats into ``Rational``. Run AFTER ``auto_number``.""" + result: list[TOKEN] = [] + passed_float = False + for toknum, tokval in tokens: + if toknum == NAME: + if tokval == 'Float': + passed_float = True + tokval = 'Rational' + result.append((toknum, tokval)) + elif passed_float == True and toknum == NUMBER: + passed_float = False + result.append((STRING, tokval)) + else: + result.append((toknum, tokval)) + + return result + + +def _transform_equals_sign(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Transforms the equals sign ``=`` to instances of Eq. + + This is a helper function for ``convert_equals_signs``. + Works with expressions containing one equals sign and no + nesting. Expressions like ``(1=2)=False`` will not work with this + and should be used with ``convert_equals_signs``. + + Examples: 1=2 to Eq(1,2) + 1*2=x to Eq(1*2, x) + + This does not deal with function arguments yet. + + """ + result: list[TOKEN] = [] + if (OP, "=") in tokens: + result.append((NAME, "Eq")) + result.append((OP, "(")) + for token in tokens: + if token == (OP, "="): + result.append((OP, ",")) + continue + result.append(token) + result.append((OP, ")")) + else: + result = tokens + return result + + +def convert_equals_signs(tokens: list[TOKEN], local_dict: DICT, + global_dict: DICT) -> list[TOKEN]: + """ Transforms all the equals signs ``=`` to instances of Eq. + + Parses the equals signs in the expression and replaces them with + appropriate Eq instances. Also works with nested equals signs. + + Does not yet play well with function arguments. + For example, the expression ``(x=y)`` is ambiguous and can be interpreted + as x being an argument to a function and ``convert_equals_signs`` will not + work for this. + + See also + ======== + convert_equality_operators + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, convert_equals_signs) + >>> parse_expr("1*2=x", transformations=( + ... standard_transformations + (convert_equals_signs,))) + Eq(2, x) + >>> parse_expr("(1*2=x)=False", transformations=( + ... standard_transformations + (convert_equals_signs,))) + Eq(Eq(2, x), False) + + """ + res1 = _group_parentheses(convert_equals_signs)(tokens, local_dict, global_dict) + res2 = _apply_functions(res1, local_dict, global_dict) + res3 = _transform_equals_sign(res2, local_dict, global_dict) + result = _flatten(res3) + return result + + +#: Standard transformations for :func:`parse_expr`. +#: Inserts calls to :class:`~.Symbol`, :class:`~.Integer`, and other SymPy +#: datatypes and allows the use of standard factorial notation (e.g. ``x!``). +standard_transformations: tuple[TRANS, ...] \ + = (lambda_notation, auto_symbol, repeated_decimals, auto_number, + factorial_notation) + + +def stringify_expr(s: str, local_dict: DICT, global_dict: DICT, + transformations: tuple[TRANS, ...]) -> str: + """ + Converts the string ``s`` to Python code, in ``local_dict`` + + Generally, ``parse_expr`` should be used. + """ + + tokens = [] + input_code = StringIO(s.strip()) + for toknum, tokval, _, _, _ in generate_tokens(input_code.readline): + tokens.append((toknum, tokval)) + + for transform in transformations: + tokens = transform(tokens, local_dict, global_dict) + + return untokenize(tokens) + + +def eval_expr(code, local_dict: DICT, global_dict: DICT): + """ + Evaluate Python code generated by ``stringify_expr``. + + Generally, ``parse_expr`` should be used. + """ + expr = eval( + code, global_dict, local_dict) # take local objects in preference + return expr + + +def parse_expr(s: str, local_dict: DICT | None = None, + transformations: tuple[TRANS, ...] | str \ + = standard_transformations, + global_dict: DICT | None = None, evaluate=True): + """Converts the string ``s`` to a SymPy expression, in ``local_dict``. + + .. warning:: + Note that this function uses ``eval``, and thus shouldn't be used on + unsanitized input. + + Parameters + ========== + + s : str + The string to parse. + + local_dict : dict, optional + A dictionary of local variables to use when parsing. + + global_dict : dict, optional + A dictionary of global variables. By default, this is initialized + with ``from sympy import *``; provide this parameter to override + this behavior (for instance, to parse ``"Q & S"``). + + transformations : tuple or str + A tuple of transformation functions used to modify the tokens of the + parsed expression before evaluation. The default transformations + convert numeric literals into their SymPy equivalents, convert + undefined variables into SymPy symbols, and allow the use of standard + mathematical factorial notation (e.g. ``x!``). Selection via + string is available (see below). + + evaluate : bool, optional + When False, the order of the arguments will remain as they were in the + string and automatic simplification that would normally occur is + suppressed. (see examples) + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import parse_expr + >>> parse_expr("1/2") + 1/2 + >>> type(_) + + >>> from sympy.parsing.sympy_parser import standard_transformations,\\ + ... implicit_multiplication_application + >>> transformations = (standard_transformations + + ... (implicit_multiplication_application,)) + >>> parse_expr("2x", transformations=transformations) + 2*x + + When evaluate=False, some automatic simplifications will not occur: + + >>> parse_expr("2**3"), parse_expr("2**3", evaluate=False) + (8, 2**3) + + In addition the order of the arguments will not be made canonical. + This feature allows one to tell exactly how the expression was entered: + + >>> a = parse_expr('1 + x', evaluate=False) + >>> b = parse_expr('x + 1', evaluate=False) + >>> a == b + False + >>> a.args + (1, x) + >>> b.args + (x, 1) + + Note, however, that when these expressions are printed they will + appear the same: + + >>> assert str(a) == str(b) + + As a convenience, transformations can be seen by printing ``transformations``: + + >>> from sympy.parsing.sympy_parser import transformations + + >>> print(transformations) + 0: lambda_notation + 1: auto_symbol + 2: repeated_decimals + 3: auto_number + 4: factorial_notation + 5: implicit_multiplication_application + 6: convert_xor + 7: implicit_application + 8: implicit_multiplication + 9: convert_equals_signs + 10: function_exponentiation + 11: rationalize + + The ``T`` object provides a way to select these transformations: + + >>> from sympy.parsing.sympy_parser import T + + If you print it, you will see the same list as shown above. + + >>> str(T) == str(transformations) + True + + Standard slicing will return a tuple of transformations: + + >>> T[:5] == standard_transformations + True + + So ``T`` can be used to specify the parsing transformations: + + >>> parse_expr("2x", transformations=T[:5]) + Traceback (most recent call last): + ... + SyntaxError: invalid syntax + >>> parse_expr("2x", transformations=T[:6]) + 2*x + >>> parse_expr('.3', transformations=T[3, 11]) + 3/10 + >>> parse_expr('.3x', transformations=T[:]) + 3*x/10 + + As a further convenience, strings 'implicit' and 'all' can be used + to select 0-5 and all the transformations, respectively. + + >>> parse_expr('.3x', transformations='all') + 3*x/10 + + See Also + ======== + + stringify_expr, eval_expr, standard_transformations, + implicit_multiplication_application + + """ + + if local_dict is None: + local_dict = {} + elif not isinstance(local_dict, dict): + raise TypeError('expecting local_dict to be a dict') + elif null in local_dict: + raise ValueError('cannot use "" in local_dict') + + if global_dict is None: + global_dict = {} + exec('from sympy import *', global_dict) + + builtins_dict = vars(builtins) + for name, obj in builtins_dict.items(): + if isinstance(obj, types.BuiltinFunctionType): + global_dict[name] = obj + global_dict['max'] = Max + global_dict['min'] = Min + + elif not isinstance(global_dict, dict): + raise TypeError('expecting global_dict to be a dict') + + transformations = transformations or () + if isinstance(transformations, str): + if transformations == 'all': + _transformations = T[:] + elif transformations == 'implicit': + _transformations = T[:6] + else: + raise ValueError('unknown transformation group name') + else: + _transformations = transformations + + code = stringify_expr(s, local_dict, global_dict, _transformations) + + if not evaluate: + code = compile(evaluateFalse(code), '', 'eval') # type: ignore + + try: + rv = eval_expr(code, local_dict, global_dict) + # restore neutral definitions for names + for i in local_dict.pop(null, ()): + local_dict[i] = null + return rv + except Exception as e: + # restore neutral definitions for names + for i in local_dict.pop(null, ()): + local_dict[i] = null + raise e from ValueError(f"Error from parse_expr with transformed code: {code!r}") + + +def evaluateFalse(s: str): + """ + Replaces operators with the SymPy equivalent and sets evaluate=False. + """ + node = ast.parse(s) + transformed_node = EvaluateFalseTransformer().visit(node) + # node is a Module, we want an Expression + transformed_node = ast.Expression(transformed_node.body[0].value) + + return ast.fix_missing_locations(transformed_node) + + +class EvaluateFalseTransformer(ast.NodeTransformer): + operators = { + ast.Add: 'Add', + ast.Mult: 'Mul', + ast.Pow: 'Pow', + ast.Sub: 'Add', + ast.Div: 'Mul', + ast.BitOr: 'Or', + ast.BitAnd: 'And', + ast.BitXor: 'Not', + } + functions = ( + 'Abs', 'im', 're', 'sign', 'arg', 'conjugate', + 'acos', 'acot', 'acsc', 'asec', 'asin', 'atan', + 'acosh', 'acoth', 'acsch', 'asech', 'asinh', 'atanh', + 'cos', 'cot', 'csc', 'sec', 'sin', 'tan', + 'cosh', 'coth', 'csch', 'sech', 'sinh', 'tanh', + 'exp', 'ln', 'log', 'sqrt', 'cbrt', + ) + + relational_operators = { + ast.NotEq: 'Ne', + ast.Lt: 'Lt', + ast.LtE: 'Le', + ast.Gt: 'Gt', + ast.GtE: 'Ge', + ast.Eq: 'Eq' + } + def visit_Compare(self, node): + def reducer(acc, op_right): + result, left = acc + op, right = op_right + if op.__class__ not in self.relational_operators: + raise ValueError("Only equation or inequality operators are supported") + new = ast.Call( + func=ast.Name( + id=self.relational_operators[op.__class__], ctx=ast.Load() + ), + args=[self.visit(left), self.visit(right)], + keywords=[ast.keyword(arg="evaluate", value=ast.Constant(value=False))], + ) + return result + [new], right + + args, _ = reduce( + reducer, zip(node.ops, node.comparators), ([], node.left) + ) + if len(args) == 1: + return args[0] + return ast.Call( + func=ast.Name(id=self.operators[ast.BitAnd], ctx=ast.Load()), + args=args, + keywords=[ast.keyword(arg="evaluate", value=ast.Constant(value=False))], + ) + + def flatten(self, args, func): + result = [] + for arg in args: + if isinstance(arg, ast.Call): + arg_func = arg.func + if isinstance(arg_func, ast.Call): + arg_func = arg_func.func + if arg_func.id == func: + result.extend(self.flatten(arg.args, func)) + else: + result.append(arg) + else: + result.append(arg) + return result + + def visit_BinOp(self, node): + if node.op.__class__ in self.operators: + sympy_class = self.operators[node.op.__class__] + right = self.visit(node.right) + left = self.visit(node.left) + + rev = False + if isinstance(node.op, ast.Sub): + right = ast.Call( + func=ast.Name(id='Mul', ctx=ast.Load()), + args=[ast.UnaryOp(op=ast.USub(), operand=ast.Constant(1)), right], + keywords=[ast.keyword(arg='evaluate', value=ast.Constant(value=False))] + ) + elif isinstance(node.op, ast.Div): + if isinstance(node.left, ast.UnaryOp): + left, right = right, left + rev = True + left = ast.Call( + func=ast.Name(id='Pow', ctx=ast.Load()), + args=[left, ast.UnaryOp(op=ast.USub(), operand=ast.Constant(1))], + keywords=[ast.keyword(arg='evaluate', value=ast.Constant(value=False))] + ) + else: + right = ast.Call( + func=ast.Name(id='Pow', ctx=ast.Load()), + args=[right, ast.UnaryOp(op=ast.USub(), operand=ast.Constant(1))], + keywords=[ast.keyword(arg='evaluate', value=ast.Constant(value=False))] + ) + + if rev: # undo reversal + left, right = right, left + new_node = ast.Call( + func=ast.Name(id=sympy_class, ctx=ast.Load()), + args=[left, right], + keywords=[ast.keyword(arg='evaluate', value=ast.Constant(value=False))] + ) + + if sympy_class in ('Add', 'Mul'): + # Denest Add or Mul as appropriate + new_node.args = self.flatten(new_node.args, sympy_class) + + return new_node + return node + + def visit_Call(self, node): + new_node = self.generic_visit(node) + if isinstance(node.func, ast.Name) and node.func.id in self.functions: + new_node.keywords.append(ast.keyword(arg='evaluate', value=ast.Constant(value=False))) + return new_node + + +_transformation = { # items can be added but never re-ordered +0: lambda_notation, +1: auto_symbol, +2: repeated_decimals, +3: auto_number, +4: factorial_notation, +5: implicit_multiplication_application, +6: convert_xor, +7: implicit_application, +8: implicit_multiplication, +9: convert_equals_signs, +10: function_exponentiation, +11: rationalize} + +transformations = '\n'.join('%s: %s' % (i, func_name(f)) for i, f in _transformation.items()) + + +class _T(): + """class to retrieve transformations from a given slice + + EXAMPLES + ======== + + >>> from sympy.parsing.sympy_parser import T, standard_transformations + >>> assert T[:5] == standard_transformations + """ + def __init__(self): + self.N = len(_transformation) + + def __str__(self): + return transformations + + def __getitem__(self, t): + if not type(t) is tuple: + t = (t,) + i = [] + for ti in t: + if type(ti) is int: + i.append(range(self.N)[ti]) + elif type(ti) is slice: + i.extend(range(*ti.indices(self.N))) + else: + raise TypeError('unexpected slice arg') + return tuple([_transformation[_] for _ in i]) + +T = _T() diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_ast_parser.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_ast_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..24572190df72f9be11b5830355b0d6b9e3bb53ad --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_ast_parser.py @@ -0,0 +1,25 @@ +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.parsing.ast_parser import parse_expr +from sympy.testing.pytest import raises +from sympy.core.sympify import SympifyError +import warnings + +def test_parse_expr(): + a, b = symbols('a, b') + # tests issue_16393 + assert parse_expr('a + b', {}) == a + b + raises(SympifyError, lambda: parse_expr('a + ', {})) + + # tests Transform.visit_Constant + assert parse_expr('1 + 2', {}) == S(3) + assert parse_expr('1 + 2.0', {}) == S(3.0) + + # tests Transform.visit_Name + assert parse_expr('Rational(1, 2)', {}) == S(1)/2 + assert parse_expr('a', {'a': a}) == a + + # tests issue_23092 + with warnings.catch_warnings(): + warnings.simplefilter('error') + assert parse_expr('6 * 7', {}) == S(42) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_autolev.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_autolev.py new file mode 100644 index 0000000000000000000000000000000000000000..dfcaef13565c5e2187dc6e90113b407a7967c331 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_autolev.py @@ -0,0 +1,178 @@ +import os + +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.external import import_module +from sympy.testing.pytest import skip +from sympy.parsing.autolev import parse_autolev + +antlr4 = import_module("antlr4") + +if not antlr4: + disabled = True + +FILE_DIR = os.path.dirname( + os.path.dirname(os.path.abspath(os.path.realpath(__file__)))) + + +def _test_examples(in_filename, out_filename, test_name=""): + + in_file_path = os.path.join(FILE_DIR, 'autolev', 'test-examples', + in_filename) + correct_file_path = os.path.join(FILE_DIR, 'autolev', 'test-examples', + out_filename) + with open(in_file_path) as f: + generated_code = parse_autolev(f, include_numeric=True) + + with open(correct_file_path) as f: + for idx, line1 in enumerate(f): + if line1.startswith("#"): + break + try: + line2 = generated_code.split('\n')[idx] + assert line1.rstrip() == line2.rstrip() + except Exception: + msg = 'mismatch in ' + test_name + ' in line no: {0}' + raise AssertionError(msg.format(idx+1)) + + +def test_rule_tests(): + + l = ["ruletest1", "ruletest2", "ruletest3", "ruletest4", "ruletest5", + "ruletest6", "ruletest7", "ruletest8", "ruletest9", "ruletest10", + "ruletest11", "ruletest12"] + + for i in l: + in_filepath = i + ".al" + out_filepath = i + ".py" + _test_examples(in_filepath, out_filepath, i) + + +def test_pydy_examples(): + + l = ["mass_spring_damper", "chaos_pendulum", "double_pendulum", + "non_min_pendulum"] + + for i in l: + in_filepath = os.path.join("pydy-example-repo", i + ".al") + out_filepath = os.path.join("pydy-example-repo", i + ".py") + _test_examples(in_filepath, out_filepath, i) + + +def test_autolev_tutorial(): + + dir_path = os.path.join(FILE_DIR, 'autolev', 'test-examples', + 'autolev-tutorial') + + if os.path.isdir(dir_path): + l = ["tutor1", "tutor2", "tutor3", "tutor4", "tutor5", "tutor6", + "tutor7"] + for i in l: + in_filepath = os.path.join("autolev-tutorial", i + ".al") + out_filepath = os.path.join("autolev-tutorial", i + ".py") + _test_examples(in_filepath, out_filepath, i) + + +def test_dynamics_online(): + + dir_path = os.path.join(FILE_DIR, 'autolev', 'test-examples', + 'dynamics-online') + + if os.path.isdir(dir_path): + ch1 = ["1-4", "1-5", "1-6", "1-7", "1-8", "1-9_1", "1-9_2", "1-9_3"] + ch2 = ["2-1", "2-2", "2-3", "2-4", "2-5", "2-6", "2-7", "2-8", "2-9", + "circular"] + ch3 = ["3-1_1", "3-1_2", "3-2_1", "3-2_2", "3-2_3", "3-2_4", "3-2_5", + "3-3"] + ch4 = ["4-1_1", "4-2_1", "4-4_1", "4-4_2", "4-5_1", "4-5_2"] + chapters = [(ch1, "ch1"), (ch2, "ch2"), (ch3, "ch3"), (ch4, "ch4")] + for ch, name in chapters: + for i in ch: + in_filepath = os.path.join("dynamics-online", name, i + ".al") + out_filepath = os.path.join("dynamics-online", name, i + ".py") + _test_examples(in_filepath, out_filepath, i) + + +def test_output_01(): + """Autolev example calculates the position, velocity, and acceleration of a + point and expresses in a single reference frame:: + + (1) FRAMES C,D,F + (2) VARIABLES FD'',DC'' + (3) CONSTANTS R,L + (4) POINTS O,E + (5) SIMPROT(F,D,1,FD) + -> (6) F_D = [1, 0, 0; 0, COS(FD), -SIN(FD); 0, SIN(FD), COS(FD)] + (7) SIMPROT(D,C,2,DC) + -> (8) D_C = [COS(DC), 0, SIN(DC); 0, 1, 0; -SIN(DC), 0, COS(DC)] + (9) W_C_F> = EXPRESS(W_C_F>, F) + -> (10) W_C_F> = FD'*F1> + COS(FD)*DC'*F2> + SIN(FD)*DC'*F3> + (11) P_O_E>=R*D2>-L*C1> + (12) P_O_E>=EXPRESS(P_O_E>, D) + -> (13) P_O_E> = -L*COS(DC)*D1> + R*D2> + L*SIN(DC)*D3> + (14) V_E_F>=EXPRESS(DT(P_O_E>,F),D) + -> (15) V_E_F> = L*SIN(DC)*DC'*D1> - L*SIN(DC)*FD'*D2> + (R*FD'+L*COS(DC)*DC')*D3> + (16) A_E_F>=EXPRESS(DT(V_E_F>,F),D) + -> (17) A_E_F> = L*(COS(DC)*DC'^2+SIN(DC)*DC'')*D1> + (-R*FD'^2-2*L*COS(DC)*DC'*FD'-L*SIN(DC)*FD'')*D2> + (R*FD''+L*COS(DC)*DC''-L*SIN(DC)*DC'^2-L*SIN(DC)*FD'^2)*D3> + + """ + + if not antlr4: + skip('Test skipped: antlr4 is not installed.') + + autolev_input = """\ +FRAMES C,D,F +VARIABLES FD'',DC'' +CONSTANTS R,L +POINTS O,E +SIMPROT(F,D,1,FD) +SIMPROT(D,C,2,DC) +W_C_F>=EXPRESS(W_C_F>,F) +P_O_E>=R*D2>-L*C1> +P_O_E>=EXPRESS(P_O_E>,D) +V_E_F>=EXPRESS(DT(P_O_E>,F),D) +A_E_F>=EXPRESS(DT(V_E_F>,F),D)\ +""" + + sympy_input = parse_autolev(autolev_input) + + g = {} + l = {} + exec(sympy_input, g, l) + + w_c_f = l['frame_c'].ang_vel_in(l['frame_f']) + # P_O_E> means "the position of point E wrt to point O" + p_o_e = l['point_e'].pos_from(l['point_o']) + v_e_f = l['point_e'].vel(l['frame_f']) + a_e_f = l['point_e'].acc(l['frame_f']) + + # NOTE : The Autolev outputs above were manually transformed into + # equivalent SymPy physics vector expressions. Would be nice to automate + # this transformation. + expected_w_c_f = (l['fd'].diff()*l['frame_f'].x + + cos(l['fd'])*l['dc'].diff()*l['frame_f'].y + + sin(l['fd'])*l['dc'].diff()*l['frame_f'].z) + + assert (w_c_f - expected_w_c_f).simplify() == 0 + + expected_p_o_e = (-l['l']*cos(l['dc'])*l['frame_d'].x + + l['r']*l['frame_d'].y + + l['l']*sin(l['dc'])*l['frame_d'].z) + + assert (p_o_e - expected_p_o_e).simplify() == 0 + + expected_v_e_f = (l['l']*sin(l['dc'])*l['dc'].diff()*l['frame_d'].x - + l['l']*sin(l['dc'])*l['fd'].diff()*l['frame_d'].y + + (l['r']*l['fd'].diff() + + l['l']*cos(l['dc'])*l['dc'].diff())*l['frame_d'].z) + assert (v_e_f - expected_v_e_f).simplify() == 0 + + expected_a_e_f = (l['l']*(cos(l['dc'])*l['dc'].diff()**2 + + sin(l['dc'])*l['dc'].diff().diff())*l['frame_d'].x + + (-l['r']*l['fd'].diff()**2 - + 2*l['l']*cos(l['dc'])*l['dc'].diff()*l['fd'].diff() - + l['l']*sin(l['dc'])*l['fd'].diff().diff())*l['frame_d'].y + + (l['r']*l['fd'].diff().diff() + + l['l']*cos(l['dc'])*l['dc'].diff().diff() - + l['l']*sin(l['dc'])*l['dc'].diff()**2 - + l['l']*sin(l['dc'])*l['fd'].diff()**2)*l['frame_d'].z) + assert (a_e_f - expected_a_e_f).simplify() == 0 diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_c_parser.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_c_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..b74622e40030cba180cb4fc354216ccca119baec --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_c_parser.py @@ -0,0 +1,5248 @@ +from sympy.parsing.sym_expr import SymPyExpression +from sympy.testing.pytest import raises, XFAIL +from sympy.external import import_module + +cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']}) + +if cin: + from sympy.codegen.ast import (Variable, String, Return, + FunctionDefinition, Integer, Float, Declaration, CodeBlock, + FunctionPrototype, FunctionCall, NoneToken, Assignment, Type, + IntBaseType, SignedIntType, UnsignedIntType, FloatType, + AddAugmentedAssignment, SubAugmentedAssignment, + MulAugmentedAssignment, DivAugmentedAssignment, + ModAugmentedAssignment, While) + from sympy.codegen.cnodes import (PreDecrement, PostDecrement, + PreIncrement, PostIncrement) + from sympy.core import (Add, Mul, Mod, Pow, Rational, + StrictLessThan, LessThan, StrictGreaterThan, GreaterThan, + Equality, Unequality) + from sympy.logic.boolalg import And, Not, Or + from sympy.core.symbol import Symbol + from sympy.logic.boolalg import (false, true) + import os + + def test_variable(): + c_src1 = ( + 'int a;' + '\n' + + 'int b;' + '\n' + ) + c_src2 = ( + 'float a;' + '\n' + + 'float b;' + '\n' + ) + c_src3 = ( + 'int a;' + '\n' + + 'float b;' + '\n' + + 'int c;' + ) + c_src4 = ( + 'int x = 1, y = 6.78;' + '\n' + + 'float p = 2, q = 9.67;' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + + assert res1[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + + assert res1[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')) + ) + ) + + assert res2[0] == Declaration( + Variable( + Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ) + assert res2[1] == Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ) + + assert res3[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + + assert res3[1] == Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ) + + assert res3[2] == Declaration( + Variable( + Symbol('c'), + type=IntBaseType(String('intc')) + ) + ) + + assert res4[0] == Declaration( + Variable( + Symbol('x'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res4[1] == Declaration( + Variable( + Symbol('y'), + type=IntBaseType(String('intc')), + value=Integer(6) + ) + ) + + assert res4[2] == Declaration( + Variable( + Symbol('p'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.0', precision=53) + ) + ) + + assert res4[3] == Declaration( + Variable( + Symbol('q'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('9.67', precision=53) + ) + ) + + + def test_int(): + c_src1 = 'int a = 1;' + c_src2 = ( + 'int a = 1;' + '\n' + + 'int b = 2;' + '\n' + ) + c_src3 = 'int a = 2.345, b = 5.67;' + c_src4 = 'int p = 6, q = 23.45;' + c_src5 = "int x = '0', y = 'a';" + c_src6 = "int r = true, s = false;" + + # cin.TypeKind.UCHAR + c_src_type1 = ( + "signed char a = 1, b = 5.1;" + ) + + # cin.TypeKind.SHORT + c_src_type2 = ( + "short a = 1, b = 5.1;" + "signed short c = 1, d = 5.1;" + "short int e = 1, f = 5.1;" + "signed short int g = 1, h = 5.1;" + ) + + # cin.TypeKind.INT + c_src_type3 = ( + "signed int a = 1, b = 5.1;" + "int c = 1, d = 5.1;" + ) + + # cin.TypeKind.LONG + c_src_type4 = ( + "long a = 1, b = 5.1;" + "long int c = 1, d = 5.1;" + ) + + # cin.TypeKind.UCHAR + c_src_type5 = "unsigned char a = 1, b = 5.1;" + + # cin.TypeKind.USHORT + c_src_type6 = ( + "unsigned short a = 1, b = 5.1;" + "unsigned short int c = 1, d = 5.1;" + ) + + # cin.TypeKind.UINT + c_src_type7 = "unsigned int a = 1, b = 5.1;" + + # cin.TypeKind.ULONG + c_src_type8 = ( + "unsigned long a = 1, b = 5.1;" + "unsigned long int c = 1, d = 5.1;" + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + res6 = SymPyExpression(c_src6, 'c').return_expr() + + res_type1 = SymPyExpression(c_src_type1, 'c').return_expr() + res_type2 = SymPyExpression(c_src_type2, 'c').return_expr() + res_type3 = SymPyExpression(c_src_type3, 'c').return_expr() + res_type4 = SymPyExpression(c_src_type4, 'c').return_expr() + res_type5 = SymPyExpression(c_src_type5, 'c').return_expr() + res_type6 = SymPyExpression(c_src_type6, 'c').return_expr() + res_type7 = SymPyExpression(c_src_type7, 'c').return_expr() + res_type8 = SymPyExpression(c_src_type8, 'c').return_expr() + + assert res1[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res2[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res2[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res3[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res3[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ) + + assert res4[0] == Declaration( + Variable( + Symbol('p'), + type=IntBaseType(String('intc')), + value=Integer(6) + ) + ) + + assert res4[1] == Declaration( + Variable( + Symbol('q'), + type=IntBaseType(String('intc')), + value=Integer(23) + ) + ) + + assert res5[0] == Declaration( + Variable( + Symbol('x'), + type=IntBaseType(String('intc')), + value=Integer(48) + ) + ) + + assert res5[1] == Declaration( + Variable( + Symbol('y'), + type=IntBaseType(String('intc')), + value=Integer(97) + ) + ) + + assert res6[0] == Declaration( + Variable( + Symbol('r'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res6[1] == Declaration( + Variable( + Symbol('s'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ) + + assert res_type1[0] == Declaration( + Variable( + Symbol('a'), + type=SignedIntType( + String('int8'), + nbits=Integer(8) + ), + value=Integer(1) + ) + ) + + assert res_type1[1] == Declaration( + Variable( + Symbol('b'), + type=SignedIntType( + String('int8'), + nbits=Integer(8) + ), + value=Integer(5) + ) + ) + + assert res_type2[0] == Declaration( + Variable( + Symbol('a'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type2[1] == Declaration( + Variable( + Symbol('b'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type2[2] == Declaration( + Variable(Symbol('c'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type2[3] == Declaration( + Variable( + Symbol('d'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type2[4] == Declaration( + Variable( + Symbol('e'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type2[5] == Declaration( + Variable( + Symbol('f'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type2[6] == Declaration( + Variable( + Symbol('g'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type2[7] == Declaration( + Variable( + Symbol('h'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type3[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res_type3[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ) + + assert res_type3[2] == Declaration( + Variable( + Symbol('c'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res_type3[3] == Declaration( + Variable( + Symbol('d'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ) + + assert res_type4[0] == Declaration( + Variable( + Symbol('a'), + type=SignedIntType( + String('int64'), + nbits=Integer(64) + ), + value=Integer(1) + ) + ) + + assert res_type4[1] == Declaration( + Variable( + Symbol('b'), + type=SignedIntType( + String('int64'), + nbits=Integer(64) + ), + value=Integer(5) + ) + ) + + assert res_type4[2] == Declaration( + Variable( + Symbol('c'), + type=SignedIntType( + String('int64'), + nbits=Integer(64) + ), + value=Integer(1) + ) + ) + + assert res_type4[3] == Declaration( + Variable( + Symbol('d'), + type=SignedIntType( + String('int64'), + nbits=Integer(64) + ), + value=Integer(5) + ) + ) + + assert res_type5[0] == Declaration( + Variable( + Symbol('a'), + type=UnsignedIntType( + String('uint8'), + nbits=Integer(8) + ), + value=Integer(1) + ) + ) + + assert res_type5[1] == Declaration( + Variable( + Symbol('b'), + type=UnsignedIntType( + String('uint8'), + nbits=Integer(8) + ), + value=Integer(5) + ) + ) + + assert res_type6[0] == Declaration( + Variable( + Symbol('a'), + type=UnsignedIntType( + String('uint16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type6[1] == Declaration( + Variable( + Symbol('b'), + type=UnsignedIntType( + String('uint16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type6[2] == Declaration( + Variable( + Symbol('c'), + type=UnsignedIntType( + String('uint16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type6[3] == Declaration( + Variable( + Symbol('d'), + type=UnsignedIntType( + String('uint16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type7[0] == Declaration( + Variable( + Symbol('a'), + type=UnsignedIntType( + String('uint32'), + nbits=Integer(32) + ), + value=Integer(1) + ) + ) + + assert res_type7[1] == Declaration( + Variable( + Symbol('b'), + type=UnsignedIntType( + String('uint32'), + nbits=Integer(32) + ), + value=Integer(5) + ) + ) + + assert res_type8[0] == Declaration( + Variable( + Symbol('a'), + type=UnsignedIntType( + String('uint64'), + nbits=Integer(64) + ), + value=Integer(1) + ) + ) + + assert res_type8[1] == Declaration( + Variable( + Symbol('b'), + type=UnsignedIntType( + String('uint64'), + nbits=Integer(64) + ), + value=Integer(5) + ) + ) + + assert res_type8[2] == Declaration( + Variable( + Symbol('c'), + type=UnsignedIntType( + String('uint64'), + nbits=Integer(64) + ), + value=Integer(1) + ) + ) + + assert res_type8[3] == Declaration( + Variable( + Symbol('d'), + type=UnsignedIntType( + String('uint64'), + nbits=Integer(64) + ), + value=Integer(5) + ) + ) + + + def test_float(): + c_src1 = 'float a = 1.0;' + c_src2 = ( + 'float a = 1.25;' + '\n' + + 'float b = 2.39;' + '\n' + ) + c_src3 = 'float x = 1, y = 2;' + c_src4 = 'float p = 5, e = 7.89;' + c_src5 = 'float r = true, s = false;' + + # cin.TypeKind.FLOAT + c_src_type1 = 'float x = 1, y = 2.5;' + + # cin.TypeKind.DOUBLE + c_src_type2 = 'double x = 1, y = 2.5;' + + # cin.TypeKind.LONGDOUBLE + c_src_type3 = 'long double x = 1, y = 2.5;' + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + + res_type1 = SymPyExpression(c_src_type1, 'c').return_expr() + res_type2 = SymPyExpression(c_src_type2, 'c').return_expr() + res_type3 = SymPyExpression(c_src_type3, 'c').return_expr() + + assert res1[0] == Declaration( + Variable( + Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res2[0] == Declaration( + Variable( + Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.25', precision=53) + ) + ) + + assert res2[1] == Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.3900000000000001', precision=53) + ) + ) + + assert res3[0] == Declaration( + Variable( + Symbol('x'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res3[1] == Declaration( + Variable( + Symbol('y'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.0', precision=53) + ) + ) + + assert res4[0] == Declaration( + Variable( + Symbol('p'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('5.0', precision=53) + ) + ) + + assert res4[1] == Declaration( + Variable( + Symbol('e'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('7.89', precision=53) + ) + ) + + assert res5[0] == Declaration( + Variable( + Symbol('r'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res5[1] == Declaration( + Variable( + Symbol('s'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('0.0', precision=53) + ) + ) + + assert res_type1[0] == Declaration( + Variable( + Symbol('x'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res_type1[1] == Declaration( + Variable( + Symbol('y'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ) + assert res_type2[0] == Declaration( + Variable( + Symbol('x'), + type=FloatType( + String('float64'), + nbits=Integer(64), + nmant=Integer(52), + nexp=Integer(11) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res_type2[1] == Declaration( + Variable( + Symbol('y'), + type=FloatType( + String('float64'), + nbits=Integer(64), + nmant=Integer(52), + nexp=Integer(11) + ), + value=Float('2.5', precision=53) + ) + ) + + assert res_type3[0] == Declaration( + Variable( + Symbol('x'), + type=FloatType( + String('float80'), + nbits=Integer(80), + nmant=Integer(63), + nexp=Integer(15) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res_type3[1] == Declaration( + Variable( + Symbol('y'), + type=FloatType( + String('float80'), + nbits=Integer(80), + nmant=Integer(63), + nexp=Integer(15) + ), + value=Float('2.5', precision=53) + ) + ) + + + def test_bool(): + c_src1 = ( + 'bool a = true, b = false;' + ) + + c_src2 = ( + 'bool a = 1, b = 0;' + ) + + c_src3 = ( + 'bool a = 10, b = 20;' + ) + + c_src4 = ( + 'bool a = 19.1, b = 9.0, c = 0.0;' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + + assert res1[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true + ) + ) + + assert res1[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=false + ) + ) + + assert res2[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true) + ) + + assert res2[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=false + ) + ) + + assert res3[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true + ) + ) + + assert res3[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=true + ) + ) + + assert res4[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true) + ) + + assert res4[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=true + ) + ) + + assert res4[2] == Declaration( + Variable(Symbol('c'), + type=Type(String('bool')), + value=false + ) + ) + + @XFAIL # this is expected to fail because of a bug in the C parser. + def test_function(): + c_src1 = ( + 'void fun1()' + '\n' + + '{' + '\n' + + 'int a;' + '\n' + + '}' + ) + c_src2 = ( + 'int fun2()' + '\n' + + '{'+ '\n' + + 'int a;' + '\n' + + 'return a;' + '\n' + + '}' + ) + c_src3 = ( + 'float fun3()' + '\n' + + '{' + '\n' + + 'float b;' + '\n' + + 'return b;' + '\n' + + '}' + ) + c_src4 = ( + 'float fun4()' + '\n' + + '{}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('fun1'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + ) + ) + + assert res2[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun2'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Return('a') + ) + ) + + assert res3[0] == FunctionDefinition( + FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + name=String('fun3'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Return('b') + ) + ) + + assert res4[0] == FunctionPrototype( + FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + name=String('fun4'), + parameters=() + ) + + @XFAIL # this is expected to fail because of a bug in the C parser. + def test_parameters(): + c_src1 = ( + 'void fun1( int a)' + '\n' + + '{' + '\n' + + 'int i;' + '\n' + + '}' + ) + c_src2 = ( + 'int fun2(float x, float y)' + '\n' + + '{'+ '\n' + + 'int a;' + '\n' + + 'return a;' + '\n' + + '}' + ) + c_src3 = ( + 'float fun3(int p, float q, int r)' + '\n' + + '{' + '\n' + + 'float b;' + '\n' + + 'return b;' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('fun1'), + parameters=( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ), + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('i'), + type=IntBaseType(String('intc')) + ) + ) + ) + ) + + assert res2[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun2'), + parameters=( + Variable( + Symbol('x'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ), + Variable( + Symbol('y'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Return('a') + ) + ) + + assert res3[0] == FunctionDefinition( + FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + name=String('fun3'), + parameters=( + Variable( + Symbol('p'), + type=IntBaseType(String('intc')) + ), + Variable( + Symbol('q'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ), + Variable( + Symbol('r'), + type=IntBaseType(String('intc')) + ) + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Return('b') + ) + ) + + @XFAIL # this is expected to fail because of a bug in the C parser. + def test_function_call(): + c_src1 = ( + 'int fun1(int x)' + '\n' + + '{' + '\n' + + 'return x;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'int x = fun1(2);' + '\n' + + '}' + ) + + c_src2 = ( + 'int fun2(int a, int b, int c)' + '\n' + + '{' + '\n' + + 'return a;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'int y = fun2(2, 3, 4);' + '\n' + + '}' + ) + + c_src3 = ( + 'int fun3(int a, int b, int c)' + '\n' + + '{' + '\n' + + 'return b;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'int p;' + '\n' + + 'int q;' + '\n' + + 'int r;' + '\n' + + 'int z = fun3(p, q, r);' + '\n' + + '}' + ) + + c_src4 = ( + 'int fun4(float a, float b, int c)' + '\n' + + '{' + '\n' + + 'return c;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'float x;' + '\n' + + 'float y;' + '\n' + + 'int z;' + '\n' + + 'int i = fun4(x, y, z)' + '\n' + + '}' + ) + + c_src5 = ( + 'int fun()' + '\n' + + '{' + '\n' + + 'return 1;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'int a = fun()' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + + + assert res1[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun1'), + parameters=(Variable(Symbol('x'), + type=IntBaseType(String('intc')) + ), + ), + body=CodeBlock( + Return('x') + ) + ) + + assert res1[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('x'), + value=FunctionCall(String('fun1'), + function_args=( + Integer(2), + ) + ) + ) + ) + ) + ) + + assert res2[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun2'), + parameters=(Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ), + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ), + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + body=CodeBlock( + Return('a') + ) + ) + + assert res2[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('y'), + value=FunctionCall( + String('fun2'), + function_args=( + Integer(2), + Integer(3), + Integer(4) + ) + ) + ) + ) + ) + ) + + assert res3[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun3'), + parameters=( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ), + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ), + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + body=CodeBlock( + Return('b') + ) + ) + + assert res3[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('p'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('q'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('r'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('z'), + value=FunctionCall( + String('fun3'), + function_args=( + Symbol('p'), + Symbol('q'), + Symbol('r') + ) + ) + ) + ) + ) + ) + + assert res4[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun4'), + parameters=(Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ), + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ), + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + body=CodeBlock( + Return('c') + ) + ) + + assert res4[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('x'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Declaration( + Variable(Symbol('y'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Declaration( + Variable(Symbol('z'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('i'), + value=FunctionCall(String('fun4'), + function_args=( + Symbol('x'), + Symbol('y'), + Symbol('z') + ) + ) + ) + ) + ) + ) + + assert res5[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun'), + parameters=(), + body=CodeBlock( + Return('') + ) + ) + + assert res5[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + value=FunctionCall(String('fun'), + function_args=() + ) + ) + ) + ) + ) + + + def test_parse(): + c_src1 = ( + 'int a;' + '\n' + + 'int b;' + '\n' + ) + c_src2 = ( + 'void fun1()' + '\n' + + '{' + '\n' + + 'int a;' + '\n' + + '}' + ) + + f1 = open('..a.h', 'w') + f2 = open('..b.h', 'w') + + f1.write(c_src1) + f2. write(c_src2) + + f1.close() + f2.close() + + res1 = SymPyExpression('..a.h', 'c').return_expr() + res2 = SymPyExpression('..b.h', 'c').return_expr() + + os.remove('..a.h') + os.remove('..b.h') + + assert res1[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + assert res1[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')) + ) + ) + assert res2[0] == FunctionDefinition( + NoneToken(), + name=String('fun1'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + ) + ) + + + def test_binary_operators(): + c_src1 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 1;' + '\n' + + '}' + ) + c_src2 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 0;' + '\n' + + 'a = a + 1;' + '\n' + + 'a = 3*a - 10;' + '\n' + + '}' + ) + c_src3 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'a = 1 + a - 3 * 6;' + '\n' + + '}' + ) + c_src4 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'a = 100;' + '\n' + + 'b = a*a + a*a + a + 19*a + 1 + 24;' + '\n' + + '}' + ) + c_src5 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'int c;' + '\n' + + 'int d;' + '\n' + + 'a = 1;' + '\n' + + 'b = 2;' + '\n' + + 'c = b;' + '\n' + + 'd = ((a+b)*(a+c))*((c-d)*(a+c));' + '\n' + + '}' + ) + c_src6 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'int c;' + '\n' + + 'int d;' + '\n' + + 'a = 1;' + '\n' + + 'b = 2;' + '\n' + + 'c = 3;' + '\n' + + 'd = (a*a*a*a + 3*b*b + b + b + c*d);' + '\n' + + '}' + ) + c_src7 = ( + 'void func()'+ + '{' + '\n' + + 'float a;' + '\n' + + 'a = 1.01;' + '\n' + + '}' + ) + + c_src8 = ( + 'void func()'+ + '{' + '\n' + + 'float a;' + '\n' + + 'a = 10.0 + 2.5;' + '\n' + + '}' + ) + + c_src9 = ( + 'void func()'+ + '{' + '\n' + + 'float a;' + '\n' + + 'a = 10.0 / 2.5;' + '\n' + + '}' + ) + + c_src10 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 100 / 4;' + '\n' + + '}' + ) + + c_src11 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 20 - 100 / 4 * 5 + 10;' + '\n' + + '}' + ) + + c_src12 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = (20 - 100) / 4 * (5 + 10);' + '\n' + + '}' + ) + + c_src13 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'float c;' + '\n' + + 'c = b/a;' + '\n' + + '}' + ) + + c_src14 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 2;' + '\n' + + 'int d = 5;' + '\n' + + 'int n = 10;' + '\n' + + 'int s;' + '\n' + + 's = (a/2)*(2*a + (n-1)*d);' + '\n' + + '}' + ) + + c_src15 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 1 % 2;' + '\n' + + '}' + ) + + c_src16 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 2;' + '\n' + + 'int b;' + '\n' + + 'b = a % 3;' + '\n' + + '}' + ) + + c_src17 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int c;' + '\n' + + 'c = a % b;' + '\n' + + '}' + ) + + c_src18 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int mod = 1000000007;' + '\n' + + 'int c;' + '\n' + + 'c = (a + b * (100/a)) % mod;' + '\n' + + '}' + ) + + c_src19 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int mod = 1000000007;' + '\n' + + 'int c;' + '\n' + + 'c = ((a % mod + b % mod) % mod' \ + '* (a % mod - b % mod) % mod) % mod;' + '\n' + + '}' + ) + + c_src20 = ( + 'void func()'+ + '{' + '\n' + + 'bool a' + '\n' + + 'bool b;' + '\n' + + 'a = 1 == 2;' + '\n' + + 'b = 1 != 2;' + '\n' + + '}' + ) + + c_src21 = ( + 'void func()'+ + '{' + '\n' + + 'bool a;' + '\n' + + 'bool b;' + '\n' + + 'bool c;' + '\n' + + 'bool d;' + '\n' + + 'a = 1 == 2;' + '\n' + + 'b = 1 <= 2;' + '\n' + + 'c = 1 > 2;' + '\n' + + 'd = 1 >= 2;' + '\n' + + '}' + ) + + c_src22 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 1;' + '\n' + + 'int b = 2;' + '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + 'bool c7;' + '\n' + + 'bool c8;' + '\n' + + + 'c1 = a == 1;' + '\n' + + 'c2 = b == 2;' + '\n' + + + 'c3 = 1 != a;' + '\n' + + 'c4 = 1 != b;' + '\n' + + + 'c5 = a < 0;' + '\n' + + 'c6 = b <= 10;' + '\n' + + 'c7 = a > 0;' + '\n' + + 'c8 = b >= 11;' + '\n' + + '}' + ) + + c_src23 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 3;' + '\n' + + 'int b = 4;' + '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = a == b;' + '\n' + + 'c2 = a != b;' + '\n' + + 'c3 = a < b;' + '\n' + + 'c4 = a <= b;' + '\n' + + 'c5 = a > b;' + '\n' + + 'c6 = a >= b;' + '\n' + + '}' + ) + + c_src24 = ( + 'void func()'+ + '{' + '\n' + + 'float a = 1.25' + 'float b = 2.5;' + '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + + 'c1 = a == 1.25;' + '\n' + + 'c2 = b == 2.54;' + '\n' + + + 'c3 = 1.2 != a;' + '\n' + + 'c4 = 1.5 != b;' + '\n' + + '}' + ) + + c_src25 = ( + 'void func()'+ + '{' + '\n' + + 'float a = 1.25' + '\n' + + 'float b = 2.5;' + '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = a == b;' + '\n' + + 'c2 = a != b;' + '\n' + + 'c3 = a < b;' + '\n' + + 'c4 = a <= b;' + '\n' + + 'c5 = a > b;' + '\n' + + 'c6 = a >= b;' + '\n' + + '}' + ) + + c_src26 = ( + 'void func()'+ + '{' + '\n' + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = true == true;' + '\n' + + 'c2 = true == false;' + '\n' + + 'c3 = false == false;' + '\n' + + + 'c4 = true != true;' + '\n' + + 'c5 = true != false;' + '\n' + + 'c6 = false != false;' + '\n' + + '}' + ) + + c_src27 = ( + 'void func()'+ + '{' + '\n' + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = true && true;' + '\n' + + 'c2 = true && false;' + '\n' + + 'c3 = false && false;' + '\n' + + + 'c4 = true || true;' + '\n' + + 'c5 = true || false;' + '\n' + + 'c6 = false || false;' + '\n' + + '}' + ) + + c_src28 = ( + 'void func()'+ + '{' + '\n' + + 'bool a;' + '\n' + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + + 'c1 = a && true;' + '\n' + + 'c2 = false && a;' + '\n' + + + 'c3 = true || a;' + '\n' + + 'c4 = a || false;' + '\n' + + '}' + ) + + c_src29 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + + 'c1 = a && 1;' + '\n' + + 'c2 = a && 0;' + '\n' + + + 'c3 = a || 1;' + '\n' + + 'c4 = 0 || a;' + '\n' + + '}' + ) + + c_src30 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'bool c;'+ '\n' + + 'bool d;'+ '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = a && b;' + '\n' + + 'c2 = a && c;' + '\n' + + 'c3 = c && d;' + '\n' + + + 'c4 = a || b;' + '\n' + + 'c5 = a || c;' + '\n' + + 'c6 = c || d;' + '\n' + + '}' + ) + + c_src_raise1 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = -1;' + '\n' + + '}' + ) + + c_src_raise2 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = -+1;' + '\n' + + '}' + ) + + c_src_raise3 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 2*-2;' + '\n' + + '}' + ) + + c_src_raise4 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = (int)2.0;' + '\n' + + '}' + ) + + c_src_raise5 = ( + 'void func()'+ + '{' + '\n' + + 'int a=100;' + '\n' + + 'a = (a==100)?(1):(0);' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + res6 = SymPyExpression(c_src6, 'c').return_expr() + res7 = SymPyExpression(c_src7, 'c').return_expr() + res8 = SymPyExpression(c_src8, 'c').return_expr() + res9 = SymPyExpression(c_src9, 'c').return_expr() + res10 = SymPyExpression(c_src10, 'c').return_expr() + res11 = SymPyExpression(c_src11, 'c').return_expr() + res12 = SymPyExpression(c_src12, 'c').return_expr() + res13 = SymPyExpression(c_src13, 'c').return_expr() + res14 = SymPyExpression(c_src14, 'c').return_expr() + res15 = SymPyExpression(c_src15, 'c').return_expr() + res16 = SymPyExpression(c_src16, 'c').return_expr() + res17 = SymPyExpression(c_src17, 'c').return_expr() + res18 = SymPyExpression(c_src18, 'c').return_expr() + res19 = SymPyExpression(c_src19, 'c').return_expr() + res20 = SymPyExpression(c_src20, 'c').return_expr() + res21 = SymPyExpression(c_src21, 'c').return_expr() + res22 = SymPyExpression(c_src22, 'c').return_expr() + res23 = SymPyExpression(c_src23, 'c').return_expr() + res24 = SymPyExpression(c_src24, 'c').return_expr() + res25 = SymPyExpression(c_src25, 'c').return_expr() + res26 = SymPyExpression(c_src26, 'c').return_expr() + res27 = SymPyExpression(c_src27, 'c').return_expr() + res28 = SymPyExpression(c_src28, 'c').return_expr() + res29 = SymPyExpression(c_src29, 'c').return_expr() + res30 = SymPyExpression(c_src30, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment(Variable(Symbol('a')), Integer(1)) + ) + ) + + assert res2[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(0))), + Assignment( + Variable(Symbol('a')), + Add(Symbol('a'), + Integer(1)) + ), + Assignment(Variable(Symbol('a')), + Add( + Mul( + Integer(3), + Symbol('a')), + Integer(-10) + ) + ) + ) + ) + + assert res3[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Assignment( + Variable(Symbol('a')), + Add( + Symbol('a'), + Integer(-17) + ) + ) + ) + ) + + assert res4[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(100)), + Assignment( + Variable(Symbol('b')), + Add( + Mul( + Integer(2), + Pow( + Symbol('a'), + Integer(2)) + ), + Mul( + Integer(20), + Symbol('a')), + Integer(25) + ) + ) + ) + ) + + assert res5[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(1)), + Assignment( + Variable(Symbol('b')), + Integer(2) + ), + Assignment( + Variable(Symbol('c')), + Symbol('b')), + Assignment( + Variable(Symbol('d')), + Mul( + Add( + Symbol('a'), + Symbol('b')), + Pow( + Add( + Symbol('a'), + Symbol('c') + ), + Integer(2) + ), + Add( + Symbol('c'), + Mul( + Integer(-1), + Symbol('d') + ) + ) + ) + ) + ) + ) + + assert res6[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(1) + ), + Assignment( + Variable(Symbol('b')), + Integer(2) + ), + Assignment( + Variable(Symbol('c')), + Integer(3) + ), + Assignment( + Variable(Symbol('d')), + Add( + Pow( + Symbol('a'), + Integer(4) + ), + Mul( + Integer(3), + Pow( + Symbol('b'), + Integer(2) + ) + ), + Mul( + Integer(2), + Symbol('b') + ), + Mul( + Symbol('c'), + Symbol('d') + ) + ) + ) + ) + ) + + assert res7[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Assignment( + Variable(Symbol('a')), + Float('1.01', precision=53) + ) + ) + ) + + assert res8[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Assignment( + Variable(Symbol('a')), + Float('12.5', precision=53) + ) + ) + ) + + assert res9[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Assignment( + Variable(Symbol('a')), + Float('4.0', precision=53) + ) + ) + ) + + assert res10[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(25) + ) + ) + ) + + assert res11[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(-95) + ) + ) + ) + + assert res12[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(-300) + ) + ) + ) + + assert res13[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Assignment( + Variable(Symbol('c')), + Mul( + Pow( + Symbol('a'), + Integer(-1) + ), + Symbol('b') + ) + ) + ) + ) + + assert res14[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ), + Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ), + Declaration( + Variable(Symbol('n'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Declaration( + Variable(Symbol('s'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('s')), + Mul( + Rational(1, 2), + Symbol('a'), + Add( + Mul( + Integer(2), + Symbol('a') + ), + Mul( + Symbol('d'), + Add( + Symbol('n'), + Integer(-1) + ) + ) + ) + ) + ) + ) + ) + + assert res15[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(1) + ) + ) + ) + + assert res16[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('b')), + Mod( + Symbol('a'), + Integer(3) + ) + ) + ) + ) + + assert res17[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('c')), + Mod( + Symbol('a'), + Symbol('b') + ) + ) + ) + ) + + assert res18[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ), + Declaration( + Variable(Symbol('mod'), + type=IntBaseType(String('intc')), + value=Integer(1000000007) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('c')), + Mod( + Add( + Symbol('a'), + Mul( + Integer(100), + Pow( + Symbol('a'), + Integer(-1) + ), + Symbol('b') + ) + ), + Symbol('mod') + ) + ) + ) + ) + + assert res19[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ), + Declaration( + Variable(Symbol('mod'), + type=IntBaseType(String('intc')), + value=Integer(1000000007) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('c')), + Mod( + Mul( + Add( + Mod( + Symbol('a'), + Symbol('mod') + ), + Mul( + Integer(-1), + Mod( + Symbol('b'), + Symbol('mod') + ) + ) + ), + Mod( + Add( + Symbol('a'), + Symbol('b') + ), + Symbol('mod') + ) + ), + Symbol('mod') + ) + ) + ) + ) + + assert res20[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('a')), + false + ), + Assignment( + Variable(Symbol('b')), + true + ) + ) + ) + + assert res21[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('d'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('a')), + false + ), + Assignment( + Variable(Symbol('b')), + true + ), + Assignment( + Variable(Symbol('c')), + false + ), + Assignment( + Variable(Symbol('d')), + false + ) + ) + ) + + assert res22[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c7'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c8'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Equality( + Symbol('a'), + Integer(1) + ) + ), + Assignment( + Variable(Symbol('c2')), + Equality( + Symbol('b'), + Integer(2) + ) + ), + Assignment( + Variable(Symbol('c3')), + Unequality( + Integer(1), + Symbol('a') + ) + ), + Assignment( + Variable(Symbol('c4')), + Unequality( + Integer(1), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c5')), + StrictLessThan( + Symbol('a'), + Integer(0) + ) + ), + Assignment( + Variable(Symbol('c6')), + LessThan( + Symbol('b'), + Integer(10) + ) + ), + Assignment( + Variable(Symbol('c7')), + StrictGreaterThan( + Symbol('a'), + Integer(0) + ) + ), + Assignment( + Variable(Symbol('c8')), + GreaterThan( + Symbol('b'), + Integer(11) + ) + ) + ) + ) + + assert res23[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(4) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Equality( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c2')), + Unequality( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c3')), + StrictLessThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c4')), + LessThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c5')), + StrictGreaterThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c6')), + GreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + ) + + assert res24[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Equality( + Symbol('a'), + Float('1.25', precision=53) + ) + ), + Assignment( + Variable(Symbol('c3')), + Unequality( + Float('1.2', precision=53), + Symbol('a') + ) + ) + ) + ) + + + assert res25[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.25', precision=53) + ) + ), + Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool') + ) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Equality( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c2')), + Unequality( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c3')), + StrictLessThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c4')), + LessThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c5')), + StrictGreaterThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c6')), + GreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + ) + + assert res26[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), body=CodeBlock( + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + true + ), + Assignment( + Variable(Symbol('c2')), + false + ), + Assignment( + Variable(Symbol('c3')), + true + ), + Assignment( + Variable(Symbol('c4')), + false + ), + Assignment( + Variable(Symbol('c5')), + true + ), + Assignment( + Variable(Symbol('c6')), + false + ) + ) + ) + + assert res27[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + true + ), + Assignment( + Variable(Symbol('c2')), + false + ), + Assignment( + Variable(Symbol('c3')), + false + ), + Assignment( + Variable(Symbol('c4')), + true + ), + Assignment( + Variable(Symbol('c5')), + true + ), + Assignment( + Variable(Symbol('c6')), + false) + ) + ) + + assert res28[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Symbol('a') + ), + Assignment( + Variable(Symbol('c2')), + false + ), + Assignment( + Variable(Symbol('c3')), + true + ), + Assignment( + Variable(Symbol('c4')), + Symbol('a') + ) + ) + ) + + assert res29[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Symbol('a') + ), + Assignment( + Variable(Symbol('c2')), + false + ), + Assignment( + Variable(Symbol('c3')), + true + ), + Assignment( + Variable(Symbol('c4')), + Symbol('a') + ) + ) + ) + + assert res30[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('d'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + And( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c2')), + And( + Symbol('a'), + Symbol('c') + ) + ), + Assignment( + Variable(Symbol('c3')), + And( + Symbol('c'), + Symbol('d') + ) + ), + Assignment( + Variable(Symbol('c4')), + Or( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c5')), + Or( + Symbol('a'), + Symbol('c') + ) + ), + Assignment( + Variable(Symbol('c6')), + Or( + Symbol('c'), + Symbol('d') + ) + ) + ) + ) + + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise1, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise2, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise3, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise4, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise5, 'c')) + + + @XFAIL + def test_var_decl(): + c_src1 = ( + 'int b = 100;' + '\n' + + 'int a = b;' + '\n' + ) + + c_src2 = ( + 'int a = 1;' + '\n' + + 'int b = a + 1;' + '\n' + ) + + c_src3 = ( + 'float a = 10.0 + 2.5;' + '\n' + + 'float b = a * 20.0;' + '\n' + ) + + c_src4 = ( + 'int a = 1 + 100 - 3 * 6;' + '\n' + ) + + c_src5 = ( + 'int a = (((1 + 100) * 12) - 3) * (6 - 10);' + '\n' + ) + + c_src6 = ( + 'int b = 2;' + '\n' + + 'int c = 3;' + '\n' + + 'int a = b + c * 4;' + '\n' + ) + + c_src7 = ( + 'int b = 1;' + '\n' + + 'int c = b + 2;' + '\n' + + 'int a = 10 * b * b * c;' + '\n' + ) + + c_src8 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 1;' + '\n' + + 'int b = 2;' + '\n' + + 'int temp = a;' + '\n' + + 'a = b;' + '\n' + + 'b = temp;' + '\n' + + '}' + ) + + c_src9 = ( + 'int a = 1;' + '\n' + + 'int b = 2;' + '\n' + + 'int c = a;' + '\n' + + 'int d = a + b + c;' + '\n' + + 'int e = a*a*a + 3*a*a*b + 3*a*b*b + b*b*b;' + '\n' + 'int f = (a + b + c) * (a + b - c);' + '\n' + + 'int g = (a + b + c + d)*(a + b + c + d)*(a * (b - c));' + + '\n' + ) + + c_src10 = ( + 'float a = 10.0;' + '\n' + + 'float b = 2.5;' + '\n' + + 'float c = a*a + 2*a*b + b*b;' + '\n' + ) + + c_src11 = ( + 'float a = 10.0 / 2.5;' + '\n' + ) + + c_src12 = ( + 'int a = 100 / 4;' + '\n' + ) + + c_src13 = ( + 'int a = 20 - 100 / 4 * 5 + 10;' + '\n' + ) + + c_src14 = ( + 'int a = (20 - 100) / 4 * (5 + 10);' + '\n' + ) + + c_src15 = ( + 'int a = 4;' + '\n' + + 'int b = 2;' + '\n' + + 'float c = b/a;' + '\n' + ) + + c_src16 = ( + 'int a = 2;' + '\n' + + 'int d = 5;' + '\n' + + 'int n = 10;' + '\n' + + 'int s = (a/2)*(2*a + (n-1)*d);' + '\n' + ) + + c_src17 = ( + 'int a = 1 % 2;' + '\n' + ) + + c_src18 = ( + 'int a = 2;' + '\n' + + 'int b = a % 3;' + '\n' + ) + + c_src19 = ( + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int c = a % b;' + '\n' + ) + + c_src20 = ( + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int mod = 1000000007;' + '\n' + + 'int c = (a + b * (100/a)) % mod;' + '\n' + ) + + c_src21 = ( + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int mod = 1000000007;' + '\n' + + 'int c = ((a % mod + b % mod) % mod *' \ + '(a % mod - b % mod) % mod) % mod;' + '\n' + ) + + c_src22 = ( + 'bool a = 1 == 2, b = 1 != 2;' + ) + + c_src23 = ( + 'bool a = 1 < 2, b = 1 <= 2, c = 1 > 2, d = 1 >= 2;' + ) + + c_src24 = ( + 'int a = 1, b = 2;' + '\n' + + + 'bool c1 = a == 1;' + '\n' + + 'bool c2 = b == 2;' + '\n' + + + 'bool c3 = 1 != a;' + '\n' + + 'bool c4 = 1 != b;' + '\n' + + + 'bool c5 = a < 0;' + '\n' + + 'bool c6 = b <= 10;' + '\n' + + 'bool c7 = a > 0;' + '\n' + + 'bool c8 = b >= 11;' + + ) + + c_src25 = ( + 'int a = 3, b = 4;' + '\n' + + + 'bool c1 = a == b;' + '\n' + + 'bool c2 = a != b;' + '\n' + + 'bool c3 = a < b;' + '\n' + + 'bool c4 = a <= b;' + '\n' + + 'bool c5 = a > b;' + '\n' + + 'bool c6 = a >= b;' + ) + + c_src26 = ( + 'float a = 1.25, b = 2.5;' + '\n' + + + 'bool c1 = a == 1.25;' + '\n' + + 'bool c2 = b == 2.54;' + '\n' + + + 'bool c3 = 1.2 != a;' + '\n' + + 'bool c4 = 1.5 != b;' + ) + + c_src27 = ( + 'float a = 1.25, b = 2.5;' + '\n' + + + 'bool c1 = a == b;' + '\n' + + 'bool c2 = a != b;' + '\n' + + 'bool c3 = a < b;' + '\n' + + 'bool c4 = a <= b;' + '\n' + + 'bool c5 = a > b;' + '\n' + + 'bool c6 = a >= b;' + ) + + c_src28 = ( + 'bool c1 = true == true;' + '\n' + + 'bool c2 = true == false;' + '\n' + + 'bool c3 = false == false;' + '\n' + + + 'bool c4 = true != true;' + '\n' + + 'bool c5 = true != false;' + '\n' + + 'bool c6 = false != false;' + ) + + c_src29 = ( + 'bool c1 = true && true;' + '\n' + + 'bool c2 = true && false;' + '\n' + + 'bool c3 = false && false;' + '\n' + + + 'bool c4 = true || true;' + '\n' + + 'bool c5 = true || false;' + '\n' + + 'bool c6 = false || false;' + ) + + c_src30 = ( + 'bool a = false;' + '\n' + + + 'bool c1 = a && true;' + '\n' + + 'bool c2 = false && a;' + '\n' + + + 'bool c3 = true || a;' + '\n' + + 'bool c4 = a || false;' + ) + + c_src31 = ( + 'int a = 1;' + '\n' + + + 'bool c1 = a && 1;' + '\n' + + 'bool c2 = a && 0;' + '\n' + + + 'bool c3 = a || 1;' + '\n' + + 'bool c4 = 0 || a;' + ) + + c_src32 = ( + 'int a = 1, b = 0;' + '\n' + + 'bool c = false, d = true;'+ '\n' + + + 'bool c1 = a && b;' + '\n' + + 'bool c2 = a && c;' + '\n' + + 'bool c3 = c && d;' + '\n' + + + 'bool c4 = a || b;' + '\n' + + 'bool c5 = a || c;' + '\n' + + 'bool c6 = c || d;' + ) + + c_src_raise1 = ( + "char a = 'b';" + ) + + c_src_raise2 = ( + 'int a[] = {10, 20};' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + res6 = SymPyExpression(c_src6, 'c').return_expr() + res7 = SymPyExpression(c_src7, 'c').return_expr() + res8 = SymPyExpression(c_src8, 'c').return_expr() + res9 = SymPyExpression(c_src9, 'c').return_expr() + res10 = SymPyExpression(c_src10, 'c').return_expr() + res11 = SymPyExpression(c_src11, 'c').return_expr() + res12 = SymPyExpression(c_src12, 'c').return_expr() + res13 = SymPyExpression(c_src13, 'c').return_expr() + res14 = SymPyExpression(c_src14, 'c').return_expr() + res15 = SymPyExpression(c_src15, 'c').return_expr() + res16 = SymPyExpression(c_src16, 'c').return_expr() + res17 = SymPyExpression(c_src17, 'c').return_expr() + res18 = SymPyExpression(c_src18, 'c').return_expr() + res19 = SymPyExpression(c_src19, 'c').return_expr() + res20 = SymPyExpression(c_src20, 'c').return_expr() + res21 = SymPyExpression(c_src21, 'c').return_expr() + res22 = SymPyExpression(c_src22, 'c').return_expr() + res23 = SymPyExpression(c_src23, 'c').return_expr() + res24 = SymPyExpression(c_src24, 'c').return_expr() + res25 = SymPyExpression(c_src25, 'c').return_expr() + res26 = SymPyExpression(c_src26, 'c').return_expr() + res27 = SymPyExpression(c_src27, 'c').return_expr() + res28 = SymPyExpression(c_src28, 'c').return_expr() + res29 = SymPyExpression(c_src29, 'c').return_expr() + res30 = SymPyExpression(c_src30, 'c').return_expr() + res31 = SymPyExpression(c_src31, 'c').return_expr() + res32 = SymPyExpression(c_src32, 'c').return_expr() + + assert res1[0] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ) + + assert res1[1] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Symbol('b') + ) + ) + + assert res2[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res2[1] == Declaration(Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('a'), + Integer(1) + ) + ) + ) + + assert res3[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('12.5', precision=53) + ) + ) + + assert res3[1] == Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Mul( + Float('20.0', precision=53), + Symbol('a') + ) + ) + ) + + assert res4[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(83) + ) + ) + + assert res5[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(-4836) + ) + ) + + assert res6[0] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res6[1] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res6[2] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('b'), + Mul( + Integer(4), + Symbol('c') + ) + ) + ) + ) + + assert res7[0] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res7[1] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('b'), + Integer(2) + ) + ) + ) + + assert res7[2] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Mul( + Integer(10), + Pow( + Symbol('b'), + Integer(2) + ), + Symbol('c') + ) + ) + ) + + assert res8[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ), + Declaration( + Variable(Symbol('temp'), + type=IntBaseType(String('intc')), + value=Symbol('a') + ) + ), + Assignment( + Variable(Symbol('a')), + Symbol('b') + ), + Assignment( + Variable(Symbol('b')), + Symbol('temp') + ) + ) + ) + + assert res9[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res9[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res9[2] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Symbol('a') + ) + ) + + assert res9[3] == Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('a'), + Symbol('b'), + Symbol('c') + ) + ) + ) + + assert res9[4] == Declaration( + Variable(Symbol('e'), + type=IntBaseType(String('intc')), + value=Add( + Pow( + Symbol('a'), + Integer(3) + ), + Mul( + Integer(3), + Pow( + Symbol('a'), + Integer(2) + ), + Symbol('b') + ), + Mul( + Integer(3), + Symbol('a'), + Pow( + Symbol('b'), + Integer(2) + ) + ), + Pow( + Symbol('b'), + Integer(3) + ) + ) + ) + ) + + assert res9[5] == Declaration( + Variable(Symbol('f'), + type=IntBaseType(String('intc')), + value=Mul( + Add( + Symbol('a'), + Symbol('b'), + Mul( + Integer(-1), + Symbol('c') + ) + ), + Add( + Symbol('a'), + Symbol('b'), + Symbol('c') + ) + ) + ) + ) + + assert res9[6] == Declaration( + Variable(Symbol('g'), + type=IntBaseType(String('intc')), + value=Mul( + Symbol('a'), + Add( + Symbol('b'), + Mul( + Integer(-1), + Symbol('c') + ) + ), + Pow( + Add( + Symbol('a'), + Symbol('b'), + Symbol('c'), + Symbol('d') + ), + Integer(2) + ) + ) + ) + ) + + assert res10[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('10.0', precision=53) + ) + ) + + assert res10[1] == Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ) + + assert res10[2] == Declaration( + Variable(Symbol('c'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Add( + Pow( + Symbol('a'), + Integer(2) + ), + Mul( + Integer(2), + Symbol('a'), + Symbol('b') + ), + Pow( + Symbol('b'), + Integer(2) + ) + ) + ) + ) + + assert res11[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('4.0', precision=53) + ) + ) + + assert res12[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(25) + ) + ) + + assert res13[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(-95) + ) + ) + + assert res14[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(-300) + ) + ) + + assert res15[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(4) + ) + ) + + assert res15[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res15[2] == Declaration( + Variable(Symbol('c'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Mul( + Pow( + Symbol('a'), + Integer(-1) + ), + Symbol('b') + ) + ) + ) + + assert res16[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res16[1] == Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ) + + assert res16[2] == Declaration( + Variable(Symbol('n'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ) + + assert res16[3] == Declaration( + Variable(Symbol('s'), + type=IntBaseType(String('intc')), + value=Mul( + Rational(1, 2), + Symbol('a'), + Add( + Mul( + Integer(2), + Symbol('a') + ), + Mul( + Symbol('d'), + Add( + Symbol('n'), + Integer(-1) + ) + ) + ) + ) + ) + ) + + assert res17[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res18[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res18[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Mod( + Symbol('a'), + Integer(3) + ) + ) + ) + + assert res19[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ) + assert res19[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res19[2] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Mod( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res20[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ) + + assert res20[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res20[2] == Declaration( + Variable(Symbol('mod'), + type=IntBaseType(String('intc')), + value=Integer(1000000007) + ) + ) + + assert res20[3] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Mod( + Add( + Symbol('a'), + Mul( + Integer(100), + Pow( + Symbol('a'), + Integer(-1) + ), + Symbol('b') + ) + ), + Symbol('mod') + ) + ) + ) + + assert res21[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ) + + assert res21[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res21[2] == Declaration( + Variable(Symbol('mod'), + type=IntBaseType(String('intc')), + value=Integer(1000000007) + ) + ) + + assert res21[3] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Mod( + Mul( + Add( + Symbol('a'), + Mul( + Integer(-1), + Symbol('b') + ) + ), + Add( + Symbol('a'), + Symbol('b') + ) + ), + Symbol('mod') + ) + ) + ) + + assert res22[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=false + ) + ) + + assert res22[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=true + ) + ) + + assert res23[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true + ) + ) + + assert res23[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=true + ) + ) + + assert res23[2] == Declaration( + Variable(Symbol('c'), + type=Type(String('bool')), + value=false + ) + ) + + assert res23[3] == Declaration( + Variable(Symbol('d'), + type=Type(String('bool')), + value=false + ) + ) + + assert res24[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res24[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res24[2] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Equality( + Symbol('a'), + Integer(1) + ) + ) + ) + + assert res24[3] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=Equality( + Symbol('b'), + Integer(2) + ) + ) + ) + + assert res24[4] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=Unequality( + Integer(1), + Symbol('a') + ) + ) + ) + + assert res24[5] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Unequality( + Integer(1), + Symbol('b') + ) + ) + ) + + assert res24[6] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=StrictLessThan(Symbol('a'), + Integer(0) + ) + ) + ) + + assert res24[7] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=LessThan( + Symbol('b'), + Integer(10) + ) + ) + ) + + assert res24[8] == Declaration( + Variable(Symbol('c7'), + type=Type(String('bool')), + value=StrictGreaterThan( + Symbol('a'), + Integer(0) + ) + ) + ) + + assert res24[9] == Declaration( + Variable(Symbol('c8'), + type=Type(String('bool')), + value=GreaterThan( + Symbol('b'), + Integer(11) + ) + ) + ) + + assert res25[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res25[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(4) + ) + ) + + assert res25[2] == Declaration(Variable(Symbol('c1'), + type=Type(String('bool')), + value=Equality( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[3] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=Unequality( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[4] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=StrictLessThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[5] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=LessThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[6] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=StrictGreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[7] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=GreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res26[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.25', precision=53) + ) + ) + + assert res26[1] == Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ) + + assert res26[2] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Equality( + Symbol('a'), + Float('1.25', precision=53) + ) + ) + ) + + assert res26[3] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=Equality( + Symbol('b'), + Float('2.54', precision=53) + ) + ) + ) + + assert res26[4] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=Unequality( + Float('1.2', precision=53), + Symbol('a') + ) + ) + ) + + assert res26[5] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Unequality( + Float('1.5', precision=53), + Symbol('b') + ) + ) + ) + + assert res27[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.25', precision=53) + ) + ) + + assert res27[1] == Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ) + + assert res27[2] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Equality( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[3] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=Unequality( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[4] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=StrictLessThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[5] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=LessThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[6] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=StrictGreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[7] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=GreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res28[0] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=true + ) + ) + + assert res28[1] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=false + ) + ) + + assert res28[2] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=true + ) + ) + + assert res28[3] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=false + ) + ) + + assert res28[4] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=true + ) + ) + + assert res28[5] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=false + ) + ) + + assert res29[0] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=true + ) + ) + + assert res29[1] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=false + ) + ) + + assert res29[2] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=false + ) + ) + + assert res29[3] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=true + ) + ) + + assert res29[4] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=true + ) + ) + + assert res29[5] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=false + ) + ) + + assert res30[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=false + ) + ) + + assert res30[1] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Symbol('a') + ) + ) + + assert res30[2] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=false + ) + ) + + assert res30[3] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=true + ) + ) + + assert res30[4] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Symbol('a') + ) + ) + + assert res31[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res31[1] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Symbol('a') + ) + ) + + assert res31[2] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=false + ) + ) + + assert res31[3] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=true + ) + ) + + assert res31[4] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Symbol('a') + ) + ) + + assert res32[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res32[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ) + + assert res32[2] == Declaration( + Variable(Symbol('c'), + type=Type(String('bool')), + value=false + ) + ) + + assert res32[3] == Declaration( + Variable(Symbol('d'), + type=Type(String('bool')), + value=true + ) + ) + + assert res32[4] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=And( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res32[5] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=And( + Symbol('a'), + Symbol('c') + ) + ) + ) + + assert res32[6] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=And( + Symbol('c'), + Symbol('d') + ) + ) + ) + + assert res32[7] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Or( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res32[8] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=Or( + Symbol('a'), + Symbol('c') + ) + ) + ) + + assert res32[9] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=Or( + Symbol('c'), + Symbol('d') + ) + ) + ) + + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise1, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise2, 'c')) + + + def test_paren_expr(): + c_src1 = ( + 'int a = (1);' + 'int b = (1 + 2 * 3);' + ) + + c_src2 = ( + 'int a = 1, b = 2, c = 3;' + 'int d = (a);' + 'int e = (a + 1);' + 'int f = (a + b * c - d / e);' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + + assert res1[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res1[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(7) + ) + ) + + assert res2[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res2[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res2[2] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res2[3] == Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=Symbol('a') + ) + ) + + assert res2[4] == Declaration( + Variable(Symbol('e'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('a'), + Integer(1) + ) + ) + ) + + assert res2[5] == Declaration( + Variable(Symbol('f'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('a'), + Mul( + Symbol('b'), + Symbol('c') + ), + Mul( + Integer(-1), + Symbol('d'), + Pow( + Symbol('e'), + Integer(-1) + ) + ) + ) + ) + ) + + + def test_unary_operators(): + c_src1 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'int b = 20;' + '\n' + + '++a;' + '\n' + + '--b;' + '\n' + + 'a++;' + '\n' + + 'b--;' + '\n' + + '}' + ) + + c_src2 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'int b = -100;' + '\n' + + 'int c = +19;' + '\n' + + 'int d = ++a;' + '\n' + + 'int e = --b;' + '\n' + + 'int f = a++;' + '\n' + + 'int g = b--;' + '\n' + + 'bool h = !false;' + '\n' + + 'bool i = !d;' + '\n' + + 'bool j = !0;' + '\n' + + 'bool k = !10.0;' + '\n' + + '}' + ) + + c_src_raise1 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'int b = ~a;' + '\n' + + '}' + ) + + c_src_raise2 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'int b = *&a;' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(20) + ) + ), + PreIncrement(Symbol('a')), + PreDecrement(Symbol('b')), + PostIncrement(Symbol('a')), + PostDecrement(Symbol('b')) + ) + ) + + assert res2[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(-100) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Integer(19) + ) + ), + Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=PreIncrement(Symbol('a')) + ) + ), + Declaration( + Variable(Symbol('e'), + type=IntBaseType(String('intc')), + value=PreDecrement(Symbol('b')) + ) + ), + Declaration( + Variable(Symbol('f'), + type=IntBaseType(String('intc')), + value=PostIncrement(Symbol('a')) + ) + ), + Declaration( + Variable(Symbol('g'), + type=IntBaseType(String('intc')), + value=PostDecrement(Symbol('b')) + ) + ), + Declaration( + Variable(Symbol('h'), + type=Type(String('bool')), + value=true + ) + ), + Declaration( + Variable(Symbol('i'), + type=Type(String('bool')), + value=Not(Symbol('d')) + ) + ), + Declaration( + Variable(Symbol('j'), + type=Type(String('bool')), + value=true + ) + ), + Declaration( + Variable(Symbol('k'), + type=Type(String('bool')), + value=false + ) + ) + ) + ) + + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise1, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise2, 'c')) + + + def test_compound_assignment_operator(): + c_src = ( + 'void func()'+ + '{' + '\n' + + 'int a = 100;' + '\n' + + 'a += 10;' + '\n' + + 'a -= 10;' + '\n' + + 'a *= 10;' + '\n' + + 'a /= 10;' + '\n' + + 'a %= 10;' + '\n' + + '}' + ) + + res = SymPyExpression(c_src, 'c').return_expr() + + assert res[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ), + AddAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ), + SubAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ), + MulAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ), + DivAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ), + ModAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ) + ) + ) + + @XFAIL # this is expected to fail because of a bug in the C parser. + def test_while_stmt(): + c_src1 = ( + 'void func()'+ + '{' + '\n' + + 'int i = 0;' + '\n' + + 'while(i < 10)' + '\n' + + '{' + '\n' + + 'i++;' + '\n' + + '}' + '}' + ) + + c_src2 = ( + 'void func()'+ + '{' + '\n' + + 'int i = 0;' + '\n' + + 'while(i < 10)' + '\n' + + 'i++;' + '\n' + + '}' + ) + + c_src3 = ( + 'void func()'+ + '{' + '\n' + + 'int i = 10;' + '\n' + + 'int cnt = 0;' + '\n' + + 'while(i > 0)' + '\n' + + '{' + '\n' + + 'i--;' + '\n' + + 'cnt++;' + '\n' + + '}' + '\n' + + '}' + ) + + c_src4 = ( + 'int digit_sum(int n)'+ + '{' + '\n' + + 'int sum = 0;' + '\n' + + 'while(n > 0)' + '\n' + + '{' + '\n' + + 'sum += (n % 10);' + '\n' + + 'n /= 10;' + '\n' + + '}' + '\n' + + 'return sum;' + '\n' + + '}' + ) + + c_src5 = ( + 'void func()'+ + '{' + '\n' + + 'while(1);' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('i'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ), + While( + StrictLessThan( + Symbol('i'), + Integer(10) + ), + body=CodeBlock( + PostIncrement( + Symbol('i') + ) + ) + ) + ) + ) + + assert res2[0] == res1[0] + + assert res3[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('i'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Declaration( + Variable( + Symbol('cnt'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ), + While( + StrictGreaterThan( + Symbol('i'), + Integer(0) + ), + body=CodeBlock( + PostDecrement( + Symbol('i') + ), + PostIncrement( + Symbol('cnt') + ) + ) + ) + ) + ) + + assert res4[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('digit_sum'), + parameters=( + Variable( + Symbol('n'), + type=IntBaseType(String('intc')) + ), + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('sum'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ), + While( + StrictGreaterThan( + Symbol('n'), + Integer(0) + ), + body=CodeBlock( + AddAugmentedAssignment( + Variable( + Symbol('sum') + ), + Mod( + Symbol('n'), + Integer(10) + ) + ), + DivAugmentedAssignment( + Variable( + Symbol('n') + ), + Integer(10) + ) + ) + ), + Return('sum') + ) + ) + + assert res5[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + While( + Integer(1), + body=CodeBlock( + NoneToken() + ) + ) + ) + ) + + +else: + def test_raise(): + from sympy.parsing.c.c_parser import CCodeConverter + raises(ImportError, lambda: CCodeConverter()) + raises(ImportError, lambda: SymPyExpression(' ', mode = 'c')) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_custom_latex.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_custom_latex.py new file mode 100644 index 0000000000000000000000000000000000000000..f5eff1c9ec79528c7f9e3a06cf9e2f84c86091ee --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_custom_latex.py @@ -0,0 +1,69 @@ +import os +import tempfile +from pathlib import Path + +import sympy +from sympy.testing.pytest import raises +from sympy.parsing.latex.lark import LarkLaTeXParser, TransformToSymPyExpr, parse_latex_lark +from sympy.external import import_module + +lark = import_module("lark") + +# disable tests if lark is not present +disabled = lark is None + +grammar_file = os.path.join(os.path.dirname(__file__), "../latex/lark/grammar/latex.lark") + +modification1 = """ +%override DIV_SYMBOL: DIV +%override MUL_SYMBOL: MUL | CMD_TIMES +""" + +modification2 = r""" +%override number: /\d+(,\d*)?/ +""" + +def init_custom_parser(modification, transformer=None): + latex_grammar = Path(grammar_file).read_text(encoding="utf-8") + latex_grammar += modification + + with tempfile.NamedTemporaryFile() as f: + f.write(bytes(latex_grammar, encoding="utf8")) + f.flush() + + parser = LarkLaTeXParser(grammar_file=f.name, transformer=transformer) + + return parser + +def test_custom1(): + # Removes the parser's ability to understand \cdot and \div. + + parser = init_custom_parser(modification1) + + with raises(lark.exceptions.UnexpectedCharacters): + parser.doparse(r"a \cdot b") + parser.doparse(r"x \div y") + +class CustomTransformer(TransformToSymPyExpr): + def number(self, tokens): + if "," in tokens[0]: + # The Float constructor expects a dot as the decimal separator + return sympy.core.numbers.Float(tokens[0].replace(",", ".")) + else: + return sympy.core.numbers.Integer(tokens[0]) + +def test_custom2(): + # Makes the parser parse commas as the decimal separator instead of dots + + parser = init_custom_parser(modification2, CustomTransformer) + + with raises(lark.exceptions.UnexpectedCharacters): + # Asserting that the default parser cannot parse numbers which have commas as + # the decimal separator + parse_latex_lark("100,1") + parse_latex_lark("0,009") + + parser.doparse("100,1") + parser.doparse("0,009") + parser.doparse("2,71828") + parser.doparse("3,14159") diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_fortran_parser.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_fortran_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcd54533ef231dd0a116910453dff0e993bc727 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_fortran_parser.py @@ -0,0 +1,406 @@ +from sympy.testing.pytest import raises +from sympy.parsing.sym_expr import SymPyExpression +from sympy.external import import_module + +lfortran = import_module('lfortran') + +if lfortran: + from sympy.codegen.ast import (Variable, IntBaseType, FloatBaseType, String, + Return, FunctionDefinition, Assignment, + Declaration, CodeBlock) + from sympy.core import Integer, Float, Add + from sympy.core.symbol import Symbol + + + expr1 = SymPyExpression() + expr2 = SymPyExpression() + src = """\ + integer :: a, b, c, d + real :: p, q, r, s + """ + + + def test_sym_expr(): + src1 = ( + src + + """\ + d = a + b -c + """ + ) + expr3 = SymPyExpression(src,'f') + expr4 = SymPyExpression(src1,'f') + ls1 = expr3.return_expr() + ls2 = expr4.return_expr() + for i in range(0, 7): + assert isinstance(ls1[i], Declaration) + assert isinstance(ls2[i], Declaration) + assert isinstance(ls2[8], Assignment) + assert ls1[0] == Declaration( + Variable( + Symbol('a'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls1[1] == Declaration( + Variable( + Symbol('b'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls1[2] == Declaration( + Variable( + Symbol('c'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls1[3] == Declaration( + Variable( + Symbol('d'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls1[4] == Declaration( + Variable( + Symbol('p'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls1[5] == Declaration( + Variable( + Symbol('q'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls1[6] == Declaration( + Variable( + Symbol('r'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls1[7] == Declaration( + Variable( + Symbol('s'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls2[8] == Assignment( + Variable(Symbol('d')), + Symbol('a') + Symbol('b') - Symbol('c') + ) + + def test_assignment(): + src1 = ( + src + + """\ + a = b + c = d + p = q + r = s + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(0, 12): + if iter < 8: + assert isinstance(ls1[iter], Declaration) + else: + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('a')), + Variable(Symbol('b')) + ) + assert ls1[9] == Assignment( + Variable(Symbol('c')), + Variable(Symbol('d')) + ) + assert ls1[10] == Assignment( + Variable(Symbol('p')), + Variable(Symbol('q')) + ) + assert ls1[11] == Assignment( + Variable(Symbol('r')), + Variable(Symbol('s')) + ) + + + def test_binop_add(): + src1 = ( + src + + """\ + c = a + b + d = a + c + s = p + q + r + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 11): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('c')), + Symbol('a') + Symbol('b') + ) + assert ls1[9] == Assignment( + Variable(Symbol('d')), + Symbol('a') + Symbol('c') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') + Symbol('q') + Symbol('r') + ) + + + def test_binop_sub(): + src1 = ( + src + + """\ + c = a - b + d = a - c + s = p - q - r + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 11): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('c')), + Symbol('a') - Symbol('b') + ) + assert ls1[9] == Assignment( + Variable(Symbol('d')), + Symbol('a') - Symbol('c') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') - Symbol('q') - Symbol('r') + ) + + + def test_binop_mul(): + src1 = ( + src + + """\ + c = a * b + d = a * c + s = p * q * r + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 11): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('c')), + Symbol('a') * Symbol('b') + ) + assert ls1[9] == Assignment( + Variable(Symbol('d')), + Symbol('a') * Symbol('c') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') * Symbol('q') * Symbol('r') + ) + + + def test_binop_div(): + src1 = ( + src + + """\ + c = a / b + d = a / c + s = p / q + r = q / p + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 12): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('c')), + Symbol('a') / Symbol('b') + ) + assert ls1[9] == Assignment( + Variable(Symbol('d')), + Symbol('a') / Symbol('c') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') / Symbol('q') + ) + assert ls1[11] == Assignment( + Variable(Symbol('r')), + Symbol('q') / Symbol('p') + ) + + def test_mul_binop(): + src1 = ( + src + + """\ + d = a + b - c + c = a * b + d + s = p * q / r + r = p * s + q / p + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 12): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('d')), + Symbol('a') + Symbol('b') - Symbol('c') + ) + assert ls1[9] == Assignment( + Variable(Symbol('c')), + Symbol('a') * Symbol('b') + Symbol('d') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') * Symbol('q') / Symbol('r') + ) + assert ls1[11] == Assignment( + Variable(Symbol('r')), + Symbol('p') * Symbol('s') + Symbol('q') / Symbol('p') + ) + + + def test_function(): + src1 = """\ + integer function f(a,b) + integer :: x, y + f = x + y + end function + """ + expr1.convert_to_expr(src1, 'f') + for iter in expr1.return_expr(): + assert isinstance(iter, FunctionDefinition) + assert iter == FunctionDefinition( + IntBaseType(String('integer')), + name=String('f'), + parameters=( + Variable(Symbol('a')), + Variable(Symbol('b')) + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Declaration( + Variable( + Symbol('f'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Declaration( + Variable( + Symbol('x'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Declaration( + Variable( + Symbol('y'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Assignment( + Variable(Symbol('f')), + Add(Symbol('x'), Symbol('y')) + ), + Return(Variable(Symbol('f'))) + ) + ) + + + def test_var(): + expr1.convert_to_expr(src, 'f') + ls = expr1.return_expr() + for iter in expr1.return_expr(): + assert isinstance(iter, Declaration) + assert ls[0] == Declaration( + Variable( + Symbol('a'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls[1] == Declaration( + Variable( + Symbol('b'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls[2] == Declaration( + Variable( + Symbol('c'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls[3] == Declaration( + Variable( + Symbol('d'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls[4] == Declaration( + Variable( + Symbol('p'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls[5] == Declaration( + Variable( + Symbol('q'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls[6] == Declaration( + Variable( + Symbol('r'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls[7] == Declaration( + Variable( + Symbol('s'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + +else: + def test_raise(): + from sympy.parsing.fortran.fortran_parser import ASR2PyVisitor + raises(ImportError, lambda: ASR2PyVisitor()) + raises(ImportError, lambda: SymPyExpression(' ', mode = 'f')) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_implicit_multiplication_application.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_implicit_multiplication_application.py new file mode 100644 index 0000000000000000000000000000000000000000..56df361e77b0c0f94bdb53b03e0dc30a8a10899f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_implicit_multiplication_application.py @@ -0,0 +1,195 @@ +import sympy +from sympy.parsing.sympy_parser import ( + parse_expr, + standard_transformations, + convert_xor, + implicit_multiplication_application, + implicit_multiplication, + implicit_application, + function_exponentiation, + split_symbols, + split_symbols_custom, + _token_splittable +) +from sympy.testing.pytest import raises + + +def test_implicit_multiplication(): + cases = { + '5x': '5*x', + 'abc': 'a*b*c', + '3sin(x)': '3*sin(x)', + '(x+1)(x+2)': '(x+1)*(x+2)', + '(5 x**2)sin(x)': '(5*x**2)*sin(x)', + '2 sin(x) cos(x)': '2*sin(x)*cos(x)', + 'pi x': 'pi*x', + 'x pi': 'x*pi', + 'E x': 'E*x', + 'EulerGamma y': 'EulerGamma*y', + 'E pi': 'E*pi', + 'pi (x + 2)': 'pi*(x+2)', + '(x + 2) pi': '(x+2)*pi', + 'pi sin(x)': 'pi*sin(x)', + } + transformations = standard_transformations + (convert_xor,) + transformations2 = transformations + (split_symbols, + implicit_multiplication) + for case in cases: + implicit = parse_expr(case, transformations=transformations2) + normal = parse_expr(cases[case], transformations=transformations) + assert(implicit == normal) + + application = ['sin x', 'cos 2*x', 'sin cos x'] + for case in application: + raises(SyntaxError, + lambda: parse_expr(case, transformations=transformations2)) + raises(TypeError, + lambda: parse_expr('sin**2(x)', transformations=transformations2)) + + +def test_implicit_application(): + cases = { + 'factorial': 'factorial', + 'sin x': 'sin(x)', + 'tan y**3': 'tan(y**3)', + 'cos 2*x': 'cos(2*x)', + '(cot)': 'cot', + 'sin cos tan x': 'sin(cos(tan(x)))' + } + transformations = standard_transformations + (convert_xor,) + transformations2 = transformations + (implicit_application,) + for case in cases: + implicit = parse_expr(case, transformations=transformations2) + normal = parse_expr(cases[case], transformations=transformations) + assert(implicit == normal), (implicit, normal) + + multiplication = ['x y', 'x sin x', '2x'] + for case in multiplication: + raises(SyntaxError, + lambda: parse_expr(case, transformations=transformations2)) + raises(TypeError, + lambda: parse_expr('sin**2(x)', transformations=transformations2)) + + +def test_function_exponentiation(): + cases = { + 'sin**2(x)': 'sin(x)**2', + 'exp^y(z)': 'exp(z)^y', + 'sin**2(E^(x))': 'sin(E^(x))**2' + } + transformations = standard_transformations + (convert_xor,) + transformations2 = transformations + (function_exponentiation,) + for case in cases: + implicit = parse_expr(case, transformations=transformations2) + normal = parse_expr(cases[case], transformations=transformations) + assert(implicit == normal) + + other_implicit = ['x y', 'x sin x', '2x', 'sin x', + 'cos 2*x', 'sin cos x'] + for case in other_implicit: + raises(SyntaxError, + lambda: parse_expr(case, transformations=transformations2)) + + assert parse_expr('x**2', local_dict={ 'x': sympy.Symbol('x') }, + transformations=transformations2) == parse_expr('x**2') + + +def test_symbol_splitting(): + # By default Greek letter names should not be split (lambda is a keyword + # so skip it) + transformations = standard_transformations + (split_symbols,) + greek_letters = ('alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta', + 'eta', 'theta', 'iota', 'kappa', 'mu', 'nu', 'xi', + 'omicron', 'pi', 'rho', 'sigma', 'tau', 'upsilon', + 'phi', 'chi', 'psi', 'omega') + + for letter in greek_letters: + assert(parse_expr(letter, transformations=transformations) == + parse_expr(letter)) + + # Make sure symbol splitting resolves names + transformations += (implicit_multiplication,) + local_dict = { 'e': sympy.E } + cases = { + 'xe': 'E*x', + 'Iy': 'I*y', + 'ee': 'E*E', + } + for case, expected in cases.items(): + assert(parse_expr(case, local_dict=local_dict, + transformations=transformations) == + parse_expr(expected)) + + # Make sure custom splitting works + def can_split(symbol): + if symbol not in ('unsplittable', 'names'): + return _token_splittable(symbol) + return False + transformations = standard_transformations + transformations += (split_symbols_custom(can_split), + implicit_multiplication) + + assert(parse_expr('unsplittable', transformations=transformations) == + parse_expr('unsplittable')) + assert(parse_expr('names', transformations=transformations) == + parse_expr('names')) + assert(parse_expr('xy', transformations=transformations) == + parse_expr('x*y')) + for letter in greek_letters: + assert(parse_expr(letter, transformations=transformations) == + parse_expr(letter)) + + +def test_all_implicit_steps(): + cases = { + '2x': '2*x', # implicit multiplication + 'x y': 'x*y', + 'xy': 'x*y', + 'sin x': 'sin(x)', # add parentheses + '2sin x': '2*sin(x)', + 'x y z': 'x*y*z', + 'sin(2 * 3x)': 'sin(2 * 3 * x)', + 'sin(x) (1 + cos(x))': 'sin(x) * (1 + cos(x))', + '(x + 2) sin(x)': '(x + 2) * sin(x)', + '(x + 2) sin x': '(x + 2) * sin(x)', + 'sin(sin x)': 'sin(sin(x))', + 'sin x!': 'sin(factorial(x))', + 'sin x!!': 'sin(factorial2(x))', + 'factorial': 'factorial', # don't apply a bare function + 'x sin x': 'x * sin(x)', # both application and multiplication + 'xy sin x': 'x * y * sin(x)', + '(x+2)(x+3)': '(x + 2) * (x+3)', + 'x**2 + 2xy + y**2': 'x**2 + 2 * x * y + y**2', # split the xy + 'pi': 'pi', # don't mess with constants + 'None': 'None', + 'ln sin x': 'ln(sin(x))', # multiple implicit function applications + 'sin x**2': 'sin(x**2)', # implicit application to an exponential + 'alpha': 'Symbol("alpha")', # don't split Greek letters/subscripts + 'x_2': 'Symbol("x_2")', + 'sin^2 x**2': 'sin(x**2)**2', # function raised to a power + 'sin**3(x)': 'sin(x)**3', + '(factorial)': 'factorial', + 'tan 3x': 'tan(3*x)', + 'sin^2(3*E^(x))': 'sin(3*E**(x))**2', + 'sin**2(E^(3x))': 'sin(E**(3*x))**2', + 'sin^2 (3x*E^(x))': 'sin(3*x*E^x)**2', + 'pi sin x': 'pi*sin(x)', + } + transformations = standard_transformations + (convert_xor,) + transformations2 = transformations + (implicit_multiplication_application,) + for case in cases: + implicit = parse_expr(case, transformations=transformations2) + normal = parse_expr(cases[case], transformations=transformations) + assert(implicit == normal) + + +def test_no_methods_implicit_multiplication(): + # Issue 21020 + u = sympy.Symbol('u') + transformations = standard_transformations + \ + (implicit_multiplication,) + expr = parse_expr('x.is_polynomial(x)', transformations=transformations) + assert expr == True + expr = parse_expr('(exp(x) / (1 + exp(2x))).subs(exp(x), u)', + transformations=transformations) + assert expr == u/(u**2 + 1) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_latex.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_latex.py new file mode 100644 index 0000000000000000000000000000000000000000..49a48966eacaa1cd7a242dcd0e7699c992bb1268 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_latex.py @@ -0,0 +1,358 @@ +from sympy.testing.pytest import raises, XFAIL +from sympy.external import import_module + +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.function import (Derivative, Function) +from sympy.core.mul import Mul +from sympy.core.numbers import (E, oo) +from sympy.core.power import Pow +from sympy.core.relational import (GreaterThan, LessThan, StrictGreaterThan, StrictLessThan, Unequality) +from sympy.core.symbol import Symbol +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.complexes import (Abs, conjugate) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.integers import (ceiling, floor) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (asin, cos, csc, sec, sin, tan) +from sympy.integrals.integrals import Integral +from sympy.series.limits import Limit + +from sympy.core.relational import Eq, Ne, Lt, Le, Gt, Ge +from sympy.physics.quantum.state import Bra, Ket +from sympy.abc import x, y, z, a, b, c, t, k, n +antlr4 = import_module("antlr4") + +# disable tests if antlr4-python3-runtime is not present +disabled = antlr4 is None + +theta = Symbol('theta') +f = Function('f') + + +# shorthand definitions +def _Add(a, b): + return Add(a, b, evaluate=False) + + +def _Mul(a, b): + return Mul(a, b, evaluate=False) + + +def _Pow(a, b): + return Pow(a, b, evaluate=False) + + +def _Sqrt(a): + return sqrt(a, evaluate=False) + + +def _Conjugate(a): + return conjugate(a, evaluate=False) + + +def _Abs(a): + return Abs(a, evaluate=False) + + +def _factorial(a): + return factorial(a, evaluate=False) + + +def _exp(a): + return exp(a, evaluate=False) + + +def _log(a, b): + return log(a, b, evaluate=False) + + +def _binomial(n, k): + return binomial(n, k, evaluate=False) + + +def test_import(): + from sympy.parsing.latex._build_latex_antlr import ( + build_parser, + check_antlr_version, + dir_latex_antlr + ) + # XXX: It would be better to come up with a test for these... + del build_parser, check_antlr_version, dir_latex_antlr + + +# These LaTeX strings should parse to the corresponding SymPy expression +GOOD_PAIRS = [ + (r"0", 0), + (r"1", 1), + (r"-3.14", -3.14), + (r"(-7.13)(1.5)", _Mul(-7.13, 1.5)), + (r"x", x), + (r"2x", 2*x), + (r"x^2", x**2), + (r"x^\frac{1}{2}", _Pow(x, _Pow(2, -1))), + (r"x^{3 + 1}", x**_Add(3, 1)), + (r"-c", -c), + (r"a \cdot b", a * b), + (r"a / b", a / b), + (r"a \div b", a / b), + (r"a + b", a + b), + (r"a + b - a", _Add(a+b, -a)), + (r"a^2 + b^2 = c^2", Eq(a**2 + b**2, c**2)), + (r"(x + y) z", _Mul(_Add(x, y), z)), + (r"a'b+ab'", _Add(_Mul(Symbol("a'"), b), _Mul(a, Symbol("b'")))), + (r"y''_1", Symbol("y_{1}''")), + (r"y_1''", Symbol("y_{1}''")), + (r"\left(x + y\right) z", _Mul(_Add(x, y), z)), + (r"\left( x + y\right ) z", _Mul(_Add(x, y), z)), + (r"\left( x + y\right ) z", _Mul(_Add(x, y), z)), + (r"\left[x + y\right] z", _Mul(_Add(x, y), z)), + (r"\left\{x + y\right\} z", _Mul(_Add(x, y), z)), + (r"1+1", _Add(1, 1)), + (r"0+1", _Add(0, 1)), + (r"1*2", _Mul(1, 2)), + (r"0*1", _Mul(0, 1)), + (r"1 \times 2 ", _Mul(1, 2)), + (r"x = y", Eq(x, y)), + (r"x \neq y", Ne(x, y)), + (r"x < y", Lt(x, y)), + (r"x > y", Gt(x, y)), + (r"x \leq y", Le(x, y)), + (r"x \geq y", Ge(x, y)), + (r"x \le y", Le(x, y)), + (r"x \ge y", Ge(x, y)), + (r"\lfloor x \rfloor", floor(x)), + (r"\lceil x \rceil", ceiling(x)), + (r"\langle x |", Bra('x')), + (r"| x \rangle", Ket('x')), + (r"\sin \theta", sin(theta)), + (r"\sin(\theta)", sin(theta)), + (r"\sin^{-1} a", asin(a)), + (r"\sin a \cos b", _Mul(sin(a), cos(b))), + (r"\sin \cos \theta", sin(cos(theta))), + (r"\sin(\cos \theta)", sin(cos(theta))), + (r"\frac{a}{b}", a / b), + (r"\dfrac{a}{b}", a / b), + (r"\tfrac{a}{b}", a / b), + (r"\frac12", _Pow(2, -1)), + (r"\frac12y", _Mul(_Pow(2, -1), y)), + (r"\frac1234", _Mul(_Pow(2, -1), 34)), + (r"\frac2{3}", _Mul(2, _Pow(3, -1))), + (r"\frac{\sin{x}}2", _Mul(sin(x), _Pow(2, -1))), + (r"\frac{a + b}{c}", _Mul(a + b, _Pow(c, -1))), + (r"\frac{7}{3}", _Mul(7, _Pow(3, -1))), + (r"(\csc x)(\sec y)", csc(x)*sec(y)), + (r"\lim_{x \to 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \rightarrow 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \Rightarrow 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \longrightarrow 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \Longrightarrow 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \to 3^{+}} a", Limit(a, x, 3, dir='+')), + (r"\lim_{x \to 3^{-}} a", Limit(a, x, 3, dir='-')), + (r"\lim_{x \to 3^+} a", Limit(a, x, 3, dir='+')), + (r"\lim_{x \to 3^-} a", Limit(a, x, 3, dir='-')), + (r"\infty", oo), + (r"\lim_{x \to \infty} \frac{1}{x}", Limit(_Pow(x, -1), x, oo)), + (r"\frac{d}{dx} x", Derivative(x, x)), + (r"\frac{d}{dt} x", Derivative(x, t)), + (r"f(x)", f(x)), + (r"f(x, y)", f(x, y)), + (r"f(x, y, z)", f(x, y, z)), + (r"f'_1(x)", Function("f_{1}'")(x)), + (r"f_{1}''(x+y)", Function("f_{1}''")(x+y)), + (r"\frac{d f(x)}{dx}", Derivative(f(x), x)), + (r"\frac{d\theta(x)}{dx}", Derivative(Function('theta')(x), x)), + (r"x \neq y", Unequality(x, y)), + (r"|x|", _Abs(x)), + (r"||x||", _Abs(Abs(x))), + (r"|x||y|", _Abs(x)*_Abs(y)), + (r"||x||y||", _Abs(_Abs(x)*_Abs(y))), + (r"\pi^{|xy|}", Symbol('pi')**_Abs(x*y)), + (r"\int x dx", Integral(x, x)), + (r"\int x d\theta", Integral(x, theta)), + (r"\int (x^2 - y)dx", Integral(x**2 - y, x)), + (r"\int x + a dx", Integral(_Add(x, a), x)), + (r"\int da", Integral(1, a)), + (r"\int_0^7 dx", Integral(1, (x, 0, 7))), + (r"\int\limits_{0}^{1} x dx", Integral(x, (x, 0, 1))), + (r"\int_a^b x dx", Integral(x, (x, a, b))), + (r"\int^b_a x dx", Integral(x, (x, a, b))), + (r"\int_{a}^b x dx", Integral(x, (x, a, b))), + (r"\int^{b}_a x dx", Integral(x, (x, a, b))), + (r"\int_{a}^{b} x dx", Integral(x, (x, a, b))), + (r"\int^{b}_{a} x dx", Integral(x, (x, a, b))), + (r"\int_{f(a)}^{f(b)} f(z) dz", Integral(f(z), (z, f(a), f(b)))), + (r"\int (x+a)", Integral(_Add(x, a), x)), + (r"\int a + b + c dx", Integral(_Add(_Add(a, b), c), x)), + (r"\int \frac{dz}{z}", Integral(Pow(z, -1), z)), + (r"\int \frac{3 dz}{z}", Integral(3*Pow(z, -1), z)), + (r"\int \frac{1}{x} dx", Integral(Pow(x, -1), x)), + (r"\int \frac{1}{a} + \frac{1}{b} dx", + Integral(_Add(_Pow(a, -1), Pow(b, -1)), x)), + (r"\int \frac{3 \cdot d\theta}{\theta}", + Integral(3*_Pow(theta, -1), theta)), + (r"\int \frac{1}{x} + 1 dx", Integral(_Add(_Pow(x, -1), 1), x)), + (r"x_0", Symbol('x_{0}')), + (r"x_{1}", Symbol('x_{1}')), + (r"x_a", Symbol('x_{a}')), + (r"x_{b}", Symbol('x_{b}')), + (r"h_\theta", Symbol('h_{theta}')), + (r"h_{\theta}", Symbol('h_{theta}')), + (r"h_{\theta}(x_0, x_1)", + Function('h_{theta}')(Symbol('x_{0}'), Symbol('x_{1}'))), + (r"x!", _factorial(x)), + (r"100!", _factorial(100)), + (r"\theta!", _factorial(theta)), + (r"(x + 1)!", _factorial(_Add(x, 1))), + (r"(x!)!", _factorial(_factorial(x))), + (r"x!!!", _factorial(_factorial(_factorial(x)))), + (r"5!7!", _Mul(_factorial(5), _factorial(7))), + (r"\sqrt{x}", sqrt(x)), + (r"\sqrt{x + b}", sqrt(_Add(x, b))), + (r"\sqrt[3]{\sin x}", root(sin(x), 3)), + (r"\sqrt[y]{\sin x}", root(sin(x), y)), + (r"\sqrt[\theta]{\sin x}", root(sin(x), theta)), + (r"\sqrt{\frac{12}{6}}", _Sqrt(_Mul(12, _Pow(6, -1)))), + (r"\overline{z}", _Conjugate(z)), + (r"\overline{\overline{z}}", _Conjugate(_Conjugate(z))), + (r"\overline{x + y}", _Conjugate(_Add(x, y))), + (r"\overline{x} + \overline{y}", _Conjugate(x) + _Conjugate(y)), + (r"x < y", StrictLessThan(x, y)), + (r"x \leq y", LessThan(x, y)), + (r"x > y", StrictGreaterThan(x, y)), + (r"x \geq y", GreaterThan(x, y)), + (r"\mathit{x}", Symbol('x')), + (r"\mathit{test}", Symbol('test')), + (r"\mathit{TEST}", Symbol('TEST')), + (r"\mathit{HELLO world}", Symbol('HELLO world')), + (r"\sum_{k = 1}^{3} c", Sum(c, (k, 1, 3))), + (r"\sum_{k = 1}^3 c", Sum(c, (k, 1, 3))), + (r"\sum^{3}_{k = 1} c", Sum(c, (k, 1, 3))), + (r"\sum^3_{k = 1} c", Sum(c, (k, 1, 3))), + (r"\sum_{k = 1}^{10} k^2", Sum(k**2, (k, 1, 10))), + (r"\sum_{n = 0}^{\infty} \frac{1}{n!}", + Sum(_Pow(_factorial(n), -1), (n, 0, oo))), + (r"\prod_{a = b}^{c} x", Product(x, (a, b, c))), + (r"\prod_{a = b}^c x", Product(x, (a, b, c))), + (r"\prod^{c}_{a = b} x", Product(x, (a, b, c))), + (r"\prod^c_{a = b} x", Product(x, (a, b, c))), + (r"\exp x", _exp(x)), + (r"\exp(x)", _exp(x)), + (r"\lg x", _log(x, 10)), + (r"\ln x", _log(x, E)), + (r"\ln xy", _log(x*y, E)), + (r"\log x", _log(x, E)), + (r"\log xy", _log(x*y, E)), + (r"\log_{2} x", _log(x, 2)), + (r"\log_{a} x", _log(x, a)), + (r"\log_{11} x", _log(x, 11)), + (r"\log_{a^2} x", _log(x, _Pow(a, 2))), + (r"[x]", x), + (r"[a + b]", _Add(a, b)), + (r"\frac{d}{dx} [ \tan x ]", Derivative(tan(x), x)), + (r"\binom{n}{k}", _binomial(n, k)), + (r"\tbinom{n}{k}", _binomial(n, k)), + (r"\dbinom{n}{k}", _binomial(n, k)), + (r"\binom{n}{0}", _binomial(n, 0)), + (r"x^\binom{n}{k}", _Pow(x, _binomial(n, k))), + (r"a \, b", _Mul(a, b)), + (r"a \thinspace b", _Mul(a, b)), + (r"a \: b", _Mul(a, b)), + (r"a \medspace b", _Mul(a, b)), + (r"a \; b", _Mul(a, b)), + (r"a \thickspace b", _Mul(a, b)), + (r"a \quad b", _Mul(a, b)), + (r"a \qquad b", _Mul(a, b)), + (r"a \! b", _Mul(a, b)), + (r"a \negthinspace b", _Mul(a, b)), + (r"a \negmedspace b", _Mul(a, b)), + (r"a \negthickspace b", _Mul(a, b)), + (r"\int x \, dx", Integral(x, x)), + (r"\log_2 x", _log(x, 2)), + (r"\log_a x", _log(x, a)), + (r"5^0 - 4^0", _Add(_Pow(5, 0), _Mul(-1, _Pow(4, 0)))), + (r"3x - 1", _Add(_Mul(3, x), -1)) +] + + +def test_parseable(): + from sympy.parsing.latex import parse_latex + for latex_str, sympy_expr in GOOD_PAIRS: + assert parse_latex(latex_str) == sympy_expr, latex_str + +# These bad LaTeX strings should raise a LaTeXParsingError when parsed +BAD_STRINGS = [ + r"(", + r")", + r"\frac{d}{dx}", + r"(\frac{d}{dx})", + r"\sqrt{}", + r"\sqrt", + r"\overline{}", + r"\overline", + r"{", + r"}", + r"\mathit{x + y}", + r"\mathit{21}", + r"\frac{2}{}", + r"\frac{}{2}", + r"\int", + r"!", + r"!0", + r"_", + r"^", + r"|", + r"||x|", + r"()", + r"((((((((((((((((()))))))))))))))))", + r"-", + r"\frac{d}{dx} + \frac{d}{dt}", + r"f(x,,y)", + r"f(x,y,", + r"\sin^x", + r"\cos^2", + r"@", + r"#", + r"$", + r"%", + r"&", + r"*", + r"" "\\", + r"~", + r"\frac{(2 + x}{1 - x)}", +] + +def test_not_parseable(): + from sympy.parsing.latex import parse_latex, LaTeXParsingError + for latex_str in BAD_STRINGS: + with raises(LaTeXParsingError): + parse_latex(latex_str) + +# At time of migration from latex2sympy, should fail but doesn't +FAILING_BAD_STRINGS = [ + r"\cos 1 \cos", + r"f(,", + r"f()", + r"a \div \div b", + r"a \cdot \cdot b", + r"a // b", + r"a +", + r"1.1.1", + r"1 +", + r"a / b /", +] + +@XFAIL +def test_failing_not_parseable(): + from sympy.parsing.latex import parse_latex, LaTeXParsingError + for latex_str in FAILING_BAD_STRINGS: + with raises(LaTeXParsingError): + parse_latex(latex_str) + +# In strict mode, FAILING_BAD_STRINGS would fail +def test_strict_mode(): + from sympy.parsing.latex import parse_latex, LaTeXParsingError + for latex_str in FAILING_BAD_STRINGS: + with raises(LaTeXParsingError): + parse_latex(latex_str, strict=True) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_latex_deps.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_latex_deps.py new file mode 100644 index 0000000000000000000000000000000000000000..7df44c2b19e34024db6e898f7c4eac962dcaa1c9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_latex_deps.py @@ -0,0 +1,16 @@ +from sympy.external import import_module +from sympy.testing.pytest import ignore_warnings, raises + +antlr4 = import_module("antlr4", warn_not_installed=False) + +# disable tests if antlr4-python3-runtime is not present +if antlr4: + disabled = True + + +def test_no_import(): + from sympy.parsing.latex import parse_latex + + with ignore_warnings(UserWarning): + with raises(ImportError): + parse_latex('1 + 1') diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_latex_lark.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_latex_lark.py new file mode 100644 index 0000000000000000000000000000000000000000..dd1f72a66c788ac41d923005ea988664d05a16c1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_latex_lark.py @@ -0,0 +1,872 @@ +from sympy.testing.pytest import XFAIL +from sympy.parsing.latex.lark import parse_latex_lark +from sympy.external import import_module + +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.function import Derivative, Function +from sympy.core.numbers import E, oo, Rational +from sympy.core.power import Pow +from sympy.core.parameters import evaluate +from sympy.core.relational import GreaterThan, LessThan, StrictGreaterThan, StrictLessThan, Unequality +from sympy.core.symbol import Symbol +from sympy.functions.combinatorial.factorials import binomial, factorial +from sympy.functions.elementary.complexes import Abs, conjugate +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.miscellaneous import root, sqrt, Min, Max +from sympy.functions.elementary.trigonometric import asin, cos, csc, sec, sin, tan +from sympy.integrals.integrals import Integral +from sympy.series.limits import Limit +from sympy import Matrix, MatAdd, MatMul, Transpose, Trace +from sympy import I + +from sympy.core.relational import Eq, Ne, Lt, Le, Gt, Ge +from sympy.physics.quantum import Bra, Ket, InnerProduct +from sympy.abc import x, y, z, a, b, c, d, t, k, n + +from .test_latex import theta, f, _Add, _Mul, _Pow, _Sqrt, _Conjugate, _Abs, _factorial, _exp, _binomial + +lark = import_module("lark") + +# disable tests if lark is not present +disabled = lark is None + +# shorthand definitions that are only needed for the Lark LaTeX parser +def _Min(*args): + return Min(*args, evaluate=False) + + +def _Max(*args): + return Max(*args, evaluate=False) + + +def _log(a, b=E): + if b == E: + return log(a, evaluate=False) + else: + return log(a, b, evaluate=False) + + +def _MatAdd(a, b): + return MatAdd(a, b, evaluate=False) + + +def _MatMul(a, b): + return MatMul(a, b, evaluate=False) + + +# These LaTeX strings should parse to the corresponding SymPy expression +SYMBOL_EXPRESSION_PAIRS = [ + (r"x_0", Symbol('x_{0}')), + (r"x_{1}", Symbol('x_{1}')), + (r"x_a", Symbol('x_{a}')), + (r"x_{b}", Symbol('x_{b}')), + (r"h_\theta", Symbol('h_{theta}')), + (r"h_{\theta}", Symbol('h_{theta}')), + (r"y''_1", Symbol("y''_{1}")), + (r"y_1''", Symbol("y_{1}''")), + (r"\mathit{x}", Symbol('x')), + (r"\mathit{test}", Symbol('test')), + (r"\mathit{TEST}", Symbol('TEST')), + (r"\mathit{HELLO world}", Symbol('HELLO world')), + (r"a'", Symbol("a'")), + (r"a''", Symbol("a''")), + (r"\alpha'", Symbol("alpha'")), + (r"\alpha''", Symbol("alpha''")), + (r"a_b", Symbol("a_{b}")), + (r"a_b'", Symbol("a_{b}'")), + (r"a'_b", Symbol("a'_{b}")), + (r"a'_b'", Symbol("a'_{b}'")), + (r"a_{b'}", Symbol("a_{b'}")), + (r"a_{b'}'", Symbol("a_{b'}'")), + (r"a'_{b'}", Symbol("a'_{b'}")), + (r"a'_{b'}'", Symbol("a'_{b'}'")), + (r"\mathit{foo}'", Symbol("foo'")), + (r"\mathit{foo'}", Symbol("foo'")), + (r"\mathit{foo'}'", Symbol("foo''")), + (r"a_b''", Symbol("a_{b}''")), + (r"a''_b", Symbol("a''_{b}")), + (r"a''_b'''", Symbol("a''_{b}'''")), + (r"a_{b''}", Symbol("a_{b''}")), + (r"a_{b''}''", Symbol("a_{b''}''")), + (r"a''_{b''}", Symbol("a''_{b''}")), + (r"a''_{b''}'''", Symbol("a''_{b''}'''")), + (r"\mathit{foo}''", Symbol("foo''")), + (r"\mathit{foo''}", Symbol("foo''")), + (r"\mathit{foo''}'''", Symbol("foo'''''")), + (r"a_\alpha", Symbol("a_{alpha}")), + (r"a_\alpha'", Symbol("a_{alpha}'")), + (r"a'_\alpha", Symbol("a'_{alpha}")), + (r"a'_\alpha'", Symbol("a'_{alpha}'")), + (r"a_{\alpha'}", Symbol("a_{alpha'}")), + (r"a_{\alpha'}'", Symbol("a_{alpha'}'")), + (r"a'_{\alpha'}", Symbol("a'_{alpha'}")), + (r"a'_{\alpha'}'", Symbol("a'_{alpha'}'")), + (r"a_\alpha''", Symbol("a_{alpha}''")), + (r"a''_\alpha", Symbol("a''_{alpha}")), + (r"a''_\alpha'''", Symbol("a''_{alpha}'''")), + (r"a_{\alpha''}", Symbol("a_{alpha''}")), + (r"a_{\alpha''}''", Symbol("a_{alpha''}''")), + (r"a''_{\alpha''}", Symbol("a''_{alpha''}")), + (r"a''_{\alpha''}'''", Symbol("a''_{alpha''}'''")), + (r"\alpha_b", Symbol("alpha_{b}")), + (r"\alpha_b'", Symbol("alpha_{b}'")), + (r"\alpha'_b", Symbol("alpha'_{b}")), + (r"\alpha'_b'", Symbol("alpha'_{b}'")), + (r"\alpha_{b'}", Symbol("alpha_{b'}")), + (r"\alpha_{b'}'", Symbol("alpha_{b'}'")), + (r"\alpha'_{b'}", Symbol("alpha'_{b'}")), + (r"\alpha'_{b'}'", Symbol("alpha'_{b'}'")), + (r"\alpha_b''", Symbol("alpha_{b}''")), + (r"\alpha''_b", Symbol("alpha''_{b}")), + (r"\alpha''_b'''", Symbol("alpha''_{b}'''")), + (r"\alpha_{b''}", Symbol("alpha_{b''}")), + (r"\alpha_{b''}''", Symbol("alpha_{b''}''")), + (r"\alpha''_{b''}", Symbol("alpha''_{b''}")), + (r"\alpha''_{b''}'''", Symbol("alpha''_{b''}'''")), + (r"\alpha_\beta", Symbol("alpha_{beta}")), + (r"\alpha_{\beta}", Symbol("alpha_{beta}")), + (r"\alpha_{\beta'}", Symbol("alpha_{beta'}")), + (r"\alpha_{\beta''}", Symbol("alpha_{beta''}")), + (r"\alpha'_\beta", Symbol("alpha'_{beta}")), + (r"\alpha'_{\beta}", Symbol("alpha'_{beta}")), + (r"\alpha'_{\beta'}", Symbol("alpha'_{beta'}")), + (r"\alpha'_{\beta''}", Symbol("alpha'_{beta''}")), + (r"\alpha''_\beta", Symbol("alpha''_{beta}")), + (r"\alpha''_{\beta}", Symbol("alpha''_{beta}")), + (r"\alpha''_{\beta'}", Symbol("alpha''_{beta'}")), + (r"\alpha''_{\beta''}", Symbol("alpha''_{beta''}")), + (r"\alpha_\beta'", Symbol("alpha_{beta}'")), + (r"\alpha_{\beta}'", Symbol("alpha_{beta}'")), + (r"\alpha_{\beta'}'", Symbol("alpha_{beta'}'")), + (r"\alpha_{\beta''}'", Symbol("alpha_{beta''}'")), + (r"\alpha'_\beta'", Symbol("alpha'_{beta}'")), + (r"\alpha'_{\beta}'", Symbol("alpha'_{beta}'")), + (r"\alpha'_{\beta'}'", Symbol("alpha'_{beta'}'")), + (r"\alpha'_{\beta''}'", Symbol("alpha'_{beta''}'")), + (r"\alpha''_\beta'", Symbol("alpha''_{beta}'")), + (r"\alpha''_{\beta}'", Symbol("alpha''_{beta}'")), + (r"\alpha''_{\beta'}'", Symbol("alpha''_{beta'}'")), + (r"\alpha''_{\beta''}'", Symbol("alpha''_{beta''}'")), + (r"\alpha_\beta''", Symbol("alpha_{beta}''")), + (r"\alpha_{\beta}''", Symbol("alpha_{beta}''")), + (r"\alpha_{\beta'}''", Symbol("alpha_{beta'}''")), + (r"\alpha_{\beta''}''", Symbol("alpha_{beta''}''")), + (r"\alpha'_\beta''", Symbol("alpha'_{beta}''")), + (r"\alpha'_{\beta}''", Symbol("alpha'_{beta}''")), + (r"\alpha'_{\beta'}''", Symbol("alpha'_{beta'}''")), + (r"\alpha'_{\beta''}''", Symbol("alpha'_{beta''}''")), + (r"\alpha''_\beta''", Symbol("alpha''_{beta}''")), + (r"\alpha''_{\beta}''", Symbol("alpha''_{beta}''")), + (r"\alpha''_{\beta'}''", Symbol("alpha''_{beta'}''")), + (r"\alpha''_{\beta''}''", Symbol("alpha''_{beta''}''")) + +] + +UNEVALUATED_SIMPLE_EXPRESSION_PAIRS = [ + (r"0", 0), + (r"1", 1), + (r"-3.14", -3.14), + (r"(-7.13)(1.5)", _Mul(-7.13, 1.5)), + (r"1+1", _Add(1, 1)), + (r"0+1", _Add(0, 1)), + (r"1*2", _Mul(1, 2)), + (r"0*1", _Mul(0, 1)), + (r"x", x), + (r"2x", 2 * x), + (r"3x - 1", _Add(_Mul(3, x), -1)), + (r"-c", -c), + (r"\infty", oo), + (r"a \cdot b", a * b), + (r"1 \times 2 ", _Mul(1, 2)), + (r"a / b", a / b), + (r"a \div b", a / b), + (r"a + b", a + b), + (r"a + b - a", _Add(a + b, -a)), + (r"(x + y) z", _Mul(_Add(x, y), z)), + (r"a'b+ab'", _Add(_Mul(Symbol("a'"), b), _Mul(a, Symbol("b'")))) +] + +EVALUATED_SIMPLE_EXPRESSION_PAIRS = [ + (r"(-7.13)(1.5)", -10.695), + (r"1+1", 2), + (r"0+1", 1), + (r"1*2", 2), + (r"0*1", 0), + (r"2x", 2 * x), + (r"3x - 1", 3 * x - 1), + (r"-c", -c), + (r"a \cdot b", a * b), + (r"1 \times 2 ", 2), + (r"a / b", a / b), + (r"a \div b", a / b), + (r"a + b", a + b), + (r"a + b - a", b), + (r"(x + y) z", (x + y) * z), +] + +UNEVALUATED_FRACTION_EXPRESSION_PAIRS = [ + (r"\frac{a}{b}", a / b), + (r"\dfrac{a}{b}", a / b), + (r"\tfrac{a}{b}", a / b), + (r"\frac12", _Mul(1, _Pow(2, -1))), + (r"\frac12y", _Mul(_Mul(1, _Pow(2, -1)), y)), + (r"\frac1234", _Mul(_Mul(1, _Pow(2, -1)), 34)), + (r"\frac2{3}", _Mul(2, _Pow(3, -1))), + (r"\frac{a + b}{c}", _Mul(a + b, _Pow(c, -1))), + (r"\frac{7}{3}", _Mul(7, _Pow(3, -1))) +] + +EVALUATED_FRACTION_EXPRESSION_PAIRS = [ + (r"\frac{a}{b}", a / b), + (r"\dfrac{a}{b}", a / b), + (r"\tfrac{a}{b}", a / b), + (r"\frac12", Rational(1, 2)), + (r"\frac12y", y / 2), + (r"\frac1234", 17), + (r"\frac2{3}", Rational(2, 3)), + (r"\frac{a + b}{c}", (a + b) / c), + (r"\frac{7}{3}", Rational(7, 3)) +] + +RELATION_EXPRESSION_PAIRS = [ + (r"x = y", Eq(x, y)), + (r"x \neq y", Ne(x, y)), + (r"x < y", Lt(x, y)), + (r"x > y", Gt(x, y)), + (r"x \leq y", Le(x, y)), + (r"x \geq y", Ge(x, y)), + (r"x \le y", Le(x, y)), + (r"x \ge y", Ge(x, y)), + (r"x < y", StrictLessThan(x, y)), + (r"x \leq y", LessThan(x, y)), + (r"x > y", StrictGreaterThan(x, y)), + (r"x \geq y", GreaterThan(x, y)), + (r"x \neq y", Unequality(x, y)), # same as 2nd one in the list + (r"a^2 + b^2 = c^2", Eq(a**2 + b**2, c**2)) +] + +UNEVALUATED_POWER_EXPRESSION_PAIRS = [ + (r"x^2", x ** 2), + (r"x^\frac{1}{2}", _Pow(x, _Mul(1, _Pow(2, -1)))), + (r"x^{3 + 1}", x ** _Add(3, 1)), + (r"\pi^{|xy|}", Symbol('pi') ** _Abs(x * y)), + (r"5^0 - 4^0", _Add(_Pow(5, 0), _Mul(-1, _Pow(4, 0)))) +] + +EVALUATED_POWER_EXPRESSION_PAIRS = [ + (r"x^2", x ** 2), + (r"x^\frac{1}{2}", sqrt(x)), + (r"x^{3 + 1}", x ** 4), + (r"\pi^{|xy|}", Symbol('pi') ** _Abs(x * y)), + (r"5^0 - 4^0", 0) +] + +UNEVALUATED_INTEGRAL_EXPRESSION_PAIRS = [ + (r"\int x dx", Integral(_Mul(1, x), x)), + (r"\int x \, dx", Integral(_Mul(1, x), x)), + (r"\int x d\theta", Integral(_Mul(1, x), theta)), + (r"\int (x^2 - y)dx", Integral(_Mul(1, x ** 2 - y), x)), + (r"\int x + a dx", Integral(_Mul(1, _Add(x, a)), x)), + (r"\int da", Integral(_Mul(1, 1), a)), + (r"\int_0^7 dx", Integral(_Mul(1, 1), (x, 0, 7))), + (r"\int\limits_{0}^{1} x dx", Integral(_Mul(1, x), (x, 0, 1))), + (r"\int_a^b x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int^b_a x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int_{a}^b x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int^{b}_a x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int_{a}^{b} x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int^{b}_{a} x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int_{f(a)}^{f(b)} f(z) dz", Integral(f(z), (z, f(a), f(b)))), + (r"\int a + b + c dx", Integral(_Mul(1, _Add(_Add(a, b), c)), x)), + (r"\int \frac{dz}{z}", Integral(_Mul(1, _Mul(1, Pow(z, -1))), z)), + (r"\int \frac{3 dz}{z}", Integral(_Mul(1, _Mul(3, _Pow(z, -1))), z)), + (r"\int \frac{1}{x} dx", Integral(_Mul(1, _Mul(1, Pow(x, -1))), x)), + (r"\int \frac{1}{a} + \frac{1}{b} dx", + Integral(_Mul(1, _Add(_Mul(1, _Pow(a, -1)), _Mul(1, Pow(b, -1)))), x)), + (r"\int \frac{1}{x} + 1 dx", Integral(_Mul(1, _Add(_Mul(1, _Pow(x, -1)), 1)), x)) +] + +EVALUATED_INTEGRAL_EXPRESSION_PAIRS = [ + (r"\int x dx", Integral(x, x)), + (r"\int x \, dx", Integral(x, x)), + (r"\int x d\theta", Integral(x, theta)), + (r"\int (x^2 - y)dx", Integral(x ** 2 - y, x)), + (r"\int x + a dx", Integral(x + a, x)), + (r"\int da", Integral(1, a)), + (r"\int_0^7 dx", Integral(1, (x, 0, 7))), + (r"\int\limits_{0}^{1} x dx", Integral(x, (x, 0, 1))), + (r"\int_a^b x dx", Integral(x, (x, a, b))), + (r"\int^b_a x dx", Integral(x, (x, a, b))), + (r"\int_{a}^b x dx", Integral(x, (x, a, b))), + (r"\int^{b}_a x dx", Integral(x, (x, a, b))), + (r"\int_{a}^{b} x dx", Integral(x, (x, a, b))), + (r"\int^{b}_{a} x dx", Integral(x, (x, a, b))), + (r"\int_{f(a)}^{f(b)} f(z) dz", Integral(f(z), (z, f(a), f(b)))), + (r"\int a + b + c dx", Integral(a + b + c, x)), + (r"\int \frac{dz}{z}", Integral(Pow(z, -1), z)), + (r"\int \frac{3 dz}{z}", Integral(3 * Pow(z, -1), z)), + (r"\int \frac{1}{x} dx", Integral(1 / x, x)), + (r"\int \frac{1}{a} + \frac{1}{b} dx", Integral(1 / a + 1 / b, x)), + (r"\int \frac{1}{a} - \frac{1}{b} dx", Integral(1 / a - 1 / b, x)), + (r"\int \frac{1}{x} + 1 dx", Integral(1 / x + 1, x)) +] + +DERIVATIVE_EXPRESSION_PAIRS = [ + (r"\frac{d}{dx} x", Derivative(x, x)), + (r"\frac{d}{dt} x", Derivative(x, t)), + (r"\frac{d}{dx} ( \tan x )", Derivative(tan(x), x)), + (r"\frac{d f(x)}{dx}", Derivative(f(x), x)), + (r"\frac{d\theta(x)}{dx}", Derivative(Function('theta')(x), x)) +] + +TRIGONOMETRIC_EXPRESSION_PAIRS = [ + (r"\sin \theta", sin(theta)), + (r"\sin(\theta)", sin(theta)), + (r"\sin^{-1} a", asin(a)), + (r"\sin a \cos b", _Mul(sin(a), cos(b))), + (r"\sin \cos \theta", sin(cos(theta))), + (r"\sin(\cos \theta)", sin(cos(theta))), + (r"(\csc x)(\sec y)", csc(x) * sec(y)), + (r"\frac{\sin{x}}2", _Mul(sin(x), _Pow(2, -1))) +] + +UNEVALUATED_LIMIT_EXPRESSION_PAIRS = [ + (r"\lim_{x \to 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \rightarrow 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \Rightarrow 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \longrightarrow 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \Longrightarrow 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \to 3^{+}} a", Limit(a, x, 3, dir="+")), + (r"\lim_{x \to 3^{-}} a", Limit(a, x, 3, dir="-")), + (r"\lim_{x \to 3^+} a", Limit(a, x, 3, dir="+")), + (r"\lim_{x \to 3^-} a", Limit(a, x, 3, dir="-")), + (r"\lim_{x \to \infty} \frac{1}{x}", Limit(_Mul(1, _Pow(x, -1)), x, oo)) +] + +EVALUATED_LIMIT_EXPRESSION_PAIRS = [ + (r"\lim_{x \to \infty} \frac{1}{x}", Limit(1 / x, x, oo)) +] + +UNEVALUATED_SQRT_EXPRESSION_PAIRS = [ + (r"\sqrt{x}", sqrt(x)), + (r"\sqrt{x + b}", sqrt(_Add(x, b))), + (r"\sqrt[3]{\sin x}", _Pow(sin(x), _Pow(3, -1))), + # the above test needed to be handled differently than the ones below because root + # acts differently if its second argument is a number + (r"\sqrt[y]{\sin x}", root(sin(x), y)), + (r"\sqrt[\theta]{\sin x}", root(sin(x), theta)), + (r"\sqrt{\frac{12}{6}}", _Sqrt(_Mul(12, _Pow(6, -1)))) +] + +EVALUATED_SQRT_EXPRESSION_PAIRS = [ + (r"\sqrt{x}", sqrt(x)), + (r"\sqrt{x + b}", sqrt(x + b)), + (r"\sqrt[3]{\sin x}", root(sin(x), 3)), + (r"\sqrt[y]{\sin x}", root(sin(x), y)), + (r"\sqrt[\theta]{\sin x}", root(sin(x), theta)), + (r"\sqrt{\frac{12}{6}}", sqrt(2)) +] + +UNEVALUATED_FACTORIAL_EXPRESSION_PAIRS = [ + (r"x!", _factorial(x)), + (r"100!", _factorial(100)), + (r"\theta!", _factorial(theta)), + (r"(x + 1)!", _factorial(_Add(x, 1))), + (r"(x!)!", _factorial(_factorial(x))), + (r"x!!!", _factorial(_factorial(_factorial(x)))), + (r"5!7!", _Mul(_factorial(5), _factorial(7))) +] + +EVALUATED_FACTORIAL_EXPRESSION_PAIRS = [ + (r"x!", factorial(x)), + (r"100!", factorial(100)), + (r"\theta!", factorial(theta)), + (r"(x + 1)!", factorial(x + 1)), + (r"(x!)!", factorial(factorial(x))), + (r"x!!!", factorial(factorial(factorial(x)))), + (r"5!7!", factorial(5) * factorial(7)), + (r"24! \times 24!", factorial(24) * factorial(24)) +] + +UNEVALUATED_SUM_EXPRESSION_PAIRS = [ + (r"\sum_{k = 1}^{3} c", Sum(_Mul(1, c), (k, 1, 3))), + (r"\sum_{k = 1}^3 c", Sum(_Mul(1, c), (k, 1, 3))), + (r"\sum^{3}_{k = 1} c", Sum(_Mul(1, c), (k, 1, 3))), + (r"\sum^3_{k = 1} c", Sum(_Mul(1, c), (k, 1, 3))), + (r"\sum_{k = 1}^{10} k^2", Sum(_Mul(1, k ** 2), (k, 1, 10))), + (r"\sum_{n = 0}^{\infty} \frac{1}{n!}", + Sum(_Mul(1, _Mul(1, _Pow(_factorial(n), -1))), (n, 0, oo))) +] + +EVALUATED_SUM_EXPRESSION_PAIRS = [ + (r"\sum_{k = 1}^{3} c", Sum(c, (k, 1, 3))), + (r"\sum_{k = 1}^3 c", Sum(c, (k, 1, 3))), + (r"\sum^{3}_{k = 1} c", Sum(c, (k, 1, 3))), + (r"\sum^3_{k = 1} c", Sum(c, (k, 1, 3))), + (r"\sum_{k = 1}^{10} k^2", Sum(k ** 2, (k, 1, 10))), + (r"\sum_{n = 0}^{\infty} \frac{1}{n!}", Sum(1 / factorial(n), (n, 0, oo))) +] + +UNEVALUATED_PRODUCT_EXPRESSION_PAIRS = [ + (r"\prod_{a = b}^{c} x", Product(x, (a, b, c))), + (r"\prod_{a = b}^c x", Product(x, (a, b, c))), + (r"\prod^{c}_{a = b} x", Product(x, (a, b, c))), + (r"\prod^c_{a = b} x", Product(x, (a, b, c))) +] + +APPLIED_FUNCTION_EXPRESSION_PAIRS = [ + (r"f(x)", f(x)), + (r"f(x, y)", f(x, y)), + (r"f(x, y, z)", f(x, y, z)), + (r"f'_1(x)", Function("f_{1}'")(x)), + (r"f_{1}''(x+y)", Function("f_{1}''")(x + y)), + (r"h_{\theta}(x_0, x_1)", + Function('h_{theta}')(Symbol('x_{0}'), Symbol('x_{1}'))) +] + +UNEVALUATED_COMMON_FUNCTION_EXPRESSION_PAIRS = [ + (r"|x|", _Abs(x)), + (r"||x||", _Abs(Abs(x))), + (r"|x||y|", _Abs(x) * _Abs(y)), + (r"||x||y||", _Abs(_Abs(x) * _Abs(y))), + (r"\lfloor x \rfloor", floor(x)), + (r"\lceil x \rceil", ceiling(x)), + (r"\exp x", _exp(x)), + (r"\exp(x)", _exp(x)), + (r"\lg x", _log(x, 10)), + (r"\ln x", _log(x)), + (r"\ln xy", _log(x * y)), + (r"\log x", _log(x)), + (r"\log xy", _log(x * y)), + (r"\log_{2} x", _log(x, 2)), + (r"\log_{a} x", _log(x, a)), + (r"\log_{11} x", _log(x, 11)), + (r"\log_{a^2} x", _log(x, _Pow(a, 2))), + (r"\log_2 x", _log(x, 2)), + (r"\log_a x", _log(x, a)), + (r"\overline{z}", _Conjugate(z)), + (r"\overline{\overline{z}}", _Conjugate(_Conjugate(z))), + (r"\overline{x + y}", _Conjugate(_Add(x, y))), + (r"\overline{x} + \overline{y}", _Conjugate(x) + _Conjugate(y)), + (r"\min(a, b)", _Min(a, b)), + (r"\min(a, b, c - d, xy)", _Min(a, b, c - d, x * y)), + (r"\max(a, b)", _Max(a, b)), + (r"\max(a, b, c - d, xy)", _Max(a, b, c - d, x * y)), + # physics things don't have an `evaluate=False` variant + (r"\langle x |", Bra('x')), + (r"| x \rangle", Ket('x')), + (r"\langle x | y \rangle", InnerProduct(Bra('x'), Ket('y'))), +] + +EVALUATED_COMMON_FUNCTION_EXPRESSION_PAIRS = [ + (r"|x|", Abs(x)), + (r"||x||", Abs(Abs(x))), + (r"|x||y|", Abs(x) * Abs(y)), + (r"||x||y||", Abs(Abs(x) * Abs(y))), + (r"\lfloor x \rfloor", floor(x)), + (r"\lceil x \rceil", ceiling(x)), + (r"\exp x", exp(x)), + (r"\exp(x)", exp(x)), + (r"\lg x", log(x, 10)), + (r"\ln x", log(x)), + (r"\ln xy", log(x * y)), + (r"\log x", log(x)), + (r"\log xy", log(x * y)), + (r"\log_{2} x", log(x, 2)), + (r"\log_{a} x", log(x, a)), + (r"\log_{11} x", log(x, 11)), + (r"\log_{a^2} x", log(x, _Pow(a, 2))), + (r"\log_2 x", log(x, 2)), + (r"\log_a x", log(x, a)), + (r"\overline{z}", conjugate(z)), + (r"\overline{\overline{z}}", conjugate(conjugate(z))), + (r"\overline{x + y}", conjugate(x + y)), + (r"\overline{x} + \overline{y}", conjugate(x) + conjugate(y)), + (r"\min(a, b)", Min(a, b)), + (r"\min(a, b, c - d, xy)", Min(a, b, c - d, x * y)), + (r"\max(a, b)", Max(a, b)), + (r"\max(a, b, c - d, xy)", Max(a, b, c - d, x * y)), + (r"\langle x |", Bra('x')), + (r"| x \rangle", Ket('x')), + (r"\langle x | y \rangle", InnerProduct(Bra('x'), Ket('y'))), +] + +SPACING_RELATED_EXPRESSION_PAIRS = [ + (r"a \, b", _Mul(a, b)), + (r"a \thinspace b", _Mul(a, b)), + (r"a \: b", _Mul(a, b)), + (r"a \medspace b", _Mul(a, b)), + (r"a \; b", _Mul(a, b)), + (r"a \thickspace b", _Mul(a, b)), + (r"a \quad b", _Mul(a, b)), + (r"a \qquad b", _Mul(a, b)), + (r"a \! b", _Mul(a, b)), + (r"a \negthinspace b", _Mul(a, b)), + (r"a \negmedspace b", _Mul(a, b)), + (r"a \negthickspace b", _Mul(a, b)) +] + +UNEVALUATED_BINOMIAL_EXPRESSION_PAIRS = [ + (r"\binom{n}{k}", _binomial(n, k)), + (r"\tbinom{n}{k}", _binomial(n, k)), + (r"\dbinom{n}{k}", _binomial(n, k)), + (r"\binom{n}{0}", _binomial(n, 0)), + (r"x^\binom{n}{k}", _Pow(x, _binomial(n, k))) +] + +EVALUATED_BINOMIAL_EXPRESSION_PAIRS = [ + (r"\binom{n}{k}", binomial(n, k)), + (r"\tbinom{n}{k}", binomial(n, k)), + (r"\dbinom{n}{k}", binomial(n, k)), + (r"\binom{n}{0}", binomial(n, 0)), + (r"x^\binom{n}{k}", x ** binomial(n, k)) +] + +MISCELLANEOUS_EXPRESSION_PAIRS = [ + (r"\left(x + y\right) z", _Mul(_Add(x, y), z)), + (r"\left( x + y\right ) z", _Mul(_Add(x, y), z)), + (r"\left( x + y\right ) z", _Mul(_Add(x, y), z)), +] + +UNEVALUATED_LITERAL_COMPLEX_NUMBER_EXPRESSION_PAIRS = [ + (r"\imaginaryunit^2", _Pow(I, 2)), + (r"|\imaginaryunit|", _Abs(I)), + (r"\overline{\imaginaryunit}", _Conjugate(I)), + (r"\imaginaryunit+\imaginaryunit", _Add(I, I)), + (r"\imaginaryunit-\imaginaryunit", _Add(I, -I)), + (r"\imaginaryunit*\imaginaryunit", _Mul(I, I)), + (r"\imaginaryunit/\imaginaryunit", _Mul(I, _Pow(I, -1))), + (r"(1+\imaginaryunit)/|1+\imaginaryunit|", _Mul(_Add(1, I), _Pow(_Abs(_Add(1, I)), -1))) +] + +UNEVALUATED_MATRIX_EXPRESSION_PAIRS = [ + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}", + Matrix([[a, b], [x, y]])), + (r"\begin{pmatrix}a & b \\x & y\\\end{pmatrix}", + Matrix([[a, b], [x, y]])), + (r"\begin{bmatrix}a & b \\x & y\end{bmatrix}", + Matrix([[a, b], [x, y]])), + (r"\left(\begin{matrix}a & b \\x & y\end{matrix}\right)", + Matrix([[a, b], [x, y]])), + (r"\left[\begin{matrix}a & b \\x & y\end{matrix}\right]", + Matrix([[a, b], [x, y]])), + (r"\left[\begin{array}{cc}a & b \\x & y\end{array}\right]", + Matrix([[a, b], [x, y]])), + (r"\left(\begin{array}{cc}a & b \\x & y\end{array}\right)", + Matrix([[a, b], [x, y]])), + (r"\left( { \begin{array}{cc}a & b \\x & y\end{array} } \right)", + Matrix([[a, b], [x, y]])), + (r"+\begin{pmatrix}a & b \\x & y\end{pmatrix}", + Matrix([[a, b], [x, y]])), + ((r"\begin{pmatrix}x & y \\a & b\end{pmatrix}+" + r"\begin{pmatrix}a & b \\x & y\end{pmatrix}"), + _MatAdd(Matrix([[x, y], [a, b]]), Matrix([[a, b], [x, y]]))), + (r"-\begin{pmatrix}a & b \\x & y\end{pmatrix}", + _MatMul(-1, Matrix([[a, b], [x, y]]))), + ((r"\begin{pmatrix}x & y \\a & b\end{pmatrix}-" + r"\begin{pmatrix}a & b \\x & y\end{pmatrix}"), + _MatAdd(Matrix([[x, y], [a, b]]), _MatMul(-1, Matrix([[a, b], [x, y]])))), + ((r"\begin{pmatrix}a & b & c \\x & y & z \\a & b & c \end{pmatrix}*" + r"\begin{pmatrix}x & y & z \\a & b & c \\a & b & c \end{pmatrix}*" + r"\begin{pmatrix}a & b & c \\x & y & z \\x & y & z \end{pmatrix}"), + _MatMul(_MatMul(Matrix([[a, b, c], [x, y, z], [a, b, c]]), + Matrix([[x, y, z], [a, b, c], [a, b, c]])), + Matrix([[a, b, c], [x, y, z], [x, y, z]]))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}/2", + _MatMul(Matrix([[a, b], [x, y]]), _Pow(2, -1))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^2", + _Pow(Matrix([[a, b], [x, y]]), 2)), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^{-1}", + _Pow(Matrix([[a, b], [x, y]]), -1)), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^T", + Transpose(Matrix([[a, b], [x, y]]))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^{T}", + Transpose(Matrix([[a, b], [x, y]]))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^\mathit{T}", + Transpose(Matrix([[a, b], [x, y]]))), + (r"\begin{pmatrix}1 & 2 \\3 & 4\end{pmatrix}^T", + Transpose(Matrix([[1, 2], [3, 4]]))), + ((r"(\begin{pmatrix}1 & 2 \\3 & 4\end{pmatrix}+" + r"\begin{pmatrix}1 & 2 \\3 & 4\end{pmatrix}^T)*" + r"\begin{bmatrix}1\\0\end{bmatrix}"), + _MatMul(_MatAdd(Matrix([[1, 2], [3, 4]]), + Transpose(Matrix([[1, 2], [3, 4]]))), + Matrix([[1], [0]]))), + ((r"(\begin{pmatrix}a & b \\x & y\end{pmatrix}+" + r"\begin{pmatrix}x & y \\a & b\end{pmatrix})^2"), + _Pow(_MatAdd(Matrix([[a, b], [x, y]]), + Matrix([[x, y], [a, b]])), 2)), + ((r"(\begin{pmatrix}a & b \\x & y\end{pmatrix}+" + r"\begin{pmatrix}x & y \\a & b\end{pmatrix})^T"), + Transpose(_MatAdd(Matrix([[a, b], [x, y]]), + Matrix([[x, y], [a, b]])))), + (r"\overline{\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}}", + _Conjugate(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))) +] + +EVALUATED_MATRIX_EXPRESSION_PAIRS = [ + (r"\det\left(\left[ { \begin{array}{cc}a&b\\x&y\end{array} } \right]\right)", + Matrix([[a, b], [x, y]]).det()), + (r"\det \begin{pmatrix}1&2\\3&4\end{pmatrix}", -2), + (r"\det{\begin{pmatrix}1&2\\3&4\end{pmatrix}}", -2), + (r"\det(\begin{pmatrix}1&2\\3&4\end{pmatrix})", -2), + (r"\det\left(\begin{pmatrix}1&2\\3&4\end{pmatrix}\right)", -2), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}/\begin{vmatrix}a & b \\x & y\end{vmatrix}", + _MatMul(Matrix([[a, b], [x, y]]), _Pow(Matrix([[a, b], [x, y]]).det(), -1))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}/|\begin{matrix}a & b \\x & y\end{matrix}|", + _MatMul(Matrix([[a, b], [x, y]]), _Pow(Matrix([[a, b], [x, y]]).det(), -1))), + (r"\frac{\begin{pmatrix}a & b \\x & y\end{pmatrix}}{| { \begin{matrix}a & b \\x & y\end{matrix} } |}", + _MatMul(Matrix([[a, b], [x, y]]), _Pow(Matrix([[a, b], [x, y]]).det(), -1))), + (r"\overline{\begin{pmatrix}\imaginaryunit & 1+\imaginaryunit \\-\imaginaryunit & 4\end{pmatrix}}", + Matrix([[-I, 1-I], [I, 4]])), + (r"\begin{pmatrix}\imaginaryunit & 1+\imaginaryunit \\-\imaginaryunit & 4\end{pmatrix}^H", + Matrix([[-I, I], [1-I, 4]])), + (r"\trace(\begin{pmatrix}\imaginaryunit & 1+\imaginaryunit \\-\imaginaryunit & 4\end{pmatrix})", + Trace(Matrix([[I, 1+I], [-I, 4]]))), + (r"\adjugate(\begin{pmatrix}1 & 2 \\3 & 4\end{pmatrix})", + Matrix([[4, -2], [-3, 1]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^\ast", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\ast}", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\ast\ast}", + Matrix([[2*I, 4], [6, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\ast\ast\ast}", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{*}", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{**}", + Matrix([[2*I, 4], [6, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{***}", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^\prime", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\prime}", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\prime\prime}", + _MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]]))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\prime\prime\prime}", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{'}", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{''}", + _MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]]))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{'''}", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})'", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})''", + _MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]]))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})'''", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"\det(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})", + (_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]]))).det()), + (r"\trace(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})", + Trace(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"\adjugate(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})", + (Matrix([[8, -4], [-6, 2*I]]))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^T", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^H", + (Matrix([[-2*I, 6], [4, 8]]))) +] + + +def test_symbol_expressions(): + expected_failures = {6, 7} + for i, (latex_str, sympy_expr) in enumerate(SYMBOL_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_simple_expressions(): + expected_failures = {20} + for i, (latex_str, sympy_expr) in enumerate(UNEVALUATED_SIMPLE_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for i, (latex_str, sympy_expr) in enumerate(EVALUATED_SIMPLE_EXPRESSION_PAIRS): + if i in expected_failures: + continue + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_fraction_expressions(): + for latex_str, sympy_expr in UNEVALUATED_FRACTION_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_FRACTION_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_relation_expressions(): + for latex_str, sympy_expr in RELATION_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + +def test_power_expressions(): + expected_failures = {3} + for i, (latex_str, sympy_expr) in enumerate(UNEVALUATED_POWER_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for i, (latex_str, sympy_expr) in enumerate(EVALUATED_POWER_EXPRESSION_PAIRS): + if i in expected_failures: + continue + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_integral_expressions(): + expected_failures = {14} + for i, (latex_str, sympy_expr) in enumerate(UNEVALUATED_INTEGRAL_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, i + + for i, (latex_str, sympy_expr) in enumerate(EVALUATED_INTEGRAL_EXPRESSION_PAIRS): + if i in expected_failures: + continue + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_derivative_expressions(): + expected_failures = {3, 4} + for i, (latex_str, sympy_expr) in enumerate(DERIVATIVE_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for i, (latex_str, sympy_expr) in enumerate(DERIVATIVE_EXPRESSION_PAIRS): + if i in expected_failures: + continue + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_trigonometric_expressions(): + expected_failures = {3} + for i, (latex_str, sympy_expr) in enumerate(TRIGONOMETRIC_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_limit_expressions(): + for latex_str, sympy_expr in UNEVALUATED_LIMIT_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_square_root_expressions(): + for latex_str, sympy_expr in UNEVALUATED_SQRT_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_SQRT_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_factorial_expressions(): + for latex_str, sympy_expr in UNEVALUATED_FACTORIAL_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_FACTORIAL_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_sum_expressions(): + for latex_str, sympy_expr in UNEVALUATED_SUM_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_SUM_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_product_expressions(): + for latex_str, sympy_expr in UNEVALUATED_PRODUCT_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + +@XFAIL +def test_applied_function_expressions(): + expected_failures = {0, 3, 4} # 0 is ambiguous, and the others require not-yet-added features + # not sure why 1, and 2 are failing + for i, (latex_str, sympy_expr) in enumerate(APPLIED_FUNCTION_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_common_function_expressions(): + for latex_str, sympy_expr in UNEVALUATED_COMMON_FUNCTION_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_COMMON_FUNCTION_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +# unhandled bug causing these to fail +@XFAIL +def test_spacing(): + for latex_str, sympy_expr in SPACING_RELATED_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_binomial_expressions(): + for latex_str, sympy_expr in UNEVALUATED_BINOMIAL_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_BINOMIAL_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_miscellaneous_expressions(): + for latex_str, sympy_expr in MISCELLANEOUS_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_literal_complex_number_expressions(): + for latex_str, sympy_expr in UNEVALUATED_LITERAL_COMPLEX_NUMBER_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_matrix_expressions(): + for latex_str, sympy_expr in UNEVALUATED_MATRIX_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_MATRIX_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_mathematica.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_mathematica.py new file mode 100644 index 0000000000000000000000000000000000000000..df193b6d61f9c82778d8e0a40b893cbe6cb8f06a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_mathematica.py @@ -0,0 +1,280 @@ +from sympy import sin, Function, symbols, Dummy, Lambda, cos +from sympy.parsing.mathematica import parse_mathematica, MathematicaParser +from sympy.core.sympify import sympify +from sympy.abc import n, w, x, y, z +from sympy.testing.pytest import raises + + +def test_mathematica(): + d = { + '- 6x': '-6*x', + 'Sin[x]^2': 'sin(x)**2', + '2(x-1)': '2*(x-1)', + '3y+8': '3*y+8', + 'ArcSin[2x+9(4-x)^2]/x': 'asin(2*x+9*(4-x)**2)/x', + 'x+y': 'x+y', + '355/113': '355/113', + '2.718281828': '2.718281828', + 'Cos(1/2 * π)': 'Cos(π/2)', + 'Sin[12]': 'sin(12)', + 'Exp[Log[4]]': 'exp(log(4))', + '(x+1)(x+3)': '(x+1)*(x+3)', + 'Cos[ArcCos[3.6]]': 'cos(acos(3.6))', + 'Cos[x]==Sin[y]': 'Eq(cos(x), sin(y))', + '2*Sin[x+y]': '2*sin(x+y)', + 'Sin[x]+Cos[y]': 'sin(x)+cos(y)', + 'Sin[Cos[x]]': 'sin(cos(x))', + '2*Sqrt[x+y]': '2*sqrt(x+y)', # Test case from the issue 4259 + '+Sqrt[2]': 'sqrt(2)', + '-Sqrt[2]': '-sqrt(2)', + '-1/Sqrt[2]': '-1/sqrt(2)', + '-(1/Sqrt[3])': '-(1/sqrt(3))', + '1/(2*Sqrt[5])': '1/(2*sqrt(5))', + 'Mod[5,3]': 'Mod(5,3)', + '-Mod[5,3]': '-Mod(5,3)', + '(x+1)y': '(x+1)*y', + 'x(y+1)': 'x*(y+1)', + 'Sin[x]Cos[y]': 'sin(x)*cos(y)', + 'Sin[x]^2Cos[y]^2': 'sin(x)**2*cos(y)**2', + 'Cos[x]^2(1 - Cos[y]^2)': 'cos(x)**2*(1-cos(y)**2)', + 'x y': 'x*y', + 'x y': 'x*y', + '2 x': '2*x', + 'x 8': 'x*8', + '2 8': '2*8', + '4.x': '4.*x', + '4. 3': '4.*3', + '4. 3.': '4.*3.', + '1 2 3': '1*2*3', + ' - 2 * Sqrt[ 2 3 * ( 1 + 5 ) ] ': '-2*sqrt(2*3*(1+5))', + 'Log[2,4]': 'log(4,2)', + 'Log[Log[2,4],4]': 'log(4,log(4,2))', + 'Exp[Sqrt[2]^2Log[2, 8]]': 'exp(sqrt(2)**2*log(8,2))', + 'ArcSin[Cos[0]]': 'asin(cos(0))', + 'Log2[16]': 'log(16,2)', + 'Max[1,-2,3,-4]': 'Max(1,-2,3,-4)', + 'Min[1,-2,3]': 'Min(1,-2,3)', + 'Exp[I Pi/2]': 'exp(I*pi/2)', + 'ArcTan[x,y]': 'atan2(y,x)', + 'Pochhammer[x,y]': 'rf(x,y)', + 'ExpIntegralEi[x]': 'Ei(x)', + 'SinIntegral[x]': 'Si(x)', + 'CosIntegral[x]': 'Ci(x)', + 'AiryAi[x]': 'airyai(x)', + 'AiryAiPrime[5]': 'airyaiprime(5)', + 'AiryBi[x]': 'airybi(x)', + 'AiryBiPrime[7]': 'airybiprime(7)', + 'LogIntegral[4]': ' li(4)', + 'PrimePi[7]': 'primepi(7)', + 'Prime[5]': 'prime(5)', + 'PrimeQ[5]': 'isprime(5)', + 'Rational[2,19]': 'Rational(2,19)', # test case for issue 25716 + } + + for e in d: + assert parse_mathematica(e) == sympify(d[e]) + + # The parsed form of this expression should not evaluate the Lambda object: + assert parse_mathematica("Sin[#]^2 + Cos[#]^2 &[x]") == sin(x)**2 + cos(x)**2 + + d1, d2, d3 = symbols("d1:4", cls=Dummy) + assert parse_mathematica("Sin[#] + Cos[#3] &").dummy_eq(Lambda((d1, d2, d3), sin(d1) + cos(d3))) + assert parse_mathematica("Sin[#^2] &").dummy_eq(Lambda(d1, sin(d1**2))) + assert parse_mathematica("Function[x, x^3]") == Lambda(x, x**3) + assert parse_mathematica("Function[{x, y}, x^2 + y^2]") == Lambda((x, y), x**2 + y**2) + + +def test_parser_mathematica_tokenizer(): + parser = MathematicaParser() + + chain = lambda expr: parser._from_tokens_to_fullformlist(parser._from_mathematica_to_tokens(expr)) + + # Basic patterns + assert chain("x") == "x" + assert chain("42") == "42" + assert chain(".2") == ".2" + assert chain("+x") == "x" + assert chain("-1") == "-1" + assert chain("- 3") == "-3" + assert chain("α") == "α" + assert chain("+Sin[x]") == ["Sin", "x"] + assert chain("-Sin[x]") == ["Times", "-1", ["Sin", "x"]] + assert chain("x(a+1)") == ["Times", "x", ["Plus", "a", "1"]] + assert chain("(x)") == "x" + assert chain("(+x)") == "x" + assert chain("-a") == ["Times", "-1", "a"] + assert chain("(-x)") == ["Times", "-1", "x"] + assert chain("(x + y)") == ["Plus", "x", "y"] + assert chain("3 + 4") == ["Plus", "3", "4"] + assert chain("a - 3") == ["Plus", "a", "-3"] + assert chain("a - b") == ["Plus", "a", ["Times", "-1", "b"]] + assert chain("7 * 8") == ["Times", "7", "8"] + assert chain("a + b*c") == ["Plus", "a", ["Times", "b", "c"]] + assert chain("a + b* c* d + 2 * e") == ["Plus", "a", ["Times", "b", "c", "d"], ["Times", "2", "e"]] + assert chain("a / b") == ["Times", "a", ["Power", "b", "-1"]] + + # Missing asterisk (*) patterns: + assert chain("x y") == ["Times", "x", "y"] + assert chain("3 4") == ["Times", "3", "4"] + assert chain("a[b] c") == ["Times", ["a", "b"], "c"] + assert chain("(x) (y)") == ["Times", "x", "y"] + assert chain("3 (a)") == ["Times", "3", "a"] + assert chain("(a) b") == ["Times", "a", "b"] + assert chain("4.2") == "4.2" + assert chain("4 2") == ["Times", "4", "2"] + assert chain("4 2") == ["Times", "4", "2"] + assert chain("3 . 4") == ["Dot", "3", "4"] + assert chain("4. 2") == ["Times", "4.", "2"] + assert chain("x.y") == ["Dot", "x", "y"] + assert chain("4.y") == ["Times", "4.", "y"] + assert chain("4 .y") == ["Dot", "4", "y"] + assert chain("x.4") == ["Times", "x", ".4"] + assert chain("x0.3") == ["Times", "x0", ".3"] + assert chain("x. 4") == ["Dot", "x", "4"] + + # Comments + assert chain("a (* +b *) + c") == ["Plus", "a", "c"] + assert chain("a (* + b *) + (**)c (* +d *) + e") == ["Plus", "a", "c", "e"] + assert chain("""a + (* + + b + *) c + (* d + *) e + """) == ["Plus", "a", "c", "e"] + + # Operators couples + and -, * and / are mutually associative: + # (i.e. expression gets flattened when mixing these operators) + assert chain("a*b/c") == ["Times", "a", "b", ["Power", "c", "-1"]] + assert chain("a/b*c") == ["Times", "a", ["Power", "b", "-1"], "c"] + assert chain("a+b-c") == ["Plus", "a", "b", ["Times", "-1", "c"]] + assert chain("a-b+c") == ["Plus", "a", ["Times", "-1", "b"], "c"] + assert chain("-a + b -c ") == ["Plus", ["Times", "-1", "a"], "b", ["Times", "-1", "c"]] + assert chain("a/b/c*d") == ["Times", "a", ["Power", "b", "-1"], ["Power", "c", "-1"], "d"] + assert chain("a/b/c") == ["Times", "a", ["Power", "b", "-1"], ["Power", "c", "-1"]] + assert chain("a-b-c") == ["Plus", "a", ["Times", "-1", "b"], ["Times", "-1", "c"]] + assert chain("1/a") == ["Times", "1", ["Power", "a", "-1"]] + assert chain("1/a/b") == ["Times", "1", ["Power", "a", "-1"], ["Power", "b", "-1"]] + assert chain("-1/a*b") == ["Times", "-1", ["Power", "a", "-1"], "b"] + + # Enclosures of various kinds, i.e. ( ) [ ] [[ ]] { } + assert chain("(a + b) + c") == ["Plus", ["Plus", "a", "b"], "c"] + assert chain(" a + (b + c) + d ") == ["Plus", "a", ["Plus", "b", "c"], "d"] + assert chain("a * (b + c)") == ["Times", "a", ["Plus", "b", "c"]] + assert chain("a b (c d)") == ["Times", "a", "b", ["Times", "c", "d"]] + assert chain("{a, b, 2, c}") == ["List", "a", "b", "2", "c"] + assert chain("{a, {b, c}}") == ["List", "a", ["List", "b", "c"]] + assert chain("{{a}}") == ["List", ["List", "a"]] + assert chain("a[b, c]") == ["a", "b", "c"] + assert chain("a[[b, c]]") == ["Part", "a", "b", "c"] + assert chain("a[b[c]]") == ["a", ["b", "c"]] + assert chain("a[[b, c[[d, {e,f}]]]]") == ["Part", "a", "b", ["Part", "c", "d", ["List", "e", "f"]]] + assert chain("a[b[[c,d]]]") == ["a", ["Part", "b", "c", "d"]] + assert chain("a[[b[c]]]") == ["Part", "a", ["b", "c"]] + assert chain("a[[b[[c]]]]") == ["Part", "a", ["Part", "b", "c"]] + assert chain("a[[b[c[[d]]]]]") == ["Part", "a", ["b", ["Part", "c", "d"]]] + assert chain("a[b[[c[d]]]]") == ["a", ["Part", "b", ["c", "d"]]] + assert chain("x[[a+1, b+2, c+3]]") == ["Part", "x", ["Plus", "a", "1"], ["Plus", "b", "2"], ["Plus", "c", "3"]] + assert chain("x[a+1, b+2, c+3]") == ["x", ["Plus", "a", "1"], ["Plus", "b", "2"], ["Plus", "c", "3"]] + assert chain("{a+1, b+2, c+3}") == ["List", ["Plus", "a", "1"], ["Plus", "b", "2"], ["Plus", "c", "3"]] + + # Flat operator: + assert chain("a*b*c*d*e") == ["Times", "a", "b", "c", "d", "e"] + assert chain("a +b + c+ d+e") == ["Plus", "a", "b", "c", "d", "e"] + + # Right priority operator: + assert chain("a^b") == ["Power", "a", "b"] + assert chain("a^b^c") == ["Power", "a", ["Power", "b", "c"]] + assert chain("a^b^c^d") == ["Power", "a", ["Power", "b", ["Power", "c", "d"]]] + + # Left priority operator: + assert chain("a/.b") == ["ReplaceAll", "a", "b"] + assert chain("a/.b/.c/.d") == ["ReplaceAll", ["ReplaceAll", ["ReplaceAll", "a", "b"], "c"], "d"] + + assert chain("a//b") == ["a", "b"] + assert chain("a//b//c") == [["a", "b"], "c"] + assert chain("a//b//c//d") == [[["a", "b"], "c"], "d"] + + # Compound expressions + assert chain("a;b") == ["CompoundExpression", "a", "b"] + assert chain("a;") == ["CompoundExpression", "a", "Null"] + assert chain("a;b;") == ["CompoundExpression", "a", "b", "Null"] + assert chain("a[b;c]") == ["a", ["CompoundExpression", "b", "c"]] + assert chain("a[b,c;d,e]") == ["a", "b", ["CompoundExpression", "c", "d"], "e"] + assert chain("a[b,c;,d]") == ["a", "b", ["CompoundExpression", "c", "Null"], "d"] + + # New lines + assert chain("a\nb\n") == ["CompoundExpression", "a", "b"] + assert chain("a\n\nb\n (c \nd) \n") == ["CompoundExpression", "a", "b", ["Times", "c", "d"]] + assert chain("\na; b\nc") == ["CompoundExpression", "a", "b", "c"] + assert chain("a + \nb\n") == ["Plus", "a", "b"] + assert chain("a\nb; c; d\n e; (f \n g); h + \n i") == ["CompoundExpression", "a", "b", "c", "d", "e", ["Times", "f", "g"], ["Plus", "h", "i"]] + assert chain("\n{\na\nb; c; d\n e (f \n g); h + \n i\n\n}\n") == ["List", ["CompoundExpression", ["Times", "a", "b"], "c", ["Times", "d", "e", ["Times", "f", "g"]], ["Plus", "h", "i"]]] + + # Patterns + assert chain("y_") == ["Pattern", "y", ["Blank"]] + assert chain("y_.") == ["Optional", ["Pattern", "y", ["Blank"]]] + assert chain("y__") == ["Pattern", "y", ["BlankSequence"]] + assert chain("y___") == ["Pattern", "y", ["BlankNullSequence"]] + assert chain("a[b_.,c_]") == ["a", ["Optional", ["Pattern", "b", ["Blank"]]], ["Pattern", "c", ["Blank"]]] + assert chain("b_. c") == ["Times", ["Optional", ["Pattern", "b", ["Blank"]]], "c"] + + # Slots for lambda functions + assert chain("#") == ["Slot", "1"] + assert chain("#3") == ["Slot", "3"] + assert chain("#n") == ["Slot", "n"] + assert chain("##") == ["SlotSequence", "1"] + assert chain("##a") == ["SlotSequence", "a"] + + # Lambda functions + assert chain("x&") == ["Function", "x"] + assert chain("#&") == ["Function", ["Slot", "1"]] + assert chain("#+3&") == ["Function", ["Plus", ["Slot", "1"], "3"]] + assert chain("#1 + #2&") == ["Function", ["Plus", ["Slot", "1"], ["Slot", "2"]]] + assert chain("# + #&") == ["Function", ["Plus", ["Slot", "1"], ["Slot", "1"]]] + assert chain("#&[x]") == [["Function", ["Slot", "1"]], "x"] + assert chain("#1 + #2 & [x, y]") == [["Function", ["Plus", ["Slot", "1"], ["Slot", "2"]]], "x", "y"] + assert chain("#1^2#2^3&") == ["Function", ["Times", ["Power", ["Slot", "1"], "2"], ["Power", ["Slot", "2"], "3"]]] + + # Strings inside Mathematica expressions: + assert chain('"abc"') == ["_Str", "abc"] + assert chain('"a\\"b"') == ["_Str", 'a"b'] + # This expression does not make sense mathematically, it's just testing the parser: + assert chain('x + "abc" ^ 3') == ["Plus", "x", ["Power", ["_Str", "abc"], "3"]] + assert chain('"a (* b *) c"') == ["_Str", "a (* b *) c"] + assert chain('"a" (* b *) ') == ["_Str", "a"] + assert chain('"a [ b] "') == ["_Str", "a [ b] "] + raises(SyntaxError, lambda: chain('"')) + raises(SyntaxError, lambda: chain('"\\"')) + raises(SyntaxError, lambda: chain('"abc')) + raises(SyntaxError, lambda: chain('"abc\\"def')) + + # Invalid expressions: + raises(SyntaxError, lambda: chain("(,")) + raises(SyntaxError, lambda: chain("()")) + raises(SyntaxError, lambda: chain("a (* b")) + + +def test_parser_mathematica_exp_alt(): + parser = MathematicaParser() + + convert_chain2 = lambda expr: parser._from_fullformlist_to_fullformsympy(parser._from_fullform_to_fullformlist(expr)) + convert_chain3 = lambda expr: parser._from_fullformsympy_to_sympy(convert_chain2(expr)) + + Sin, Times, Plus, Power = symbols("Sin Times Plus Power", cls=Function) + + full_form1 = "Sin[Times[x, y]]" + full_form2 = "Plus[Times[x, y], z]" + full_form3 = "Sin[Times[x, Plus[y, z], Power[w, n]]]]" + full_form4 = "Rational[Rational[x, y], z]" + + assert parser._from_fullform_to_fullformlist(full_form1) == ["Sin", ["Times", "x", "y"]] + assert parser._from_fullform_to_fullformlist(full_form2) == ["Plus", ["Times", "x", "y"], "z"] + assert parser._from_fullform_to_fullformlist(full_form3) == ["Sin", ["Times", "x", ["Plus", "y", "z"], ["Power", "w", "n"]]] + assert parser._from_fullform_to_fullformlist(full_form4) == ["Rational", ["Rational", "x", "y"], "z"] + + assert convert_chain2(full_form1) == Sin(Times(x, y)) + assert convert_chain2(full_form2) == Plus(Times(x, y), z) + assert convert_chain2(full_form3) == Sin(Times(x, Plus(y, z), Power(w, n))) + + assert convert_chain3(full_form1) == sin(x*y) + assert convert_chain3(full_form2) == x*y + z + assert convert_chain3(full_form3) == sin(x*(y + z)*w**n) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_maxima.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_maxima.py new file mode 100644 index 0000000000000000000000000000000000000000..c0bc1db8f1385ed52e8c677a1bcc759f5118d01e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_maxima.py @@ -0,0 +1,50 @@ +from sympy.parsing.maxima import parse_maxima +from sympy.core.numbers import (E, Rational, oo) +from sympy.core.symbol import Symbol +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.abc import x + +n = Symbol('n', integer=True) + + +def test_parser(): + assert Abs(parse_maxima('float(1/3)') - 0.333333333) < 10**(-5) + assert parse_maxima('13^26') == 91733330193268616658399616009 + assert parse_maxima('sin(%pi/2) + cos(%pi/3)') == Rational(3, 2) + assert parse_maxima('log(%e)') == 1 + + +def test_injection(): + parse_maxima('c: x+1', globals=globals()) + # c created by parse_maxima + assert c == x + 1 # noqa:F821 + + parse_maxima('g: sqrt(81)', globals=globals()) + # g created by parse_maxima + assert g == 9 # noqa:F821 + + +def test_maxima_functions(): + assert parse_maxima('expand( (x+1)^2)') == x**2 + 2*x + 1 + assert parse_maxima('factor( x**2 + 2*x + 1)') == (x + 1)**2 + assert parse_maxima('2*cos(x)^2 + sin(x)^2') == 2*cos(x)**2 + sin(x)**2 + assert parse_maxima('trigexpand(sin(2*x)+cos(2*x))') == \ + -1 + 2*cos(x)**2 + 2*cos(x)*sin(x) + assert parse_maxima('solve(x^2-4,x)') == [-2, 2] + assert parse_maxima('limit((1+1/x)^x,x,inf)') == E + assert parse_maxima('limit(sqrt(-x)/x,x,0,minus)') is -oo + assert parse_maxima('diff(x^x, x)') == x**x*(1 + log(x)) + assert parse_maxima('sum(k, k, 1, n)', name_dict={ + "n": Symbol('n', integer=True), + "k": Symbol('k', integer=True) + }) == (n**2 + n)/2 + assert parse_maxima('product(k, k, 1, n)', name_dict={ + "n": Symbol('n', integer=True), + "k": Symbol('k', integer=True) + }) == factorial(n) + assert parse_maxima('ratsimp((x^2-1)/(x+1))') == x - 1 + assert Abs( parse_maxima( + 'float(sec(%pi/3) + csc(%pi/3))') - 3.154700538379252) < 10**(-5) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_sym_expr.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_sym_expr.py new file mode 100644 index 0000000000000000000000000000000000000000..99912805db381b96e7f41a348fe6f90d71adf781 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_sym_expr.py @@ -0,0 +1,209 @@ +from sympy.parsing.sym_expr import SymPyExpression +from sympy.testing.pytest import raises +from sympy.external import import_module + +lfortran = import_module('lfortran') +cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']}) + +if lfortran and cin: + from sympy.codegen.ast import (Variable, IntBaseType, FloatBaseType, String, + Declaration, FloatType) + from sympy.core import Integer, Float + from sympy.core.symbol import Symbol + + expr1 = SymPyExpression() + src = """\ + integer :: a, b, c, d + real :: p, q, r, s + """ + + def test_c_parse(): + src1 = """\ + int a, b = 4; + float c, d = 2.4; + """ + expr1.convert_to_expr(src1, 'c') + ls = expr1.return_expr() + + assert ls[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + assert ls[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(4) + ) + ) + assert ls[2] == Declaration( + Variable( + Symbol('c'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ) + assert ls[3] == Declaration( + Variable( + Symbol('d'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.3999999999999999', precision=53) + ) + ) + + + def test_fortran_parse(): + expr = SymPyExpression(src, 'f') + ls = expr.return_expr() + + assert ls[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ) + assert ls[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ) + assert ls[2] == Declaration( + Variable( + Symbol('c'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ) + assert ls[3] == Declaration( + Variable( + Symbol('d'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ) + assert ls[4] == Declaration( + Variable( + Symbol('p'), + type=FloatBaseType(String('real')), + value=Float('0.0', precision=53) + ) + ) + assert ls[5] == Declaration( + Variable( + Symbol('q'), + type=FloatBaseType(String('real')), + value=Float('0.0', precision=53) + ) + ) + assert ls[6] == Declaration( + Variable( + Symbol('r'), + type=FloatBaseType(String('real')), + value=Float('0.0', precision=53) + ) + ) + assert ls[7] == Declaration( + Variable( + Symbol('s'), + type=FloatBaseType(String('real')), + value=Float('0.0', precision=53) + ) + ) + + + def test_convert_py(): + src1 = ( + src + + """\ + a = b + c + s = p * q / r + """ + ) + expr1.convert_to_expr(src1, 'f') + exp_py = expr1.convert_to_python() + assert exp_py == [ + 'a = 0', + 'b = 0', + 'c = 0', + 'd = 0', + 'p = 0.0', + 'q = 0.0', + 'r = 0.0', + 's = 0.0', + 'a = b + c', + 's = p*q/r' + ] + + + def test_convert_fort(): + src1 = ( + src + + """\ + a = b + c + s = p * q / r + """ + ) + expr1.convert_to_expr(src1, 'f') + exp_fort = expr1.convert_to_fortran() + assert exp_fort == [ + ' integer*4 a', + ' integer*4 b', + ' integer*4 c', + ' integer*4 d', + ' real*8 p', + ' real*8 q', + ' real*8 r', + ' real*8 s', + ' a = b + c', + ' s = p*q/r' + ] + + + def test_convert_c(): + src1 = ( + src + + """\ + a = b + c + s = p * q / r + """ + ) + expr1.convert_to_expr(src1, 'f') + exp_c = expr1.convert_to_c() + assert exp_c == [ + 'int a = 0', + 'int b = 0', + 'int c = 0', + 'int d = 0', + 'double p = 0.0', + 'double q = 0.0', + 'double r = 0.0', + 'double s = 0.0', + 'a = b + c;', + 's = p*q/r;' + ] + + + def test_exceptions(): + src = 'int a;' + raises(ValueError, lambda: SymPyExpression(src)) + raises(ValueError, lambda: SymPyExpression(mode = 'c')) + raises(NotImplementedError, lambda: SymPyExpression(src, mode = 'd')) + +elif not lfortran and not cin: + def test_raise(): + raises(ImportError, lambda: SymPyExpression('int a;', 'c')) + raises(ImportError, lambda: SymPyExpression('integer :: a', 'f')) diff --git a/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_sympy_parser.py b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_sympy_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..43ecccbe262ffb4093248d891aa7423c8f62c628 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/parsing/tests/test_sympy_parser.py @@ -0,0 +1,371 @@ +# -*- coding: utf-8 -*- + + +import builtins +import types + +from sympy.assumptions import Q +from sympy.core import Symbol, Function, Float, Rational, Integer, I, Mul, Pow, Eq, Lt, Le, Gt, Ge, Ne +from sympy.functions import exp, factorial, factorial2, sin, Min, Max +from sympy.logic import And +from sympy.series import Limit +from sympy.testing.pytest import raises + +from sympy.parsing.sympy_parser import ( + parse_expr, standard_transformations, rationalize, TokenError, + split_symbols, implicit_multiplication, convert_equals_signs, + convert_xor, function_exponentiation, lambda_notation, auto_symbol, + repeated_decimals, implicit_multiplication_application, + auto_number, factorial_notation, implicit_application, + _transformation, T + ) + + +def test_sympy_parser(): + x = Symbol('x') + inputs = { + '2*x': 2 * x, + '3.00': Float(3), + '22/7': Rational(22, 7), + '2+3j': 2 + 3*I, + 'exp(x)': exp(x), + 'x!': factorial(x), + 'x!!': factorial2(x), + '(x + 1)! - 1': factorial(x + 1) - 1, + '3.[3]': Rational(10, 3), + '.0[3]': Rational(1, 30), + '3.2[3]': Rational(97, 30), + '1.3[12]': Rational(433, 330), + '1 + 3.[3]': Rational(13, 3), + '1 + .0[3]': Rational(31, 30), + '1 + 3.2[3]': Rational(127, 30), + '.[0011]': Rational(1, 909), + '0.1[00102] + 1': Rational(366697, 333330), + '1.[0191]': Rational(10190, 9999), + '10!': 3628800, + '-(2)': -Integer(2), + '[-1, -2, 3]': [Integer(-1), Integer(-2), Integer(3)], + 'Symbol("x").free_symbols': x.free_symbols, + "S('S(3).n(n=3)')": Float(3, 3), + 'factorint(12, visual=True)': Mul( + Pow(2, 2, evaluate=False), + Pow(3, 1, evaluate=False), + evaluate=False), + 'Limit(sin(x), x, 0, dir="-")': Limit(sin(x), x, 0, dir='-'), + 'Q.even(x)': Q.even(x), + + + } + for text, result in inputs.items(): + assert parse_expr(text) == result + + raises(TypeError, lambda: + parse_expr('x', standard_transformations)) + raises(TypeError, lambda: + parse_expr('x', transformations=lambda x,y: 1)) + raises(TypeError, lambda: + parse_expr('x', transformations=(lambda x,y: 1,))) + raises(TypeError, lambda: parse_expr('x', transformations=((),))) + raises(TypeError, lambda: parse_expr('x', {}, [], [])) + raises(TypeError, lambda: parse_expr('x', [], [], {})) + raises(TypeError, lambda: parse_expr('x', [], [], {})) + + +def test_rationalize(): + inputs = { + '0.123': Rational(123, 1000) + } + transformations = standard_transformations + (rationalize,) + for text, result in inputs.items(): + assert parse_expr(text, transformations=transformations) == result + + +def test_factorial_fail(): + inputs = ['x!!!', 'x!!!!', '(!)'] + + + for text in inputs: + try: + parse_expr(text) + assert False + except TokenError: + assert True + + +def test_repeated_fail(): + inputs = ['1[1]', '.1e1[1]', '0x1[1]', '1.1j[1]', '1.1[1 + 1]', + '0.1[[1]]', '0x1.1[1]'] + + + # All are valid Python, so only raise TypeError for invalid indexing + for text in inputs: + raises(TypeError, lambda: parse_expr(text)) + + + inputs = ['0.1[', '0.1[1', '0.1[]'] + for text in inputs: + raises((TokenError, SyntaxError), lambda: parse_expr(text)) + + +def test_repeated_dot_only(): + assert parse_expr('.[1]') == Rational(1, 9) + assert parse_expr('1 + .[1]') == Rational(10, 9) + + +def test_local_dict(): + local_dict = { + 'my_function': lambda x: x + 2 + } + inputs = { + 'my_function(2)': Integer(4) + } + for text, result in inputs.items(): + assert parse_expr(text, local_dict=local_dict) == result + + +def test_local_dict_split_implmult(): + t = standard_transformations + (split_symbols, implicit_multiplication,) + w = Symbol('w', real=True) + y = Symbol('y') + assert parse_expr('yx', local_dict={'x':w}, transformations=t) == y*w + + +def test_local_dict_symbol_to_fcn(): + x = Symbol('x') + d = {'foo': Function('bar')} + assert parse_expr('foo(x)', local_dict=d) == d['foo'](x) + d = {'foo': Symbol('baz')} + raises(TypeError, lambda: parse_expr('foo(x)', local_dict=d)) + + +def test_global_dict(): + global_dict = { + 'Symbol': Symbol + } + inputs = { + 'Q & S': And(Symbol('Q'), Symbol('S')) + } + for text, result in inputs.items(): + assert parse_expr(text, global_dict=global_dict) == result + + +def test_no_globals(): + + # Replicate creating the default global_dict: + default_globals = {} + exec('from sympy import *', default_globals) + builtins_dict = vars(builtins) + for name, obj in builtins_dict.items(): + if isinstance(obj, types.BuiltinFunctionType): + default_globals[name] = obj + default_globals['max'] = Max + default_globals['min'] = Min + + # Need to include Symbol or parse_expr will not work: + default_globals.pop('Symbol') + global_dict = {'Symbol':Symbol} + + for name in default_globals: + obj = parse_expr(name, global_dict=global_dict) + assert obj == Symbol(name) + + +def test_issue_2515(): + raises(TokenError, lambda: parse_expr('(()')) + raises(TokenError, lambda: parse_expr('"""')) + + +def test_issue_7663(): + x = Symbol('x') + e = '2*(x+1)' + assert parse_expr(e, evaluate=False) == parse_expr(e, evaluate=False) + assert parse_expr(e, evaluate=False).equals(2*(x+1)) + +def test_recursive_evaluate_false_10560(): + inputs = { + '4*-3' : '4*-3', + '-4*3' : '(-4)*3', + "-2*x*y": '(-2)*x*y', + "x*-4*x": "x*(-4)*x" + } + for text, result in inputs.items(): + assert parse_expr(text, evaluate=False) == parse_expr(result, evaluate=False) + + +def test_function_evaluate_false(): + inputs = [ + 'Abs(0)', 'im(0)', 're(0)', 'sign(0)', 'arg(0)', 'conjugate(0)', + 'acos(0)', 'acot(0)', 'acsc(0)', 'asec(0)', 'asin(0)', 'atan(0)', + 'acosh(0)', 'acoth(0)', 'acsch(0)', 'asech(0)', 'asinh(0)', 'atanh(0)', + 'cos(0)', 'cot(0)', 'csc(0)', 'sec(0)', 'sin(0)', 'tan(0)', + 'cosh(0)', 'coth(0)', 'csch(0)', 'sech(0)', 'sinh(0)', 'tanh(0)', + 'exp(0)', 'log(0)', 'sqrt(0)', + ] + for case in inputs: + expr = parse_expr(case, evaluate=False) + assert case == str(expr) != str(expr.doit()) + assert str(parse_expr('ln(0)', evaluate=False)) == 'log(0)' + assert str(parse_expr('cbrt(0)', evaluate=False)) == '0**(1/3)' + + +def test_issue_10773(): + inputs = { + '-10/5': '(-10)/5', + '-10/-5' : '(-10)/(-5)', + } + for text, result in inputs.items(): + assert parse_expr(text, evaluate=False) == parse_expr(result, evaluate=False) + + +def test_split_symbols(): + transformations = standard_transformations + \ + (split_symbols, implicit_multiplication,) + x = Symbol('x') + y = Symbol('y') + xy = Symbol('xy') + + + assert parse_expr("xy") == xy + assert parse_expr("xy", transformations=transformations) == x*y + + +def test_split_symbols_function(): + transformations = standard_transformations + \ + (split_symbols, implicit_multiplication,) + x = Symbol('x') + y = Symbol('y') + a = Symbol('a') + f = Function('f') + + + assert parse_expr("ay(x+1)", transformations=transformations) == a*y*(x+1) + assert parse_expr("af(x+1)", transformations=transformations, + local_dict={'f':f}) == a*f(x+1) + + +def test_functional_exponent(): + t = standard_transformations + (convert_xor, function_exponentiation) + x = Symbol('x') + y = Symbol('y') + a = Symbol('a') + yfcn = Function('y') + assert parse_expr("sin^2(x)", transformations=t) == (sin(x))**2 + assert parse_expr("sin^y(x)", transformations=t) == (sin(x))**y + assert parse_expr("exp^y(x)", transformations=t) == (exp(x))**y + assert parse_expr("E^y(x)", transformations=t) == exp(yfcn(x)) + assert parse_expr("a^y(x)", transformations=t) == a**(yfcn(x)) + + +def test_match_parentheses_implicit_multiplication(): + transformations = standard_transformations + \ + (implicit_multiplication,) + raises(TokenError, lambda: parse_expr('(1,2),(3,4]',transformations=transformations)) + + +def test_convert_equals_signs(): + transformations = standard_transformations + \ + (convert_equals_signs, ) + x = Symbol('x') + y = Symbol('y') + assert parse_expr("1*2=x", transformations=transformations) == Eq(2, x) + assert parse_expr("y = x", transformations=transformations) == Eq(y, x) + assert parse_expr("(2*y = x) = False", + transformations=transformations) == Eq(Eq(2*y, x), False) + + +def test_parse_function_issue_3539(): + x = Symbol('x') + f = Function('f') + assert parse_expr('f(x)') == f(x) + +def test_issue_24288(): + assert parse_expr("1 < 2", evaluate=False) == Lt(1, 2, evaluate=False) + assert parse_expr("1 <= 2", evaluate=False) == Le(1, 2, evaluate=False) + assert parse_expr("1 > 2", evaluate=False) == Gt(1, 2, evaluate=False) + assert parse_expr("1 >= 2", evaluate=False) == Ge(1, 2, evaluate=False) + assert parse_expr("1 != 2", evaluate=False) == Ne(1, 2, evaluate=False) + assert parse_expr("1 == 2", evaluate=False) == Eq(1, 2, evaluate=False) + assert parse_expr("1 < 2 < 3", evaluate=False) == And(Lt(1, 2, evaluate=False), Lt(2, 3, evaluate=False), evaluate=False) + assert parse_expr("1 <= 2 <= 3", evaluate=False) == And(Le(1, 2, evaluate=False), Le(2, 3, evaluate=False), evaluate=False) + assert parse_expr("1 < 2 <= 3 < 4", evaluate=False) == \ + And(Lt(1, 2, evaluate=False), Le(2, 3, evaluate=False), Lt(3, 4, evaluate=False), evaluate=False) + # Valid Python relational operators that SymPy does not decide how to handle them yet + raises(ValueError, lambda: parse_expr("1 in 2", evaluate=False)) + raises(ValueError, lambda: parse_expr("1 is 2", evaluate=False)) + raises(ValueError, lambda: parse_expr("1 not in 2", evaluate=False)) + raises(ValueError, lambda: parse_expr("1 is not 2", evaluate=False)) + +def test_split_symbols_numeric(): + transformations = ( + standard_transformations + + (implicit_multiplication_application,)) + + n = Symbol('n') + expr1 = parse_expr('2**n * 3**n') + expr2 = parse_expr('2**n3**n', transformations=transformations) + assert expr1 == expr2 == 2**n*3**n + + expr1 = parse_expr('n12n34', transformations=transformations) + assert expr1 == n*12*n*34 + + +def test_unicode_names(): + assert parse_expr('α') == Symbol('α') + + +def test_python3_features(): + assert parse_expr("123_456") == 123456 + assert parse_expr("1.2[3_4]") == parse_expr("1.2[34]") == Rational(611, 495) + assert parse_expr("1.2[012_012]") == parse_expr("1.2[012012]") == Rational(400, 333) + assert parse_expr('.[3_4]') == parse_expr('.[34]') == Rational(34, 99) + assert parse_expr('.1[3_4]') == parse_expr('.1[34]') == Rational(133, 990) + assert parse_expr('123_123.123_123[3_4]') == parse_expr('123123.123123[34]') == Rational(12189189189211, 99000000) + + +def test_issue_19501(): + x = Symbol('x') + eq = parse_expr('E**x(1+x)', local_dict={'x': x}, transformations=( + standard_transformations + + (implicit_multiplication_application,))) + assert eq.free_symbols == {x} + + +def test_parsing_definitions(): + from sympy.abc import x + assert len(_transformation) == 12 # if this changes, extend below + assert _transformation[0] == lambda_notation + assert _transformation[1] == auto_symbol + assert _transformation[2] == repeated_decimals + assert _transformation[3] == auto_number + assert _transformation[4] == factorial_notation + assert _transformation[5] == implicit_multiplication_application + assert _transformation[6] == convert_xor + assert _transformation[7] == implicit_application + assert _transformation[8] == implicit_multiplication + assert _transformation[9] == convert_equals_signs + assert _transformation[10] == function_exponentiation + assert _transformation[11] == rationalize + assert T[:5] == T[0,1,2,3,4] == standard_transformations + t = _transformation + assert T[-1, 0] == (t[len(t) - 1], t[0]) + assert T[:5, 8] == standard_transformations + (t[8],) + assert parse_expr('0.3x^2', transformations='all') == 3*x**2/10 + assert parse_expr('sin 3x', transformations='implicit') == sin(3*x) + + +def test_builtins(): + cases = [ + ('abs(x)', 'Abs(x)'), + ('max(x, y)', 'Max(x, y)'), + ('min(x, y)', 'Min(x, y)'), + ('pow(x, y)', 'Pow(x, y)'), + ] + for built_in_func_call, sympy_func_call in cases: + assert parse_expr(built_in_func_call) == parse_expr(sympy_func_call) + assert str(parse_expr('pow(38, -1, 97)')) == '23' + + +def test_issue_22822(): + raises(ValueError, lambda: parse_expr('x', {'': 1})) + data = {'some_parameter': None} + assert parse_expr('some_parameter is None', data) is True diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..60989896ae8b3f69efc7d2350add8f6f19d85669 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/__init__.py @@ -0,0 +1,12 @@ +""" +A module that helps solving problems in physics. +""" + +from . import units +from .matrices import mgamma, msigma, minkowski_tensor, mdft + +__all__ = [ + 'units', + + 'mgamma', 'msigma', 'minkowski_tensor', 'mdft', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/hep/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/hep/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/hep/gamma_matrices.py b/.venv/lib/python3.13/site-packages/sympy/physics/hep/gamma_matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..40c3d0754438902f304d01c2df354dd09f9ea257 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/hep/gamma_matrices.py @@ -0,0 +1,716 @@ +""" + Module to handle gamma matrices expressed as tensor objects. + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, LorentzIndex + >>> from sympy.tensor.tensor import tensor_indices + >>> i = tensor_indices('i', LorentzIndex) + >>> G(i) + GammaMatrix(i) + + Note that there is already an instance of GammaMatrixHead in four dimensions: + GammaMatrix, which is simply declare as + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix + >>> from sympy.tensor.tensor import tensor_indices + >>> i = tensor_indices('i', LorentzIndex) + >>> GammaMatrix(i) + GammaMatrix(i) + + To access the metric tensor + + >>> LorentzIndex.metric + metric(LorentzIndex,LorentzIndex) + +""" +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.matrices.dense import eye +from sympy.matrices.expressions.trace import trace +from sympy.tensor.tensor import TensorIndexType, TensorIndex,\ + TensMul, TensAdd, tensor_mul, Tensor, TensorHead, TensorSymmetry + + +# DiracSpinorIndex = TensorIndexType('DiracSpinorIndex', dim=4, dummy_name="S") + + +LorentzIndex = TensorIndexType('LorentzIndex', dim=4, dummy_name="L") + + +GammaMatrix = TensorHead("GammaMatrix", [LorentzIndex], + TensorSymmetry.no_symmetry(1), comm=None) + + +def extract_type_tens(expression, component): + """ + Extract from a ``TensExpr`` all tensors with `component`. + + Returns two tensor expressions: + + * the first contains all ``Tensor`` of having `component`. + * the second contains all remaining. + + + """ + if isinstance(expression, Tensor): + sp = [expression] + elif isinstance(expression, TensMul): + sp = expression.args + else: + raise ValueError('wrong type') + + # Collect all gamma matrices of the same dimension + new_expr = S.One + residual_expr = S.One + for i in sp: + if isinstance(i, Tensor) and i.component == component: + new_expr *= i + else: + residual_expr *= i + return new_expr, residual_expr + + +def simplify_gamma_expression(expression): + extracted_expr, residual_expr = extract_type_tens(expression, GammaMatrix) + res_expr = _simplify_single_line(extracted_expr) + return res_expr * residual_expr + + +def simplify_gpgp(ex, sort=True): + """ + simplify products ``G(i)*p(-i)*G(j)*p(-j) -> p(i)*p(-i)`` + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, \ + LorentzIndex, simplify_gpgp + >>> from sympy.tensor.tensor import tensor_indices, tensor_heads + >>> p, q = tensor_heads('p, q', [LorentzIndex]) + >>> i0,i1,i2,i3,i4,i5 = tensor_indices('i0:6', LorentzIndex) + >>> ps = p(i0)*G(-i0) + >>> qs = q(i0)*G(-i0) + >>> simplify_gpgp(ps*qs*qs) + GammaMatrix(-L_0)*p(L_0)*q(L_1)*q(-L_1) + """ + def _simplify_gpgp(ex): + components = ex.components + a = [] + comp_map = [] + for i, comp in enumerate(components): + comp_map.extend([i]*comp.rank) + dum = [(i[0], i[1], comp_map[i[0]], comp_map[i[1]]) for i in ex.dum] + for i in range(len(components)): + if components[i] != GammaMatrix: + continue + for dx in dum: + if dx[2] == i: + p_pos1 = dx[3] + elif dx[3] == i: + p_pos1 = dx[2] + else: + continue + comp1 = components[p_pos1] + if comp1.comm == 0 and comp1.rank == 1: + a.append((i, p_pos1)) + if not a: + return ex + elim = set() + tv = [] + hit = True + coeff = S.One + ta = None + while hit: + hit = False + for i, ai in enumerate(a[:-1]): + if ai[0] in elim: + continue + if ai[0] != a[i + 1][0] - 1: + continue + if components[ai[1]] != components[a[i + 1][1]]: + continue + elim.add(ai[0]) + elim.add(ai[1]) + elim.add(a[i + 1][0]) + elim.add(a[i + 1][1]) + if not ta: + ta = ex.split() + mu = TensorIndex('mu', LorentzIndex) + hit = True + if i == 0: + coeff = ex.coeff + tx = components[ai[1]](mu)*components[ai[1]](-mu) + if len(a) == 2: + tx *= 4 # eye(4) + tv.append(tx) + break + + if tv: + a = [x for j, x in enumerate(ta) if j not in elim] + a.extend(tv) + t = tensor_mul(*a)*coeff + # t = t.replace(lambda x: x.is_Matrix, lambda x: 1) + return t + else: + return ex + + if sort: + ex = ex.sorted_components() + # this would be better off with pattern matching + while 1: + t = _simplify_gpgp(ex) + if t != ex: + ex = t + else: + return t + + +def gamma_trace(t): + """ + trace of a single line of gamma matrices + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, \ + gamma_trace, LorentzIndex + >>> from sympy.tensor.tensor import tensor_indices, tensor_heads + >>> p, q = tensor_heads('p, q', [LorentzIndex]) + >>> i0,i1,i2,i3,i4,i5 = tensor_indices('i0:6', LorentzIndex) + >>> ps = p(i0)*G(-i0) + >>> qs = q(i0)*G(-i0) + >>> gamma_trace(G(i0)*G(i1)) + 4*metric(i0, i1) + >>> gamma_trace(ps*ps) - 4*p(i0)*p(-i0) + 0 + >>> gamma_trace(ps*qs + ps*ps) - 4*p(i0)*p(-i0) - 4*p(i0)*q(-i0) + 0 + + """ + if isinstance(t, TensAdd): + res = TensAdd(*[gamma_trace(x) for x in t.args]) + return res + t = _simplify_single_line(t) + res = _trace_single_line(t) + return res + + +def _simplify_single_line(expression): + """ + Simplify single-line product of gamma matrices. + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, \ + LorentzIndex, _simplify_single_line + >>> from sympy.tensor.tensor import tensor_indices, TensorHead + >>> p = TensorHead('p', [LorentzIndex]) + >>> i0,i1 = tensor_indices('i0:2', LorentzIndex) + >>> _simplify_single_line(G(i0)*G(i1)*p(-i1)*G(-i0)) + 2*G(i0)*p(-i0) + 0 + + """ + t1, t2 = extract_type_tens(expression, GammaMatrix) + if t1 != 1: + t1 = kahane_simplify(t1) + res = t1*t2 + return res + + +def _trace_single_line(t): + """ + Evaluate the trace of a single gamma matrix line inside a ``TensExpr``. + + Notes + ===== + + If there are ``DiracSpinorIndex.auto_left`` and ``DiracSpinorIndex.auto_right`` + indices trace over them; otherwise traces are not implied (explain) + + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, \ + LorentzIndex, _trace_single_line + >>> from sympy.tensor.tensor import tensor_indices, TensorHead + >>> p = TensorHead('p', [LorentzIndex]) + >>> i0,i1,i2,i3,i4,i5 = tensor_indices('i0:6', LorentzIndex) + >>> _trace_single_line(G(i0)*G(i1)) + 4*metric(i0, i1) + >>> _trace_single_line(G(i0)*p(-i0)*G(i1)*p(-i1)) - 4*p(i0)*p(-i0) + 0 + + """ + def _trace_single_line1(t): + t = t.sorted_components() + components = t.components + ncomps = len(components) + g = LorentzIndex.metric + # gamma matirices are in a[i:j] + hit = 0 + for i in range(ncomps): + if components[i] == GammaMatrix: + hit = 1 + break + + for j in range(i + hit, ncomps): + if components[j] != GammaMatrix: + break + else: + j = ncomps + numG = j - i + if numG == 0: + tcoeff = t.coeff + return t.nocoeff if tcoeff else t + if numG % 2 == 1: + return TensMul.from_data(S.Zero, [], [], []) + elif numG > 4: + # find the open matrix indices and connect them: + a = t.split() + ind1 = a[i].get_indices()[0] + ind2 = a[i + 1].get_indices()[0] + aa = a[:i] + a[i + 2:] + t1 = tensor_mul(*aa)*g(ind1, ind2) + t1 = t1.contract_metric(g) + args = [t1] + sign = 1 + for k in range(i + 2, j): + sign = -sign + ind2 = a[k].get_indices()[0] + aa = a[:i] + a[i + 1:k] + a[k + 1:] + t2 = sign*tensor_mul(*aa)*g(ind1, ind2) + t2 = t2.contract_metric(g) + t2 = simplify_gpgp(t2, False) + args.append(t2) + t3 = TensAdd(*args) + t3 = _trace_single_line(t3) + return t3 + else: + a = t.split() + t1 = _gamma_trace1(*a[i:j]) + a2 = a[:i] + a[j:] + t2 = tensor_mul(*a2) + t3 = t1*t2 + if not t3: + return t3 + t3 = t3.contract_metric(g) + return t3 + + t = t.expand() + if isinstance(t, TensAdd): + a = [_trace_single_line1(x)*x.coeff for x in t.args] + return TensAdd(*a) + elif isinstance(t, (Tensor, TensMul)): + r = t.coeff*_trace_single_line1(t) + return r + else: + return trace(t) + + +def _gamma_trace1(*a): + gctr = 4 # FIXME specific for d=4 + g = LorentzIndex.metric + if not a: + return gctr + n = len(a) + if n%2 == 1: + #return TensMul.from_data(S.Zero, [], [], []) + return S.Zero + if n == 2: + ind0 = a[0].get_indices()[0] + ind1 = a[1].get_indices()[0] + return gctr*g(ind0, ind1) + if n == 4: + ind0 = a[0].get_indices()[0] + ind1 = a[1].get_indices()[0] + ind2 = a[2].get_indices()[0] + ind3 = a[3].get_indices()[0] + + return gctr*(g(ind0, ind1)*g(ind2, ind3) - \ + g(ind0, ind2)*g(ind1, ind3) + g(ind0, ind3)*g(ind1, ind2)) + + +def kahane_simplify(expression): + r""" + This function cancels contracted elements in a product of four + dimensional gamma matrices, resulting in an expression equal to the given + one, without the contracted gamma matrices. + + Parameters + ========== + + `expression` the tensor expression containing the gamma matrices to simplify. + + Notes + ===== + + If spinor indices are given, the matrices must be given in + the order given in the product. + + Algorithm + ========= + + The idea behind the algorithm is to use some well-known identities, + i.e., for contractions enclosing an even number of `\gamma` matrices + + `\gamma^\mu \gamma_{a_1} \cdots \gamma_{a_{2N}} \gamma_\mu = 2 (\gamma_{a_{2N}} \gamma_{a_1} \cdots \gamma_{a_{2N-1}} + \gamma_{a_{2N-1}} \cdots \gamma_{a_1} \gamma_{a_{2N}} )` + + for an odd number of `\gamma` matrices + + `\gamma^\mu \gamma_{a_1} \cdots \gamma_{a_{2N+1}} \gamma_\mu = -2 \gamma_{a_{2N+1}} \gamma_{a_{2N}} \cdots \gamma_{a_{1}}` + + Instead of repeatedly applying these identities to cancel out all contracted indices, + it is possible to recognize the links that would result from such an operation, + the problem is thus reduced to a simple rearrangement of free gamma matrices. + + Examples + ======== + + When using, always remember that the original expression coefficient + has to be handled separately + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, LorentzIndex + >>> from sympy.physics.hep.gamma_matrices import kahane_simplify + >>> from sympy.tensor.tensor import tensor_indices + >>> i0, i1, i2 = tensor_indices('i0:3', LorentzIndex) + >>> ta = G(i0)*G(-i0) + >>> kahane_simplify(ta) + Matrix([ + [4, 0, 0, 0], + [0, 4, 0, 0], + [0, 0, 4, 0], + [0, 0, 0, 4]]) + >>> tb = G(i0)*G(i1)*G(-i0) + >>> kahane_simplify(tb) + -2*GammaMatrix(i1) + >>> t = G(i0)*G(-i0) + >>> kahane_simplify(t) + Matrix([ + [4, 0, 0, 0], + [0, 4, 0, 0], + [0, 0, 4, 0], + [0, 0, 0, 4]]) + >>> t = G(i0)*G(-i0) + >>> kahane_simplify(t) + Matrix([ + [4, 0, 0, 0], + [0, 4, 0, 0], + [0, 0, 4, 0], + [0, 0, 0, 4]]) + + If there are no contractions, the same expression is returned + + >>> tc = G(i0)*G(i1) + >>> kahane_simplify(tc) + GammaMatrix(i0)*GammaMatrix(i1) + + References + ========== + + [1] Algorithm for Reducing Contracted Products of gamma Matrices, + Joseph Kahane, Journal of Mathematical Physics, Vol. 9, No. 10, October 1968. + """ + + if isinstance(expression, Mul): + return expression + if isinstance(expression, TensAdd): + return TensAdd(*[kahane_simplify(arg) for arg in expression.args]) + + if isinstance(expression, Tensor): + return expression + + assert isinstance(expression, TensMul) + + gammas = expression.args + + for gamma in gammas: + assert gamma.component == GammaMatrix + + free = expression.free + # spinor_free = [_ for _ in expression.free_in_args if _[1] != 0] + + # if len(spinor_free) == 2: + # spinor_free.sort(key=lambda x: x[2]) + # assert spinor_free[0][1] == 1 and spinor_free[-1][1] == 2 + # assert spinor_free[0][2] == 0 + # elif spinor_free: + # raise ValueError('spinor indices do not match') + + dum = [] + for dum_pair in expression.dum: + if expression.index_types[dum_pair[0]] == LorentzIndex: + dum.append((dum_pair[0], dum_pair[1])) + + dum = sorted(dum) + + if len(dum) == 0: # or GammaMatrixHead: + # no contractions in `expression`, just return it. + return expression + + # find the `first_dum_pos`, i.e. the position of the first contracted + # gamma matrix, Kahane's algorithm as described in his paper requires the + # gamma matrix expression to start with a contracted gamma matrix, this is + # a workaround which ignores possible initial free indices, and re-adds + # them later. + + first_dum_pos = min(map(min, dum)) + + # for p1, p2, a1, a2 in expression.dum_in_args: + # if p1 != 0 or p2 != 0: + # # only Lorentz indices, skip Dirac indices: + # continue + # first_dum_pos = min(p1, p2) + # break + + total_number = len(free) + len(dum)*2 + number_of_contractions = len(dum) + + free_pos = [None]*total_number + for i in free: + free_pos[i[1]] = i[0] + + # `index_is_free` is a list of booleans, to identify index position + # and whether that index is free or dummy. + index_is_free = [False]*total_number + + for i, indx in enumerate(free): + index_is_free[indx[1]] = True + + # `links` is a dictionary containing the graph described in Kahane's paper, + # to every key correspond one or two values, representing the linked indices. + # All values in `links` are integers, negative numbers are used in the case + # where it is necessary to insert gamma matrices between free indices, in + # order to make Kahane's algorithm work (see paper). + links = {i: [] for i in range(first_dum_pos, total_number)} + + # `cum_sign` is a step variable to mark the sign of every index, see paper. + cum_sign = -1 + # `cum_sign_list` keeps storage for all `cum_sign` (every index). + cum_sign_list = [None]*total_number + block_free_count = 0 + + # multiply `resulting_coeff` by the coefficient parameter, the rest + # of the algorithm ignores a scalar coefficient. + resulting_coeff = S.One + + # initialize a list of lists of indices. The outer list will contain all + # additive tensor expressions, while the inner list will contain the + # free indices (rearranged according to the algorithm). + resulting_indices = [[]] + + # start to count the `connected_components`, which together with the number + # of contractions, determines a -1 or +1 factor to be multiplied. + connected_components = 1 + + # First loop: here we fill `cum_sign_list`, and draw the links + # among consecutive indices (they are stored in `links`). Links among + # non-consecutive indices will be drawn later. + for i, is_free in enumerate(index_is_free): + # if `expression` starts with free indices, they are ignored here; + # they are later added as they are to the beginning of all + # `resulting_indices` list of lists of indices. + if i < first_dum_pos: + continue + + if is_free: + block_free_count += 1 + # if previous index was free as well, draw an arch in `links`. + if block_free_count > 1: + links[i - 1].append(i) + links[i].append(i - 1) + else: + # Change the sign of the index (`cum_sign`) if the number of free + # indices preceding it is even. + cum_sign *= 1 if (block_free_count % 2) else -1 + if block_free_count == 0 and i != first_dum_pos: + # check if there are two consecutive dummy indices: + # in this case create virtual indices with negative position, + # these "virtual" indices represent the insertion of two + # gamma^0 matrices to separate consecutive dummy indices, as + # Kahane's algorithm requires dummy indices to be separated by + # free indices. The product of two gamma^0 matrices is unity, + # so the new expression being examined is the same as the + # original one. + if cum_sign == -1: + links[-1-i] = [-1-i+1] + links[-1-i+1] = [-1-i] + if (i - cum_sign) in links: + if i != first_dum_pos: + links[i].append(i - cum_sign) + if block_free_count != 0: + if i - cum_sign < len(index_is_free): + if index_is_free[i - cum_sign]: + links[i - cum_sign].append(i) + block_free_count = 0 + + cum_sign_list[i] = cum_sign + + # The previous loop has only created links between consecutive free indices, + # it is necessary to properly create links among dummy (contracted) indices, + # according to the rules described in Kahane's paper. There is only one exception + # to Kahane's rules: the negative indices, which handle the case of some + # consecutive free indices (Kahane's paper just describes dummy indices + # separated by free indices, hinting that free indices can be added without + # altering the expression result). + for i in dum: + # get the positions of the two contracted indices: + pos1 = i[0] + pos2 = i[1] + + # create Kahane's upper links, i.e. the upper arcs between dummy + # (i.e. contracted) indices: + links[pos1].append(pos2) + links[pos2].append(pos1) + + # create Kahane's lower links, this corresponds to the arcs below + # the line described in the paper: + + # first we move `pos1` and `pos2` according to the sign of the indices: + linkpos1 = pos1 + cum_sign_list[pos1] + linkpos2 = pos2 + cum_sign_list[pos2] + + # otherwise, perform some checks before creating the lower arcs: + + # make sure we are not exceeding the total number of indices: + if linkpos1 >= total_number: + continue + if linkpos2 >= total_number: + continue + + # make sure we are not below the first dummy index in `expression`: + if linkpos1 < first_dum_pos: + continue + if linkpos2 < first_dum_pos: + continue + + # check if the previous loop created "virtual" indices between dummy + # indices, in such a case relink `linkpos1` and `linkpos2`: + if (-1-linkpos1) in links: + linkpos1 = -1-linkpos1 + if (-1-linkpos2) in links: + linkpos2 = -1-linkpos2 + + # move only if not next to free index: + if linkpos1 >= 0 and not index_is_free[linkpos1]: + linkpos1 = pos1 + + if linkpos2 >=0 and not index_is_free[linkpos2]: + linkpos2 = pos2 + + # create the lower arcs: + if linkpos2 not in links[linkpos1]: + links[linkpos1].append(linkpos2) + if linkpos1 not in links[linkpos2]: + links[linkpos2].append(linkpos1) + + # This loop starts from the `first_dum_pos` index (first dummy index) + # walks through the graph deleting the visited indices from `links`, + # it adds a gamma matrix for every free index in encounters, while it + # completely ignores dummy indices and virtual indices. + pointer = first_dum_pos + previous_pointer = 0 + while True: + if pointer in links: + next_ones = links.pop(pointer) + else: + break + + if previous_pointer in next_ones: + next_ones.remove(previous_pointer) + + previous_pointer = pointer + + if next_ones: + pointer = next_ones[0] + else: + break + + if pointer == previous_pointer: + break + if pointer >=0 and free_pos[pointer] is not None: + for ri in resulting_indices: + ri.append(free_pos[pointer]) + + # The following loop removes the remaining connected components in `links`. + # If there are free indices inside a connected component, it gives a + # contribution to the resulting expression given by the factor + # `gamma_a gamma_b ... gamma_z + gamma_z ... gamma_b gamma_a`, in Kahanes's + # paper represented as {gamma_a, gamma_b, ... , gamma_z}, + # virtual indices are ignored. The variable `connected_components` is + # increased by one for every connected component this loop encounters. + + # If the connected component has virtual and dummy indices only + # (no free indices), it contributes to `resulting_indices` by a factor of two. + # The multiplication by two is a result of the + # factor {gamma^0, gamma^0} = 2 I, as it appears in Kahane's paper. + # Note: curly brackets are meant as in the paper, as a generalized + # multi-element anticommutator! + + while links: + connected_components += 1 + pointer = min(links.keys()) + previous_pointer = pointer + # the inner loop erases the visited indices from `links`, and it adds + # all free indices to `prepend_indices` list, virtual indices are + # ignored. + prepend_indices = [] + while True: + if pointer in links: + next_ones = links.pop(pointer) + else: + break + + if previous_pointer in next_ones: + if len(next_ones) > 1: + next_ones.remove(previous_pointer) + + previous_pointer = pointer + + if next_ones: + pointer = next_ones[0] + + if pointer >= first_dum_pos and free_pos[pointer] is not None: + prepend_indices.insert(0, free_pos[pointer]) + # if `prepend_indices` is void, it means there are no free indices + # in the loop (and it can be shown that there must be a virtual index), + # loops of virtual indices only contribute by a factor of two: + if len(prepend_indices) == 0: + resulting_coeff *= 2 + # otherwise, add the free indices in `prepend_indices` to + # the `resulting_indices`: + else: + expr1 = prepend_indices + expr2 = list(reversed(prepend_indices)) + resulting_indices = [expri + ri for ri in resulting_indices for expri in (expr1, expr2)] + + # sign correction, as described in Kahane's paper: + resulting_coeff *= -1 if (number_of_contractions - connected_components + 1) % 2 else 1 + # power of two factor, as described in Kahane's paper: + resulting_coeff *= 2**(number_of_contractions) + + # If `first_dum_pos` is not zero, it means that there are trailing free gamma + # matrices in front of `expression`, so multiply by them: + resulting_indices = [ free_pos[0:first_dum_pos] + ri for ri in resulting_indices ] + + resulting_expr = S.Zero + for i in resulting_indices: + temp_expr = S.One + for j in i: + temp_expr *= GammaMatrix(j) + resulting_expr += temp_expr + + t = resulting_coeff * resulting_expr + t1 = None + if isinstance(t, TensAdd): + t1 = t.args[0] + elif isinstance(t, TensMul): + t1 = t + if t1: + pass + else: + t = eye(4)*t + return t diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/hydrogen.py b/.venv/lib/python3.13/site-packages/sympy/physics/hydrogen.py new file mode 100644 index 0000000000000000000000000000000000000000..a3bac274c66a2cf97d4238d9e3951e39df820931 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/hydrogen.py @@ -0,0 +1,265 @@ +from sympy.core.numbers import Float +from sympy.core.singleton import S +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.polynomials import assoc_laguerre +from sympy.functions.special.spherical_harmonics import Ynm + + +def R_nl(n, l, r, Z=1): + """ + Returns the Hydrogen radial wavefunction R_{nl}. + + Parameters + ========== + + n : integer + Principal Quantum Number which is + an integer with possible values as 1, 2, 3, 4,... + l : integer + ``l`` is the Angular Momentum Quantum Number with + values ranging from 0 to ``n-1``. + r : + Radial coordinate. + Z : + Atomic number (1 for Hydrogen, 2 for Helium, ...) + + Everything is in Hartree atomic units. + + Examples + ======== + + >>> from sympy.physics.hydrogen import R_nl + >>> from sympy.abc import r, Z + >>> R_nl(1, 0, r, Z) + 2*sqrt(Z**3)*exp(-Z*r) + >>> R_nl(2, 0, r, Z) + sqrt(2)*(-Z*r + 2)*sqrt(Z**3)*exp(-Z*r/2)/4 + >>> R_nl(2, 1, r, Z) + sqrt(6)*Z*r*sqrt(Z**3)*exp(-Z*r/2)/12 + + For Hydrogen atom, you can just use the default value of Z=1: + + >>> R_nl(1, 0, r) + 2*exp(-r) + >>> R_nl(2, 0, r) + sqrt(2)*(2 - r)*exp(-r/2)/4 + >>> R_nl(3, 0, r) + 2*sqrt(3)*(2*r**2/9 - 2*r + 3)*exp(-r/3)/27 + + For Silver atom, you would use Z=47: + + >>> R_nl(1, 0, r, Z=47) + 94*sqrt(47)*exp(-47*r) + >>> R_nl(2, 0, r, Z=47) + 47*sqrt(94)*(2 - 47*r)*exp(-47*r/2)/4 + >>> R_nl(3, 0, r, Z=47) + 94*sqrt(141)*(4418*r**2/9 - 94*r + 3)*exp(-47*r/3)/27 + + The normalization of the radial wavefunction is: + + >>> from sympy import integrate, oo + >>> integrate(R_nl(1, 0, r)**2 * r**2, (r, 0, oo)) + 1 + >>> integrate(R_nl(2, 0, r)**2 * r**2, (r, 0, oo)) + 1 + >>> integrate(R_nl(2, 1, r)**2 * r**2, (r, 0, oo)) + 1 + + It holds for any atomic number: + + >>> integrate(R_nl(1, 0, r, Z=2)**2 * r**2, (r, 0, oo)) + 1 + >>> integrate(R_nl(2, 0, r, Z=3)**2 * r**2, (r, 0, oo)) + 1 + >>> integrate(R_nl(2, 1, r, Z=4)**2 * r**2, (r, 0, oo)) + 1 + + """ + # sympify arguments + n, l, r, Z = map(S, [n, l, r, Z]) + # radial quantum number + n_r = n - l - 1 + # rescaled "r" + a = 1/Z # Bohr radius + r0 = 2 * r / (n * a) + # normalization coefficient + C = sqrt((S(2)/(n*a))**3 * factorial(n_r) / (2*n*factorial(n + l))) + # This is an equivalent normalization coefficient, that can be found in + # some books. Both coefficients seem to be the same fast: + # C = S(2)/n**2 * sqrt(1/a**3 * factorial(n_r) / (factorial(n+l))) + return C * r0**l * assoc_laguerre(n_r, 2*l + 1, r0).expand() * exp(-r0/2) + + +def Psi_nlm(n, l, m, r, phi, theta, Z=1): + """ + Returns the Hydrogen wave function psi_{nlm}. It's the product of + the radial wavefunction R_{nl} and the spherical harmonic Y_{l}^{m}. + + Parameters + ========== + + n : integer + Principal Quantum Number which is + an integer with possible values as 1, 2, 3, 4,... + l : integer + ``l`` is the Angular Momentum Quantum Number with + values ranging from 0 to ``n-1``. + m : integer + ``m`` is the Magnetic Quantum Number with values + ranging from ``-l`` to ``l``. + r : + radial coordinate + phi : + azimuthal angle + theta : + polar angle + Z : + atomic number (1 for Hydrogen, 2 for Helium, ...) + + Everything is in Hartree atomic units. + + Examples + ======== + + >>> from sympy.physics.hydrogen import Psi_nlm + >>> from sympy import Symbol + >>> r=Symbol("r", positive=True) + >>> phi=Symbol("phi", real=True) + >>> theta=Symbol("theta", real=True) + >>> Z=Symbol("Z", positive=True, integer=True, nonzero=True) + >>> Psi_nlm(1,0,0,r,phi,theta,Z) + Z**(3/2)*exp(-Z*r)/sqrt(pi) + >>> Psi_nlm(2,1,1,r,phi,theta,Z) + -Z**(5/2)*r*exp(I*phi)*exp(-Z*r/2)*sin(theta)/(8*sqrt(pi)) + + Integrating the absolute square of a hydrogen wavefunction psi_{nlm} + over the whole space leads 1. + + The normalization of the hydrogen wavefunctions Psi_nlm is: + + >>> from sympy import integrate, conjugate, pi, oo, sin + >>> wf=Psi_nlm(2,1,1,r,phi,theta,Z) + >>> abs_sqrd=wf*conjugate(wf) + >>> jacobi=r**2*sin(theta) + >>> integrate(abs_sqrd*jacobi, (r,0,oo), (phi,0,2*pi), (theta,0,pi)) + 1 + """ + + # sympify arguments + n, l, m, r, phi, theta, Z = map(S, [n, l, m, r, phi, theta, Z]) + # check if values for n,l,m make physically sense + if n.is_integer and n < 1: + raise ValueError("'n' must be positive integer") + if l.is_integer and not (n > l): + raise ValueError("'n' must be greater than 'l'") + if m.is_integer and not (abs(m) <= l): + raise ValueError("|'m'| must be less or equal 'l'") + # return the hydrogen wave function + return R_nl(n, l, r, Z)*Ynm(l, m, theta, phi).expand(func=True) + + +def E_nl(n, Z=1): + """ + Returns the energy of the state (n, l) in Hartree atomic units. + + The energy does not depend on "l". + + Parameters + ========== + + n : integer + Principal Quantum Number which is + an integer with possible values as 1, 2, 3, 4,... + Z : + Atomic number (1 for Hydrogen, 2 for Helium, ...) + + Examples + ======== + + >>> from sympy.physics.hydrogen import E_nl + >>> from sympy.abc import n, Z + >>> E_nl(n, Z) + -Z**2/(2*n**2) + >>> E_nl(1) + -1/2 + >>> E_nl(2) + -1/8 + >>> E_nl(3) + -1/18 + >>> E_nl(3, 47) + -2209/18 + + """ + n, Z = S(n), S(Z) + if n.is_integer and (n < 1): + raise ValueError("'n' must be positive integer") + return -Z**2/(2*n**2) + + +def E_nl_dirac(n, l, spin_up=True, Z=1, c=Float("137.035999037")): + """ + Returns the relativistic energy of the state (n, l, spin) in Hartree atomic + units. + + The energy is calculated from the Dirac equation. The rest mass energy is + *not* included. + + Parameters + ========== + + n : integer + Principal Quantum Number which is + an integer with possible values as 1, 2, 3, 4,... + l : integer + ``l`` is the Angular Momentum Quantum Number with + values ranging from 0 to ``n-1``. + spin_up : + True if the electron spin is up (default), otherwise down + Z : + Atomic number (1 for Hydrogen, 2 for Helium, ...) + c : + Speed of light in atomic units. Default value is 137.035999037, + taken from https://arxiv.org/abs/1012.3627 + + Examples + ======== + + >>> from sympy.physics.hydrogen import E_nl_dirac + >>> E_nl_dirac(1, 0) + -0.500006656595360 + + >>> E_nl_dirac(2, 0) + -0.125002080189006 + >>> E_nl_dirac(2, 1) + -0.125000416028342 + >>> E_nl_dirac(2, 1, False) + -0.125002080189006 + + >>> E_nl_dirac(3, 0) + -0.0555562951740285 + >>> E_nl_dirac(3, 1) + -0.0555558020932949 + >>> E_nl_dirac(3, 1, False) + -0.0555562951740285 + >>> E_nl_dirac(3, 2) + -0.0555556377366884 + >>> E_nl_dirac(3, 2, False) + -0.0555558020932949 + + """ + n, l, Z, c = map(S, [n, l, Z, c]) + if not (l >= 0): + raise ValueError("'l' must be positive or zero") + if not (n > l): + raise ValueError("'n' must be greater than 'l'") + if (l == 0 and spin_up is False): + raise ValueError("Spin must be up for l==0.") + # skappa is sign*kappa, where sign contains the correct sign + if spin_up: + skappa = -l - 1 + else: + skappa = -l + beta = sqrt(skappa**2 - Z**2/c**2) + return c**2/sqrt(1 + Z**2/(n + skappa + beta)**2/c**2) - c**2 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/matrices.py b/.venv/lib/python3.13/site-packages/sympy/physics/matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..d91466220d63956053b91bd76b948ee677e7c191 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/matrices.py @@ -0,0 +1,176 @@ +"""Known matrices related to physics""" + +from sympy.core.numbers import I +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.utilities.decorator import deprecated + + +def msigma(i): + r"""Returns a Pauli matrix `\sigma_i` with `i=1,2,3`. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Pauli_matrices + + Examples + ======== + + >>> from sympy.physics.matrices import msigma + >>> msigma(1) + Matrix([ + [0, 1], + [1, 0]]) + """ + if i == 1: + mat = ( + (0, 1), + (1, 0) + ) + elif i == 2: + mat = ( + (0, -I), + (I, 0) + ) + elif i == 3: + mat = ( + (1, 0), + (0, -1) + ) + else: + raise IndexError("Invalid Pauli index") + return Matrix(mat) + + +def pat_matrix(m, dx, dy, dz): + """Returns the Parallel Axis Theorem matrix to translate the inertia + matrix a distance of `(dx, dy, dz)` for a body of mass m. + + Examples + ======== + + To translate a body having a mass of 2 units a distance of 1 unit along + the `x`-axis we get: + + >>> from sympy.physics.matrices import pat_matrix + >>> pat_matrix(2, 1, 0, 0) + Matrix([ + [0, 0, 0], + [0, 2, 0], + [0, 0, 2]]) + + """ + dxdy = -dx*dy + dydz = -dy*dz + dzdx = -dz*dx + dxdx = dx**2 + dydy = dy**2 + dzdz = dz**2 + mat = ((dydy + dzdz, dxdy, dzdx), + (dxdy, dxdx + dzdz, dydz), + (dzdx, dydz, dydy + dxdx)) + return m*Matrix(mat) + + +def mgamma(mu, lower=False): + r"""Returns a Dirac gamma matrix `\gamma^\mu` in the standard + (Dirac) representation. + + Explanation + =========== + + If you want `\gamma_\mu`, use ``gamma(mu, True)``. + + We use a convention: + + `\gamma^5 = i \cdot \gamma^0 \cdot \gamma^1 \cdot \gamma^2 \cdot \gamma^3` + + `\gamma_5 = i \cdot \gamma_0 \cdot \gamma_1 \cdot \gamma_2 \cdot \gamma_3 = - \gamma^5` + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gamma_matrices + + Examples + ======== + + >>> from sympy.physics.matrices import mgamma + >>> mgamma(1) + Matrix([ + [ 0, 0, 0, 1], + [ 0, 0, 1, 0], + [ 0, -1, 0, 0], + [-1, 0, 0, 0]]) + """ + if mu not in (0, 1, 2, 3, 5): + raise IndexError("Invalid Dirac index") + if mu == 0: + mat = ( + (1, 0, 0, 0), + (0, 1, 0, 0), + (0, 0, -1, 0), + (0, 0, 0, -1) + ) + elif mu == 1: + mat = ( + (0, 0, 0, 1), + (0, 0, 1, 0), + (0, -1, 0, 0), + (-1, 0, 0, 0) + ) + elif mu == 2: + mat = ( + (0, 0, 0, -I), + (0, 0, I, 0), + (0, I, 0, 0), + (-I, 0, 0, 0) + ) + elif mu == 3: + mat = ( + (0, 0, 1, 0), + (0, 0, 0, -1), + (-1, 0, 0, 0), + (0, 1, 0, 0) + ) + elif mu == 5: + mat = ( + (0, 0, 1, 0), + (0, 0, 0, 1), + (1, 0, 0, 0), + (0, 1, 0, 0) + ) + m = Matrix(mat) + if lower: + if mu in (1, 2, 3, 5): + m = -m + return m + +#Minkowski tensor using the convention (+,-,-,-) used in the Quantum Field +#Theory +minkowski_tensor = Matrix( ( + (1, 0, 0, 0), + (0, -1, 0, 0), + (0, 0, -1, 0), + (0, 0, 0, -1) +)) + + +@deprecated( + """ + The sympy.physics.matrices.mdft method is deprecated. Use + sympy.DFT(n).as_explicit() instead. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-physics-mdft", +) +def mdft(n): + r""" + .. deprecated:: 1.9 + + Use DFT from sympy.matrices.expressions.fourier instead. + + To get identical behavior to ``mdft(n)``, use ``DFT(n).as_explicit()``. + """ + from sympy.matrices.expressions.fourier import DFT + return DFT(n).as_mutable() diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..afd8c071a2af4fd201d5b2371594b19e4a68edda --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/__init__.py @@ -0,0 +1,90 @@ +__all__ = [ + 'vector', + + 'CoordinateSym', 'ReferenceFrame', 'Dyadic', 'Vector', 'Point', 'cross', + 'dot', 'express', 'time_derivative', 'outer', 'kinematic_equations', + 'get_motion_params', 'partial_velocity', 'dynamicsymbols', 'vprint', + 'vsstrrepr', 'vsprint', 'vpprint', 'vlatex', 'init_vprinting', 'curl', + 'divergence', 'gradient', 'is_conservative', 'is_solenoidal', + 'scalar_potential', 'scalar_potential_difference', + + 'KanesMethod', + + 'RigidBody', + + 'linear_momentum', 'angular_momentum', 'kinetic_energy', 'potential_energy', + 'Lagrangian', 'mechanics_printing', 'mprint', 'msprint', 'mpprint', + 'mlatex', 'msubs', 'find_dynamicsymbols', + + 'inertia', 'inertia_of_point_mass', 'Inertia', + + 'Force', 'Torque', + + 'Particle', + + 'LagrangesMethod', + + 'Linearizer', + + 'Body', + + 'SymbolicSystem', 'System', + + 'PinJoint', 'PrismaticJoint', 'CylindricalJoint', 'PlanarJoint', + 'SphericalJoint', 'WeldJoint', + + 'JointsMethod', + + 'WrappingCylinder', 'WrappingGeometryBase', 'WrappingSphere', + + 'PathwayBase', 'LinearPathway', 'ObstacleSetPathway', 'WrappingPathway', + + 'ActuatorBase', 'ForceActuator', 'LinearDamper', 'LinearSpring', + 'TorqueActuator', 'DuffingSpring', 'CoulombKineticFriction', +] + +from sympy.physics import vector + +from sympy.physics.vector import (CoordinateSym, ReferenceFrame, Dyadic, Vector, Point, + cross, dot, express, time_derivative, outer, kinematic_equations, + get_motion_params, partial_velocity, dynamicsymbols, vprint, + vsstrrepr, vsprint, vpprint, vlatex, init_vprinting, curl, divergence, + gradient, is_conservative, is_solenoidal, scalar_potential, + scalar_potential_difference) + +from .kane import KanesMethod + +from .rigidbody import RigidBody + +from .functions import (linear_momentum, angular_momentum, kinetic_energy, + potential_energy, Lagrangian, mechanics_printing, + mprint, msprint, mpprint, mlatex, msubs, + find_dynamicsymbols) + +from .inertia import inertia, inertia_of_point_mass, Inertia + +from .loads import Force, Torque + +from .particle import Particle + +from .lagrange import LagrangesMethod + +from .linearize import Linearizer + +from .body import Body + +from .system import SymbolicSystem, System + +from .jointsmethod import JointsMethod + +from .joint import (PinJoint, PrismaticJoint, CylindricalJoint, PlanarJoint, + SphericalJoint, WeldJoint) + +from .wrapping_geometry import (WrappingCylinder, WrappingGeometryBase, + WrappingSphere) + +from .pathway import (PathwayBase, LinearPathway, ObstacleSetPathway, + WrappingPathway) + +from .actuator import (ActuatorBase, ForceActuator, LinearDamper, LinearSpring, + TorqueActuator, DuffingSpring, CoulombKineticFriction) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/actuator.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/actuator.py new file mode 100644 index 0000000000000000000000000000000000000000..625b3e55019e7545c6dfed073d388acba91a324c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/actuator.py @@ -0,0 +1,1147 @@ +"""Implementations of actuators for linked force and torque application.""" + +from abc import ABC, abstractmethod + +from sympy import S, sympify, exp, sign +from sympy.physics.mechanics.joint import PinJoint +from sympy.physics.mechanics.loads import Torque +from sympy.physics.mechanics.pathway import PathwayBase +from sympy.physics.mechanics.rigidbody import RigidBody +from sympy.physics.vector import ReferenceFrame, Vector + + +__all__ = [ + 'ActuatorBase', + 'ForceActuator', + 'LinearDamper', + 'LinearSpring', + 'TorqueActuator', + 'DuffingSpring', + 'CoulombKineticFriction', +] + + +class ActuatorBase(ABC): + """Abstract base class for all actuator classes to inherit from. + + Notes + ===== + + Instances of this class cannot be directly instantiated by users. However, + it can be used to created custom actuator types through subclassing. + + """ + + def __init__(self): + """Initializer for ``ActuatorBase``.""" + pass + + @abstractmethod + def to_loads(self): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + """ + pass + + def __repr__(self): + """Default representation of an actuator.""" + return f'{self.__class__.__name__}()' + + +class ForceActuator(ActuatorBase): + """Force-producing actuator. + + Explanation + =========== + + A ``ForceActuator`` is an actuator that produces a (expansile) force along + its length. + + A force actuator uses a pathway instance to determine the direction and + number of forces that it applies to a system. Consider the simplest case + where a ``LinearPathway`` instance is used. This pathway is made up of two + points that can move relative to each other, and results in a pair of equal + and opposite forces acting on the endpoints. If the positive time-varying + Euclidean distance between the two points is defined, then the "extension + velocity" is the time derivative of this distance. The extension velocity + is positive when the two points are moving away from each other and + negative when moving closer to each other. The direction for the force + acting on either point is determined by constructing a unit vector directed + from the other point to this point. This establishes a sign convention such + that a positive force magnitude tends to push the points apart, this is the + meaning of "expansile" in this context. The following diagram shows the + positive force sense and the distance between the points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + To construct an actuator, an expression (or symbol) must be supplied to + represent the force it can produce, alongside a pathway specifying its line + of action. Let's also create a global reference frame and spatially fix one + of the points in it while setting the other to be positioned such that it + can freely move in the frame's x direction specified by the coordinate + ``q``. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (ForceActuator, LinearPathway, + ... Point, ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> force = symbols('F') + >>> pA, pB = Point('pA'), Point('pB') + >>> pA.set_vel(N, 0) + >>> pB.set_pos(pA, q*N.x) + >>> pB.pos_from(pA) + q(t)*N.x + >>> linear_pathway = LinearPathway(pA, pB) + >>> actuator = ForceActuator(force, linear_pathway) + >>> actuator + ForceActuator(F, LinearPathway(pA, pB)) + + Parameters + ========== + + force : Expr + The scalar expression defining the (expansile) force that the actuator + produces. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + + """ + + def __init__(self, force, pathway): + """Initializer for ``ForceActuator``. + + Parameters + ========== + + force : Expr + The scalar expression defining the (expansile) force that the + actuator produces. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of + a concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + + """ + self.force = force + self.pathway = pathway + + @property + def force(self): + """The magnitude of the force produced by the actuator.""" + return self._force + + @force.setter + def force(self, force): + if hasattr(self, '_force'): + msg = ( + f'Can\'t set attribute `force` to {repr(force)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + self._force = sympify(force, strict=True) + + @property + def pathway(self): + """The ``Pathway`` defining the actuator's line of action.""" + return self._pathway + + @pathway.setter + def pathway(self, pathway): + if hasattr(self, '_pathway'): + msg = ( + f'Can\'t set attribute `pathway` to {repr(pathway)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + if not isinstance(pathway, PathwayBase): + msg = ( + f'Value {repr(pathway)} passed to `pathway` was of type ' + f'{type(pathway)}, must be {PathwayBase}.' + ) + raise TypeError(msg) + self._pathway = pathway + + def to_loads(self): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced by a force + actuator that follows a linear pathway. In this example we'll assume + that the force actuator is being used to model a simple linear spring. + First, create a linear pathway between two points separated by the + coordinate ``q`` in the ``x`` direction of the global frame ``N``. + + >>> from sympy.physics.mechanics import (LinearPathway, Point, + ... ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> q = dynamicsymbols('q') + >>> N = ReferenceFrame('N') + >>> pA, pB = Point('pA'), Point('pB') + >>> pB.set_pos(pA, q*N.x) + >>> pathway = LinearPathway(pA, pB) + + Now create a symbol ``k`` to describe the spring's stiffness and + instantiate a force actuator that produces a (contractile) force + proportional to both the spring's stiffness and the pathway's length. + Note that actuator classes use the sign convention that expansile + forces are positive, so for a spring to produce a contractile force the + spring force needs to be calculated as the negative for the stiffness + multiplied by the length. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import ForceActuator + >>> stiffness = symbols('k') + >>> spring_force = -stiffness*pathway.length + >>> spring = ForceActuator(spring_force, pathway) + + The forces produced by the spring can be generated in the list of loads + form that ``KanesMethod`` (and other equations of motion methods) + requires by calling the ``to_loads`` method. + + >>> spring.to_loads() + [(pA, k*q(t)*N.x), (pB, - k*q(t)*N.x)] + + A simple linear damper can be modeled in a similar way. Create another + symbol ``c`` to describe the dampers damping coefficient. This time + instantiate a force actuator that produces a force proportional to both + the damper's damping coefficient and the pathway's extension velocity. + Note that the damping force is negative as it acts in the opposite + direction to which the damper is changing in length. + + >>> damping_coefficient = symbols('c') + >>> damping_force = -damping_coefficient*pathway.extension_velocity + >>> damper = ForceActuator(damping_force, pathway) + + Again, the forces produces by the damper can be generated by calling + the ``to_loads`` method. + + >>> damper.to_loads() + [(pA, c*Derivative(q(t), t)*N.x), (pB, - c*Derivative(q(t), t)*N.x)] + + """ + return self.pathway.to_loads(self.force) + + def __repr__(self): + """Representation of a ``ForceActuator``.""" + return f'{self.__class__.__name__}({self.force}, {self.pathway})' + + +class LinearSpring(ForceActuator): + """A spring with its spring force as a linear function of its length. + + Explanation + =========== + + Note that the "linear" in the name ``LinearSpring`` refers to the fact that + the spring force is a linear function of the springs length. I.e. for a + linear spring with stiffness ``k``, distance between its ends of ``x``, and + an equilibrium length of ``0``, the spring force will be ``-k*x``, which is + a linear function in ``x``. To create a spring that follows a linear, or + straight, pathway between its two ends, a ``LinearPathway`` instance needs + to be passed to the ``pathway`` parameter. + + A ``LinearSpring`` is a subclass of ``ForceActuator`` and so follows the + same sign conventions for length, extension velocity, and the direction of + the forces it applies to its points of attachment on bodies. The sign + convention for the direction of forces is such that, for the case where a + linear spring is instantiated with a ``LinearPathway`` instance as its + pathway, they act to push the two ends of the spring away from one another. + Because springs produces a contractile force and acts to pull the two ends + together towards the equilibrium length when stretched, the scalar portion + of the forces on the endpoint are negative in order to flip the sign of the + forces on the endpoints when converted into vector quantities. The + following diagram shows the positive force sense and the distance between + the points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + To construct a linear spring, an expression (or symbol) must be supplied to + represent the stiffness (spring constant) of the spring, alongside a + pathway specifying its line of action. Let's also create a global reference + frame and spatially fix one of the points in it while setting the other to + be positioned such that it can freely move in the frame's x direction + specified by the coordinate ``q``. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (LinearPathway, LinearSpring, + ... Point, ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> stiffness = symbols('k') + >>> pA, pB = Point('pA'), Point('pB') + >>> pA.set_vel(N, 0) + >>> pB.set_pos(pA, q*N.x) + >>> pB.pos_from(pA) + q(t)*N.x + >>> linear_pathway = LinearPathway(pA, pB) + >>> spring = LinearSpring(stiffness, linear_pathway) + >>> spring + LinearSpring(k, LinearPathway(pA, pB)) + + This spring will produce a force that is proportional to both its stiffness + and the pathway's length. Note that this force is negative as SymPy's sign + convention for actuators is that negative forces are contractile. + + >>> spring.force + -k*sqrt(q(t)**2) + + To create a linear spring with a non-zero equilibrium length, an expression + (or symbol) can be passed to the ``equilibrium_length`` parameter on + construction on a ``LinearSpring`` instance. Let's create a symbol ``l`` + to denote a non-zero equilibrium length and create another linear spring. + + >>> l = symbols('l') + >>> spring = LinearSpring(stiffness, linear_pathway, equilibrium_length=l) + >>> spring + LinearSpring(k, LinearPathway(pA, pB), equilibrium_length=l) + + The spring force of this new spring is again proportional to both its + stiffness and the pathway's length. However, the spring will not produce + any force when ``q(t)`` equals ``l``. Note that the force will become + expansile when ``q(t)`` is less than ``l``, as expected. + + >>> spring.force + -k*(-l + sqrt(q(t)**2)) + + Parameters + ========== + + stiffness : Expr + The spring constant. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + equilibrium_length : Expr, optional + The length at which the spring is in equilibrium, i.e. it produces no + force. The default value is 0, i.e. the spring force is a linear + function of the pathway's length with no constant offset. + + See Also + ======== + + ForceActuator: force-producing actuator (superclass of ``LinearSpring``). + LinearPathway: straight-line pathway between a pair of points. + + """ + + def __init__(self, stiffness, pathway, equilibrium_length=S.Zero): + """Initializer for ``LinearSpring``. + + Parameters + ========== + + stiffness : Expr + The spring constant. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of + a concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + equilibrium_length : Expr, optional + The length at which the spring is in equilibrium, i.e. it produces + no force. The default value is 0, i.e. the spring force is a linear + function of the pathway's length with no constant offset. + + """ + self.stiffness = stiffness + self.pathway = pathway + self.equilibrium_length = equilibrium_length + + @property + def force(self): + """The spring force produced by the linear spring.""" + return -self.stiffness*(self.pathway.length - self.equilibrium_length) + + @force.setter + def force(self, force): + raise AttributeError('Can\'t set computed attribute `force`.') + + @property + def stiffness(self): + """The spring constant for the linear spring.""" + return self._stiffness + + @stiffness.setter + def stiffness(self, stiffness): + if hasattr(self, '_stiffness'): + msg = ( + f'Can\'t set attribute `stiffness` to {repr(stiffness)} as it ' + f'is immutable.' + ) + raise AttributeError(msg) + self._stiffness = sympify(stiffness, strict=True) + + @property + def equilibrium_length(self): + """The length of the spring at which it produces no force.""" + return self._equilibrium_length + + @equilibrium_length.setter + def equilibrium_length(self, equilibrium_length): + if hasattr(self, '_equilibrium_length'): + msg = ( + f'Can\'t set attribute `equilibrium_length` to ' + f'{repr(equilibrium_length)} as it is immutable.' + ) + raise AttributeError(msg) + self._equilibrium_length = sympify(equilibrium_length, strict=True) + + def __repr__(self): + """Representation of a ``LinearSpring``.""" + string = f'{self.__class__.__name__}({self.stiffness}, {self.pathway}' + if self.equilibrium_length == S.Zero: + string += ')' + else: + string += f', equilibrium_length={self.equilibrium_length})' + return string + + +class LinearDamper(ForceActuator): + """A damper whose force is a linear function of its extension velocity. + + Explanation + =========== + + Note that the "linear" in the name ``LinearDamper`` refers to the fact that + the damping force is a linear function of the damper's rate of change in + its length. I.e. for a linear damper with damping ``c`` and extension + velocity ``v``, the damping force will be ``-c*v``, which is a linear + function in ``v``. To create a damper that follows a linear, or straight, + pathway between its two ends, a ``LinearPathway`` instance needs to be + passed to the ``pathway`` parameter. + + A ``LinearDamper`` is a subclass of ``ForceActuator`` and so follows the + same sign conventions for length, extension velocity, and the direction of + the forces it applies to its points of attachment on bodies. The sign + convention for the direction of forces is such that, for the case where a + linear damper is instantiated with a ``LinearPathway`` instance as its + pathway, they act to push the two ends of the damper away from one another. + Because dampers produce a force that opposes the direction of change in + length, when extension velocity is positive the scalar portions of the + forces applied at the two endpoints are negative in order to flip the sign + of the forces on the endpoints wen converted into vector quantities. When + extension velocity is negative (i.e. when the damper is shortening), the + scalar portions of the fofces applied are also negative so that the signs + cancel producing forces on the endpoints that are in the same direction as + the positive sign convention for the forces at the endpoints of the pathway + (i.e. they act to push the endpoints away from one another). The following + diagram shows the positive force sense and the distance between the + points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + To construct a linear damper, an expression (or symbol) must be supplied to + represent the damping coefficient of the damper (we'll use the symbol + ``c``), alongside a pathway specifying its line of action. Let's also + create a global reference frame and spatially fix one of the points in it + while setting the other to be positioned such that it can freely move in + the frame's x direction specified by the coordinate ``q``. The velocity + that the two points move away from one another can be specified by the + coordinate ``u`` where ``u`` is the first time derivative of ``q`` + (i.e., ``u = Derivative(q(t), t)``). + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (LinearDamper, LinearPathway, + ... Point, ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> damping = symbols('c') + >>> pA, pB = Point('pA'), Point('pB') + >>> pA.set_vel(N, 0) + >>> pB.set_pos(pA, q*N.x) + >>> pB.pos_from(pA) + q(t)*N.x + >>> pB.vel(N) + Derivative(q(t), t)*N.x + >>> linear_pathway = LinearPathway(pA, pB) + >>> damper = LinearDamper(damping, linear_pathway) + >>> damper + LinearDamper(c, LinearPathway(pA, pB)) + + This damper will produce a force that is proportional to both its damping + coefficient and the pathway's extension length. Note that this force is + negative as SymPy's sign convention for actuators is that negative forces + are contractile and the damping force of the damper will oppose the + direction of length change. + + >>> damper.force + -c*sqrt(q(t)**2)*Derivative(q(t), t)/q(t) + + Parameters + ========== + + damping : Expr + The damping constant. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + + See Also + ======== + + ForceActuator: force-producing actuator (superclass of ``LinearDamper``). + LinearPathway: straight-line pathway between a pair of points. + + """ + + def __init__(self, damping, pathway): + """Initializer for ``LinearDamper``. + + Parameters + ========== + + damping : Expr + The damping constant. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of + a concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + + """ + self.damping = damping + self.pathway = pathway + + @property + def force(self): + """The damping force produced by the linear damper.""" + return -self.damping*self.pathway.extension_velocity + + @force.setter + def force(self, force): + raise AttributeError('Can\'t set computed attribute `force`.') + + @property + def damping(self): + """The damping constant for the linear damper.""" + return self._damping + + @damping.setter + def damping(self, damping): + if hasattr(self, '_damping'): + msg = ( + f'Can\'t set attribute `damping` to {repr(damping)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + self._damping = sympify(damping, strict=True) + + def __repr__(self): + """Representation of a ``LinearDamper``.""" + return f'{self.__class__.__name__}({self.damping}, {self.pathway})' + + +class TorqueActuator(ActuatorBase): + """Torque-producing actuator. + + Explanation + =========== + + A ``TorqueActuator`` is an actuator that produces a pair of equal and + opposite torques on a pair of bodies. + + Examples + ======== + + To construct a torque actuator, an expression (or symbol) must be supplied + to represent the torque it can produce, alongside a vector specifying the + axis about which the torque will act, and a pair of frames on which the + torque will act. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (ReferenceFrame, RigidBody, + ... TorqueActuator) + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> torque = symbols('T') + >>> axis = N.z + >>> parent = RigidBody('parent', frame=N) + >>> child = RigidBody('child', frame=A) + >>> bodies = (child, parent) + >>> actuator = TorqueActuator(torque, axis, *bodies) + >>> actuator + TorqueActuator(T, axis=N.z, target_frame=A, reaction_frame=N) + + Note that because torques actually act on frames, not bodies, + ``TorqueActuator`` will extract the frame associated with a ``RigidBody`` + when one is passed instead of a ``ReferenceFrame``. + + Parameters + ========== + + torque : Expr + The scalar expression defining the torque that the actuator produces. + axis : Vector + The axis about which the actuator applies torques. + target_frame : ReferenceFrame | RigidBody + The primary frame on which the actuator will apply the torque. + reaction_frame : ReferenceFrame | RigidBody | None + The secondary frame on which the actuator will apply the torque. Note + that the (equal and opposite) reaction torque is applied to this frame. + + """ + + def __init__(self, torque, axis, target_frame, reaction_frame=None): + """Initializer for ``TorqueActuator``. + + Parameters + ========== + + torque : Expr + The scalar expression defining the torque that the actuator + produces. + axis : Vector + The axis about which the actuator applies torques. + target_frame : ReferenceFrame | RigidBody + The primary frame on which the actuator will apply the torque. + reaction_frame : ReferenceFrame | RigidBody | None + The secondary frame on which the actuator will apply the torque. + Note that the (equal and opposite) reaction torque is applied to + this frame. + + """ + self.torque = torque + self.axis = axis + self.target_frame = target_frame + self.reaction_frame = reaction_frame + + @classmethod + def at_pin_joint(cls, torque, pin_joint): + """Alternate constructor to instantiate from a ``PinJoint`` instance. + + Examples + ======== + + To create a pin joint the ``PinJoint`` class requires a name, parent + body, and child body to be passed to its constructor. It is also + possible to control the joint axis using the ``joint_axis`` keyword + argument. In this example let's use the parent body's reference frame's + z-axis as the joint axis. + + >>> from sympy.physics.mechanics import (PinJoint, ReferenceFrame, + ... RigidBody, TorqueActuator) + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> parent = RigidBody('parent', frame=N) + >>> child = RigidBody('child', frame=A) + >>> pin_joint = PinJoint( + ... 'pin', + ... parent, + ... child, + ... joint_axis=N.z, + ... ) + + Let's also create a symbol ``T`` that will represent the torque applied + by the torque actuator. + + >>> from sympy import symbols + >>> torque = symbols('T') + + To create the torque actuator from the ``torque`` and ``pin_joint`` + variables previously instantiated, these can be passed to the alternate + constructor class method ``at_pin_joint`` of the ``TorqueActuator`` + class. It should be noted that a positive torque will cause a positive + displacement of the joint coordinate or that the torque is applied on + the child body with a reaction torque on the parent. + + >>> actuator = TorqueActuator.at_pin_joint(torque, pin_joint) + >>> actuator + TorqueActuator(T, axis=N.z, target_frame=A, reaction_frame=N) + + Parameters + ========== + + torque : Expr + The scalar expression defining the torque that the actuator + produces. + pin_joint : PinJoint + The pin joint, and by association the parent and child bodies, on + which the torque actuator will act. The pair of bodies acted upon + by the torque actuator are the parent and child bodies of the pin + joint, with the child acting as the reaction body. The pin joint's + axis is used as the axis about which the torque actuator will apply + its torque. + + """ + if not isinstance(pin_joint, PinJoint): + msg = ( + f'Value {repr(pin_joint)} passed to `pin_joint` was of type ' + f'{type(pin_joint)}, must be {PinJoint}.' + ) + raise TypeError(msg) + return cls( + torque, + pin_joint.joint_axis, + pin_joint.child_interframe, + pin_joint.parent_interframe, + ) + + @property + def torque(self): + """The magnitude of the torque produced by the actuator.""" + return self._torque + + @torque.setter + def torque(self, torque): + if hasattr(self, '_torque'): + msg = ( + f'Can\'t set attribute `torque` to {repr(torque)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + self._torque = sympify(torque, strict=True) + + @property + def axis(self): + """The axis about which the torque acts.""" + return self._axis + + @axis.setter + def axis(self, axis): + if hasattr(self, '_axis'): + msg = ( + f'Can\'t set attribute `axis` to {repr(axis)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + if not isinstance(axis, Vector): + msg = ( + f'Value {repr(axis)} passed to `axis` was of type ' + f'{type(axis)}, must be {Vector}.' + ) + raise TypeError(msg) + self._axis = axis + + @property + def target_frame(self): + """The primary reference frames on which the torque will act.""" + return self._target_frame + + @target_frame.setter + def target_frame(self, target_frame): + if hasattr(self, '_target_frame'): + msg = ( + f'Can\'t set attribute `target_frame` to {repr(target_frame)} ' + f'as it is immutable.' + ) + raise AttributeError(msg) + if isinstance(target_frame, RigidBody): + target_frame = target_frame.frame + elif not isinstance(target_frame, ReferenceFrame): + msg = ( + f'Value {repr(target_frame)} passed to `target_frame` was of ' + f'type {type(target_frame)}, must be {ReferenceFrame}.' + ) + raise TypeError(msg) + self._target_frame = target_frame + + @property + def reaction_frame(self): + """The primary reference frames on which the torque will act.""" + return self._reaction_frame + + @reaction_frame.setter + def reaction_frame(self, reaction_frame): + if hasattr(self, '_reaction_frame'): + msg = ( + f'Can\'t set attribute `reaction_frame` to ' + f'{repr(reaction_frame)} as it is immutable.' + ) + raise AttributeError(msg) + if isinstance(reaction_frame, RigidBody): + reaction_frame = reaction_frame.frame + elif ( + not isinstance(reaction_frame, ReferenceFrame) + and reaction_frame is not None + ): + msg = ( + f'Value {repr(reaction_frame)} passed to `reaction_frame` was ' + f'of type {type(reaction_frame)}, must be {ReferenceFrame}.' + ) + raise TypeError(msg) + self._reaction_frame = reaction_frame + + def to_loads(self): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced by a torque + actuator that acts on a pair of bodies attached by a pin joint. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (PinJoint, ReferenceFrame, + ... RigidBody, TorqueActuator) + >>> torque = symbols('T') + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> parent = RigidBody('parent', frame=N) + >>> child = RigidBody('child', frame=A) + >>> pin_joint = PinJoint( + ... 'pin', + ... parent, + ... child, + ... joint_axis=N.z, + ... ) + >>> actuator = TorqueActuator.at_pin_joint(torque, pin_joint) + + The forces produces by the damper can be generated by calling the + ``to_loads`` method. + + >>> actuator.to_loads() + [(A, T*N.z), (N, - T*N.z)] + + Alternatively, if a torque actuator is created without a reaction frame + then the loads returned by the ``to_loads`` method will contain just + the single load acting on the target frame. + + >>> actuator = TorqueActuator(torque, N.z, N) + >>> actuator.to_loads() + [(N, T*N.z)] + + """ + loads = [ + Torque(self.target_frame, self.torque*self.axis), + ] + if self.reaction_frame is not None: + loads.append(Torque(self.reaction_frame, -self.torque*self.axis)) + return loads + + def __repr__(self): + """Representation of a ``TorqueActuator``.""" + string = ( + f'{self.__class__.__name__}({self.torque}, axis={self.axis}, ' + f'target_frame={self.target_frame}' + ) + if self.reaction_frame is not None: + string += f', reaction_frame={self.reaction_frame})' + else: + string += ')' + return string + + +class DuffingSpring(ForceActuator): + """A nonlinear spring based on the Duffing equation. + + Explanation + =========== + + Here, ``DuffingSpring`` represents the force exerted by a nonlinear spring based on the Duffing equation: + F = -beta*x-alpha*x**3, where x is the displacement from the equilibrium position, beta is the linear spring constant, + and alpha is the coefficient for the nonlinear cubic term. + + Parameters + ========== + + linear_stiffness : Expr + The linear stiffness coefficient (beta). + nonlinear_stiffness : Expr + The nonlinear stiffness coefficient (alpha). + pathway : PathwayBase + The pathway that the actuator follows. + equilibrium_length : Expr, optional + The length at which the spring is in equilibrium (x). + """ + + def __init__(self, linear_stiffness, nonlinear_stiffness, pathway, equilibrium_length=S.Zero): + self.linear_stiffness = sympify(linear_stiffness, strict=True) + self.nonlinear_stiffness = sympify(nonlinear_stiffness, strict=True) + self.equilibrium_length = sympify(equilibrium_length, strict=True) + + if not isinstance(pathway, PathwayBase): + raise TypeError("pathway must be an instance of PathwayBase.") + self._pathway = pathway + + @property + def linear_stiffness(self): + return self._linear_stiffness + + @linear_stiffness.setter + def linear_stiffness(self, linear_stiffness): + if hasattr(self, '_linear_stiffness'): + msg = ( + f'Can\'t set attribute `linear_stiffness` to ' + f'{repr(linear_stiffness)} as it is immutable.' + ) + raise AttributeError(msg) + self._linear_stiffness = sympify(linear_stiffness, strict=True) + + @property + def nonlinear_stiffness(self): + return self._nonlinear_stiffness + + @nonlinear_stiffness.setter + def nonlinear_stiffness(self, nonlinear_stiffness): + if hasattr(self, '_nonlinear_stiffness'): + msg = ( + f'Can\'t set attribute `nonlinear_stiffness` to ' + f'{repr(nonlinear_stiffness)} as it is immutable.' + ) + raise AttributeError(msg) + self._nonlinear_stiffness = sympify(nonlinear_stiffness, strict=True) + + @property + def pathway(self): + return self._pathway + + @pathway.setter + def pathway(self, pathway): + if hasattr(self, '_pathway'): + msg = ( + f'Can\'t set attribute `pathway` to {repr(pathway)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + if not isinstance(pathway, PathwayBase): + msg = ( + f'Value {repr(pathway)} passed to `pathway` was of type ' + f'{type(pathway)}, must be {PathwayBase}.' + ) + raise TypeError(msg) + self._pathway = pathway + + @property + def equilibrium_length(self): + return self._equilibrium_length + + @equilibrium_length.setter + def equilibrium_length(self, equilibrium_length): + if hasattr(self, '_equilibrium_length'): + msg = ( + f'Can\'t set attribute `equilibrium_length` to ' + f'{repr(equilibrium_length)} as it is immutable.' + ) + raise AttributeError(msg) + self._equilibrium_length = sympify(equilibrium_length, strict=True) + + @property + def force(self): + """The force produced by the Duffing spring.""" + displacement = self.pathway.length - self.equilibrium_length + return -self.linear_stiffness * displacement - self.nonlinear_stiffness * displacement**3 + + @force.setter + def force(self, force): + if hasattr(self, '_force'): + msg = ( + f'Can\'t set attribute `force` to {repr(force)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + self._force = sympify(force, strict=True) + + def __repr__(self): + return (f"{self.__class__.__name__}(" + f"{self.linear_stiffness}, {self.nonlinear_stiffness}, {self.pathway}, " + f"equilibrium_length={self.equilibrium_length})") + +class CoulombKineticFriction(ForceActuator): + r"""Coulomb kinetic friction with Stribeck and viscous effects. + + Explanation + =========== + + This represents a Coulomb kinetic friction with the Stribeck and viscous effect, + described by the function: + + .. math:: + F = (\mu_k f_n + (\mu_s - \mu_k) f_n e^{-(\frac{v}{v_s})^2}) \text{sign}(v) + \sigma v + + where :math:`\mu_k` is the coefficient of kinetic friction, :math:`\mu_s` is the + coefficient of static friction, :math:`f_n` is the normal force, :math:`v` is the + relative velocity, :math:`v_s` is the Stribeck friction coefficient, and + :math:`\sigma` is the viscous friction constant. + + The default friction force is :math:`F = \mu_k f_n`. + When specified, the actuator includes: + + - Stribeck effect: :math:`(\mu_s - \mu_k) f_n e^{-(\frac{v}{v_s})^2}` + - Viscous effect: :math:`\sigma v` + + Notes + ===== + + The actuator makes the following assumptions: + + - The actuator assumes relative motion is non-zero. + - The normal force is assumed to be a non-negative scalar. + - The resultant friction force is opposite to the velocity direction. + - Each point in the pathway is fixed within separate objects that are sliding relative to each other. In other words, these two points are fixed in the mutually sliding objects. + + This actuator has been tested for straightforward motions, like a block sliding + on a surface. + + The friction force is defined to always oppose the direction of relative velocity :math:`v`. + Specifically: + + - The default Coulomb friction force :math:`\mu_k f_n \text{sign}(v)` is opposite to :math:`v`. + - The Stribeck effect :math:`(\mu_s - \mu_k) f_n e^{-(\frac{v}{v_s})^2} \text{sign}(v)` is also opposite to :math:`v`. + - The viscous friction term :math:`\sigma v` is opposite to :math:`v`. + + Examples + ======== + + The below example shows how to generate the loads produced by a Coulomb kinetic + friction actuator in a mass-spring system with friction. + + >>> import sympy as sm + >>> from sympy.physics.mechanics import (dynamicsymbols, ReferenceFrame, Point, + ... LinearPathway, CoulombKineticFriction, LinearSpring, KanesMethod, Particle) + + >>> x, v = dynamicsymbols('x, v', real=True) + >>> m, g, k, mu_k, mu_s, v_s, sigma = sm.symbols('m, g, k, mu_k, mu_s, v_s, sigma') + + >>> N = ReferenceFrame('N') + >>> O, P = Point('O'), Point('P') + >>> O.set_vel(N, 0) + >>> P.set_pos(O, x*N.x) + + >>> pathway = LinearPathway(O, P) + >>> friction = CoulombKineticFriction(mu_k, m*g, pathway, v_s=v_s, sigma=sigma, mu_s=mu_k) + >>> spring = LinearSpring(k, pathway) + >>> block = Particle('block', point=P, mass=m) + + >>> kane = KanesMethod(N, (x,), (v,), kd_eqs=(x.diff() - v,)) + >>> friction.to_loads() + [(O, (g*m*mu_k*sign(sign(x(t))*Derivative(x(t), t)) + sigma*sign(x(t))*Derivative(x(t), t))*x(t)/Abs(x(t))*N.x), (P, (-g*m*mu_k*sign(sign(x(t))*Derivative(x(t), t)) - sigma*sign(x(t))*Derivative(x(t), t))*x(t)/Abs(x(t))*N.x)] + >>> loads = friction.to_loads() + spring.to_loads() + >>> fr, frstar = kane.kanes_equations([block], loads) + >>> eom = fr + frstar + >>> eom + Matrix([[-k*x(t) - m*Derivative(v(t), t) + (-g*m*mu_k*sign(v(t)*sign(x(t))) - sigma*v(t)*sign(x(t)))*x(t)/Abs(x(t))]]) + + Parameters + ========== + + f_n : sympifiable + The normal force between the surfaces. It should always be a non-negative scalar. + mu_k : sympifiable + The coefficient of kinetic friction. + pathway : PathwayBase + The pathway that the actuator follows. + v_s : sympifiable, optional + The Stribeck friction coefficient. + sigma : sympifiable, optional + The viscous friction coefficient. + mu_s : sympifiable, optional + The coefficient of static friction. Defaults to mu_k, meaning the Stribeck effect evaluates to 0 by default. + + References + ========== + + .. [Moore2022] https://moorepants.github.io/learn-multibody-dynamics/loads.html#friction. + .. [Flores2023] Paulo Flores, Jorge Ambrosio, Hamid M. Lankarani, + "Contact-impact events with friction in multibody dynamics: Back to basics", + Mechanism and Machine Theory, vol. 184, 2023. https://doi.org/10.1016/j.mechmachtheory.2023.105305. + .. [Rogner2017] I. Rogner, "Friction modelling for robotic applications with planar motion", + Chalmers University of Technology, Department of Electrical Engineering, 2017. + + """ + + def __init__(self, mu_k, f_n, pathway, *, v_s=None, sigma=None, mu_s=None): + self._mu_k = sympify(mu_k, strict=True) if mu_k is not None else 1 + self._mu_s = sympify(mu_s, strict=True) if mu_s is not None else self._mu_k + self._f_n = sympify(f_n, strict=True) + self._sigma = sympify(sigma, strict=True) if sigma is not None else 0 + self._v_s = sympify(v_s, strict=True) if v_s is not None or v_s == 0 else 0.01 + self.pathway = pathway + + @property + def mu_k(self): + """The coefficient of kinetic friction.""" + return self._mu_k + + @property + def mu_s(self): + """The coefficient of static friction.""" + return self._mu_s + + @property + def f_n(self): + """The normal force between the surfaces.""" + return self._f_n + + @property + def sigma(self): + """The viscous friction coefficient.""" + return self._sigma + + @property + def v_s(self): + """The Stribeck friction coefficient.""" + return self._v_s + + @property + def force(self): + v = self.pathway.extension_velocity + f_c = self.mu_k * self.f_n + f_max = self.mu_s * self.f_n + stribeck_term = (f_max - f_c) * exp(-(v / self.v_s)**2) if self.v_s is not None else 0 + viscous_term = self.sigma * v if self.sigma is not None else 0 + return (f_c + stribeck_term) * -sign(v) - viscous_term + + @force.setter + def force(self, force): + raise AttributeError('Can\'t set computed attribute `force`.') + + def __repr__(self): + return (f'{self.__class__.__name__}({self._mu_k}, {self._mu_s} ' + f'{self._f_n}, {self.pathway}, {self._v_s}, ' + f'{self._sigma})') diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/body.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/body.py new file mode 100644 index 0000000000000000000000000000000000000000..efc367158bbf51e7d9929318ac9286ba5c3fb3ac --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/body.py @@ -0,0 +1,710 @@ +from sympy import Symbol +from sympy.physics.vector import Point, Vector, ReferenceFrame, Dyadic +from sympy.physics.mechanics import RigidBody, Particle, Inertia +from sympy.physics.mechanics.body_base import BodyBase +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['Body'] + + +# XXX: We use type:ignore because the classes RigidBody and Particle have +# inconsistent parallel axis methods that take different numbers of arguments. +class Body(RigidBody, Particle): # type: ignore + """ + Body is a common representation of either a RigidBody or a Particle SymPy + object depending on what is passed in during initialization. If a mass is + passed in and central_inertia is left as None, the Particle object is + created. Otherwise a RigidBody object will be created. + + .. deprecated:: 1.13 + The Body class is deprecated. Its functionality is captured by + :class:`~.RigidBody` and :class:`~.Particle`. + + Explanation + =========== + + The attributes that Body possesses will be the same as a Particle instance + or a Rigid Body instance depending on which was created. Additional + attributes are listed below. + + Attributes + ========== + + name : string + The body's name + masscenter : Point + The point which represents the center of mass of the rigid body + frame : ReferenceFrame + The reference frame which the body is fixed in + mass : Sympifyable + The body's mass + inertia : (Dyadic, Point) + The body's inertia around its center of mass. This attribute is specific + to the rigid body form of Body and is left undefined for the Particle + form + loads : iterable + This list contains information on the different loads acting on the + Body. Forces are listed as a (point, vector) tuple and torques are + listed as (reference frame, vector) tuples. + + Parameters + ========== + + name : String + Defines the name of the body. It is used as the base for defining + body specific properties. + masscenter : Point, optional + A point that represents the center of mass of the body or particle. + If no point is given, a point is generated. + mass : Sympifyable, optional + A Sympifyable object which represents the mass of the body. If no + mass is passed, one is generated. + frame : ReferenceFrame, optional + The ReferenceFrame that represents the reference frame of the body. + If no frame is given, a frame is generated. + central_inertia : Dyadic, optional + Central inertia dyadic of the body. If none is passed while creating + RigidBody, a default inertia is generated. + + Examples + ======== + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + + Default behaviour. This results in the creation of a RigidBody object for + which the mass, mass center, frame and inertia attributes are given default + values. :: + + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... body = Body('name_of_body') + + This next example demonstrates the code required to specify all of the + values of the Body object. Note this will also create a RigidBody version of + the Body object. :: + + >>> from sympy import Symbol + >>> from sympy.physics.mechanics import ReferenceFrame, Point, inertia + >>> from sympy.physics.mechanics import Body + >>> mass = Symbol('mass') + >>> masscenter = Point('masscenter') + >>> frame = ReferenceFrame('frame') + >>> ixx = Symbol('ixx') + >>> body_inertia = inertia(frame, ixx, 0, 0) + >>> with ignore_warnings(DeprecationWarning): + ... body = Body('name_of_body', masscenter, mass, frame, body_inertia) + + The minimal code required to create a Particle version of the Body object + involves simply passing in a name and a mass. :: + + >>> from sympy import Symbol + >>> from sympy.physics.mechanics import Body + >>> mass = Symbol('mass') + >>> with ignore_warnings(DeprecationWarning): + ... body = Body('name_of_body', mass=mass) + + The Particle version of the Body object can also receive a masscenter point + and a reference frame, just not an inertia. + """ + + def __init__(self, name, masscenter=None, mass=None, frame=None, + central_inertia=None): + sympy_deprecation_warning( + """ + Support for the Body class has been removed, as its functionality is + fully captured by RigidBody and Particle. + """, + deprecated_since_version="1.13", + active_deprecations_target="deprecated-mechanics-body-class" + ) + + self._loads = [] + + if frame is None: + frame = ReferenceFrame(name + '_frame') + + if masscenter is None: + masscenter = Point(name + '_masscenter') + + if central_inertia is None and mass is None: + ixx = Symbol(name + '_ixx') + iyy = Symbol(name + '_iyy') + izz = Symbol(name + '_izz') + izx = Symbol(name + '_izx') + ixy = Symbol(name + '_ixy') + iyz = Symbol(name + '_iyz') + _inertia = Inertia.from_inertia_scalars(masscenter, frame, ixx, iyy, + izz, ixy, iyz, izx) + else: + _inertia = (central_inertia, masscenter) + + if mass is None: + _mass = Symbol(name + '_mass') + else: + _mass = mass + + masscenter.set_vel(frame, 0) + + # If user passes masscenter and mass then a particle is created + # otherwise a rigidbody. As a result a body may or may not have inertia. + # Note: BodyBase.__init__ is used to prevent problems with super() calls in + # Particle and RigidBody arising due to multiple inheritance. + if central_inertia is None and mass is not None: + BodyBase.__init__(self, name, masscenter, _mass) + self.frame = frame + self._central_inertia = Dyadic(0) + else: + BodyBase.__init__(self, name, masscenter, _mass) + self.frame = frame + self.inertia = _inertia + + def __repr__(self): + if self.is_rigidbody: + return RigidBody.__repr__(self) + return Particle.__repr__(self) + + @property + def loads(self): + return self._loads + + @property + def x(self): + """The basis Vector for the Body, in the x direction.""" + return self.frame.x + + @property + def y(self): + """The basis Vector for the Body, in the y direction.""" + return self.frame.y + + @property + def z(self): + """The basis Vector for the Body, in the z direction.""" + return self.frame.z + + @property + def inertia(self): + """The body's inertia about a point; stored as (Dyadic, Point).""" + if self.is_rigidbody: + return RigidBody.inertia.fget(self) + return (self.central_inertia, self.masscenter) + + @inertia.setter + def inertia(self, I): + RigidBody.inertia.fset(self, I) + + @property + def is_rigidbody(self): + if hasattr(self, '_inertia'): + return True + return False + + def kinetic_energy(self, frame): + """Kinetic energy of the body. + + Parameters + ========== + + frame : ReferenceFrame or Body + The Body's angular velocity and the velocity of it's mass + center are typically defined with respect to an inertial frame but + any relevant frame in which the velocities are known can be supplied. + + Examples + ======== + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body, ReferenceFrame, Point + >>> from sympy import symbols + >>> m, v, r, omega = symbols('m v r omega') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> with ignore_warnings(DeprecationWarning): + ... P = Body('P', masscenter=O, mass=m) + >>> P.masscenter.set_vel(N, v * N.y) + >>> P.kinetic_energy(N) + m*v**2/2 + + >>> N = ReferenceFrame('N') + >>> b = ReferenceFrame('b') + >>> b.set_ang_vel(N, omega * b.x) + >>> P = Point('P') + >>> P.set_vel(N, v * N.x) + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B', masscenter=P, frame=b) + >>> B.kinetic_energy(N) + B_ixx*omega**2/2 + B_mass*v**2/2 + + See Also + ======== + + sympy.physics.mechanics : Particle, RigidBody + + """ + if isinstance(frame, Body): + frame = Body.frame + if self.is_rigidbody: + return RigidBody(self.name, self.masscenter, self.frame, self.mass, + (self.central_inertia, self.masscenter)).kinetic_energy(frame) + return Particle(self.name, self.masscenter, self.mass).kinetic_energy(frame) + + def apply_force(self, force, point=None, reaction_body=None, reaction_point=None): + """Add force to the body(s). + + Explanation + =========== + + Applies the force on self or equal and opposite forces on + self and other body if both are given on the desired point on the bodies. + The force applied on other body is taken opposite of self, i.e, -force. + + Parameters + ========== + + force: Vector + The force to be applied. + point: Point, optional + The point on self on which force is applied. + By default self's masscenter. + reaction_body: Body, optional + Second body on which equal and opposite force + is to be applied. + reaction_point : Point, optional + The point on other body on which equal and opposite + force is applied. By default masscenter of other body. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Body, Point, dynamicsymbols + >>> m, g = symbols('m g') + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B') + >>> force1 = m*g*B.z + >>> B.apply_force(force1) #Applying force on B's masscenter + >>> B.loads + [(B_masscenter, g*m*B_frame.z)] + + We can also remove some part of force from any point on the body by + adding the opposite force to the body on that point. + + >>> f1, f2 = dynamicsymbols('f1 f2') + >>> P = Point('P') #Considering point P on body B + >>> B.apply_force(f1*B.x + f2*B.y, P) + >>> B.loads + [(B_masscenter, g*m*B_frame.z), (P, f1(t)*B_frame.x + f2(t)*B_frame.y)] + + Let's remove f1 from point P on body B. + + >>> B.apply_force(-f1*B.x, P) + >>> B.loads + [(B_masscenter, g*m*B_frame.z), (P, f2(t)*B_frame.y)] + + To further demonstrate the use of ``apply_force`` attribute, + consider two bodies connected through a spring. + + >>> from sympy.physics.mechanics import Body, dynamicsymbols + >>> with ignore_warnings(DeprecationWarning): + ... N = Body('N') #Newtonion Frame + >>> x = dynamicsymbols('x') + >>> with ignore_warnings(DeprecationWarning): + ... B1 = Body('B1') + ... B2 = Body('B2') + >>> spring_force = x*N.x + + Now let's apply equal and opposite spring force to the bodies. + + >>> P1 = Point('P1') + >>> P2 = Point('P2') + >>> B1.apply_force(spring_force, point=P1, reaction_body=B2, reaction_point=P2) + + We can check the loads(forces) applied to bodies now. + + >>> B1.loads + [(P1, x(t)*N_frame.x)] + >>> B2.loads + [(P2, - x(t)*N_frame.x)] + + Notes + ===== + + If a new force is applied to a body on a point which already has some + force applied on it, then the new force is added to the already applied + force on that point. + + """ + + if not isinstance(point, Point): + if point is None: + point = self.masscenter # masscenter + else: + raise TypeError("Force must be applied to a point on the body.") + if not isinstance(force, Vector): + raise TypeError("Force must be a vector.") + + if reaction_body is not None: + reaction_body.apply_force(-force, point=reaction_point) + + for load in self._loads: + if point in load: + force += load[1] + self._loads.remove(load) + break + + self._loads.append((point, force)) + + def apply_torque(self, torque, reaction_body=None): + """Add torque to the body(s). + + Explanation + =========== + + Applies the torque on self or equal and opposite torques on + self and other body if both are given. + The torque applied on other body is taken opposite of self, + i.e, -torque. + + Parameters + ========== + + torque: Vector + The torque to be applied. + reaction_body: Body, optional + Second body on which equal and opposite torque + is to be applied. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Body, dynamicsymbols + >>> t = symbols('t') + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B') + >>> torque1 = t*B.z + >>> B.apply_torque(torque1) + >>> B.loads + [(B_frame, t*B_frame.z)] + + We can also remove some part of torque from the body by + adding the opposite torque to the body. + + >>> t1, t2 = dynamicsymbols('t1 t2') + >>> B.apply_torque(t1*B.x + t2*B.y) + >>> B.loads + [(B_frame, t1(t)*B_frame.x + t2(t)*B_frame.y + t*B_frame.z)] + + Let's remove t1 from Body B. + + >>> B.apply_torque(-t1*B.x) + >>> B.loads + [(B_frame, t2(t)*B_frame.y + t*B_frame.z)] + + To further demonstrate the use, let us consider two bodies such that + a torque `T` is acting on one body, and `-T` on the other. + + >>> from sympy.physics.mechanics import Body, dynamicsymbols + >>> with ignore_warnings(DeprecationWarning): + ... N = Body('N') #Newtonion frame + ... B1 = Body('B1') + ... B2 = Body('B2') + >>> v = dynamicsymbols('v') + >>> T = v*N.y #Torque + + Now let's apply equal and opposite torque to the bodies. + + >>> B1.apply_torque(T, B2) + + We can check the loads (torques) applied to bodies now. + + >>> B1.loads + [(B1_frame, v(t)*N_frame.y)] + >>> B2.loads + [(B2_frame, - v(t)*N_frame.y)] + + Notes + ===== + + If a new torque is applied on body which already has some torque applied on it, + then the new torque is added to the previous torque about the body's frame. + + """ + + if not isinstance(torque, Vector): + raise TypeError("A Vector must be supplied to add torque.") + + if reaction_body is not None: + reaction_body.apply_torque(-torque) + + for load in self._loads: + if self.frame in load: + torque += load[1] + self._loads.remove(load) + break + self._loads.append((self.frame, torque)) + + def clear_loads(self): + """ + Clears the Body's loads list. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B') + >>> force = B.x + B.y + >>> B.apply_force(force) + >>> B.loads + [(B_masscenter, B_frame.x + B_frame.y)] + >>> B.clear_loads() + >>> B.loads + [] + + """ + + self._loads = [] + + def remove_load(self, about=None): + """ + Remove load about a point or frame. + + Parameters + ========== + + about : Point or ReferenceFrame, optional + The point about which force is applied, + and is to be removed. + If about is None, then the torque about + self's frame is removed. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body, Point + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B') + >>> P = Point('P') + >>> f1 = B.x + >>> f2 = B.y + >>> B.apply_force(f1) + >>> B.apply_force(f2, P) + >>> B.loads + [(B_masscenter, B_frame.x), (P, B_frame.y)] + + >>> B.remove_load(P) + >>> B.loads + [(B_masscenter, B_frame.x)] + + """ + + if about is not None: + if not isinstance(about, Point): + raise TypeError('Load is applied about Point or ReferenceFrame.') + else: + about = self.frame + + for load in self._loads: + if about in load: + self._loads.remove(load) + break + + def masscenter_vel(self, body): + """ + Returns the velocity of the mass center with respect to the provided + rigid body or reference frame. + + Parameters + ========== + + body: Body or ReferenceFrame + The rigid body or reference frame to calculate the velocity in. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... A = Body('A') + ... B = Body('B') + >>> A.masscenter.set_vel(B.frame, 5*B.frame.x) + >>> A.masscenter_vel(B) + 5*B_frame.x + >>> A.masscenter_vel(B.frame) + 5*B_frame.x + + """ + + if isinstance(body, ReferenceFrame): + frame=body + elif isinstance(body, Body): + frame = body.frame + return self.masscenter.vel(frame) + + def ang_vel_in(self, body): + """ + Returns this body's angular velocity with respect to the provided + rigid body or reference frame. + + Parameters + ========== + + body: Body or ReferenceFrame + The rigid body or reference frame to calculate the angular velocity in. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body, ReferenceFrame + >>> with ignore_warnings(DeprecationWarning): + ... A = Body('A') + >>> N = ReferenceFrame('N') + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B', frame=N) + >>> A.frame.set_ang_vel(N, 5*N.x) + >>> A.ang_vel_in(B) + 5*N.x + >>> A.ang_vel_in(N) + 5*N.x + + """ + + if isinstance(body, ReferenceFrame): + frame=body + elif isinstance(body, Body): + frame = body.frame + return self.frame.ang_vel_in(frame) + + def dcm(self, body): + """ + Returns the direction cosine matrix of this body relative to the + provided rigid body or reference frame. + + Parameters + ========== + + body: Body or ReferenceFrame + The rigid body or reference frame to calculate the dcm. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... A = Body('A') + ... B = Body('B') + >>> A.frame.orient_axis(B.frame, B.frame.x, 5) + >>> A.dcm(B) + Matrix([ + [1, 0, 0], + [0, cos(5), sin(5)], + [0, -sin(5), cos(5)]]) + >>> A.dcm(B.frame) + Matrix([ + [1, 0, 0], + [0, cos(5), sin(5)], + [0, -sin(5), cos(5)]]) + + """ + + if isinstance(body, ReferenceFrame): + frame=body + elif isinstance(body, Body): + frame = body.frame + return self.frame.dcm(frame) + + def parallel_axis(self, point, frame=None): + """Returns the inertia dyadic of the body with respect to another + point. + + Parameters + ========== + + point : sympy.physics.vector.Point + The point to express the inertia dyadic about. + frame : sympy.physics.vector.ReferenceFrame + The reference frame used to construct the dyadic. + + Returns + ======= + + inertia : sympy.physics.vector.Dyadic + The inertia dyadic of the rigid body expressed about the provided + point. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... A = Body('A') + >>> P = A.masscenter.locatenew('point', 3 * A.x + 5 * A.y) + >>> A.parallel_axis(P).to_matrix(A.frame) + Matrix([ + [A_ixx + 25*A_mass, A_ixy - 15*A_mass, A_izx], + [A_ixy - 15*A_mass, A_iyy + 9*A_mass, A_iyz], + [ A_izx, A_iyz, A_izz + 34*A_mass]]) + + """ + if self.is_rigidbody: + return RigidBody.parallel_axis(self, point, frame) + return Particle.parallel_axis(self, point, frame) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/body_base.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/body_base.py new file mode 100644 index 0000000000000000000000000000000000000000..d2546faf685f579d2aea10ed7f139a4beced7dd0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/body_base.py @@ -0,0 +1,94 @@ +from abc import ABC, abstractmethod +from sympy import Symbol, sympify +from sympy.physics.vector import Point + +__all__ = ['BodyBase'] + + +class BodyBase(ABC): + """Abstract class for body type objects.""" + def __init__(self, name, masscenter=None, mass=None): + # Note: If frame=None, no auto-generated frame is created, because a + # Particle does not need to have a frame by default. + if not isinstance(name, str): + raise TypeError('Supply a valid name.') + self._name = name + if mass is None: + mass = Symbol(f'{name}_mass') + if masscenter is None: + masscenter = Point(f'{name}_masscenter') + self.mass = mass + self.masscenter = masscenter + self.potential_energy = 0 + self.points = [] + + def __str__(self): + return self.name + + def __repr__(self): + return (f'{self.__class__.__name__}({repr(self.name)}, masscenter=' + f'{repr(self.masscenter)}, mass={repr(self.mass)})') + + @property + def name(self): + """The name of the body.""" + return self._name + + @property + def masscenter(self): + """The body's center of mass.""" + return self._masscenter + + @masscenter.setter + def masscenter(self, point): + if not isinstance(point, Point): + raise TypeError("The body's center of mass must be a Point object.") + self._masscenter = point + + @property + def mass(self): + """The body's mass.""" + return self._mass + + @mass.setter + def mass(self, mass): + self._mass = sympify(mass) + + @property + def potential_energy(self): + """The potential energy of the body. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point + >>> from sympy import symbols + >>> m, g, h = symbols('m g h') + >>> O = Point('O') + >>> P = Particle('P', O, m) + >>> P.potential_energy = m * g * h + >>> P.potential_energy + g*h*m + + """ + return self._potential_energy + + @potential_energy.setter + def potential_energy(self, scalar): + self._potential_energy = sympify(scalar) + + @abstractmethod + def kinetic_energy(self, frame): + pass + + @abstractmethod + def linear_momentum(self, frame): + pass + + @abstractmethod + def angular_momentum(self, point, frame): + pass + + @abstractmethod + def parallel_axis(self, point, frame): + pass diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/functions.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..42abe2b7fe608b4602cdab518f209b446b2dbe03 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/functions.py @@ -0,0 +1,735 @@ +from sympy.utilities import dict_merge +from sympy.utilities.iterables import iterable +from sympy.physics.vector import (Dyadic, Vector, ReferenceFrame, + Point, dynamicsymbols) +from sympy.physics.vector.printing import (vprint, vsprint, vpprint, vlatex, + init_vprinting) +from sympy.physics.mechanics.particle import Particle +from sympy.physics.mechanics.rigidbody import RigidBody +from sympy.simplify.simplify import simplify +from sympy import Matrix, Mul, Derivative, sin, cos, tan, S +from sympy.core.function import AppliedUndef +from sympy.physics.mechanics.inertia import (inertia as _inertia, + inertia_of_point_mass as _inertia_of_point_mass) +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['linear_momentum', + 'angular_momentum', + 'kinetic_energy', + 'potential_energy', + 'Lagrangian', + 'mechanics_printing', + 'mprint', + 'msprint', + 'mpprint', + 'mlatex', + 'msubs', + 'find_dynamicsymbols'] + +# These are functions that we've moved and renamed during extracting the +# basic vector calculus code from the mechanics packages. + +mprint = vprint +msprint = vsprint +mpprint = vpprint +mlatex = vlatex + + +def mechanics_printing(**kwargs): + """ + Initializes time derivative printing for all SymPy objects in + mechanics module. + """ + + init_vprinting(**kwargs) + +mechanics_printing.__doc__ = init_vprinting.__doc__ + + +def inertia(frame, ixx, iyy, izz, ixy=0, iyz=0, izx=0): + sympy_deprecation_warning( + """ + The inertia function has been moved. + Import it from "sympy.physics.mechanics". + """, + deprecated_since_version="1.13", + active_deprecations_target="moved-mechanics-functions" + ) + return _inertia(frame, ixx, iyy, izz, ixy, iyz, izx) + + +def inertia_of_point_mass(mass, pos_vec, frame): + sympy_deprecation_warning( + """ + The inertia_of_point_mass function has been moved. + Import it from "sympy.physics.mechanics". + """, + deprecated_since_version="1.13", + active_deprecations_target="moved-mechanics-functions" + ) + return _inertia_of_point_mass(mass, pos_vec, frame) + + +def linear_momentum(frame, *body): + """Linear momentum of the system. + + Explanation + =========== + + This function returns the linear momentum of a system of Particle's and/or + RigidBody's. The linear momentum of a system is equal to the vector sum of + the linear momentum of its constituents. Consider a system, S, comprised of + a rigid body, A, and a particle, P. The linear momentum of the system, L, + is equal to the vector sum of the linear momentum of the particle, L1, and + the linear momentum of the rigid body, L2, i.e. + + L = L1 + L2 + + Parameters + ========== + + frame : ReferenceFrame + The frame in which linear momentum is desired. + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose linear momentum is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, linear_momentum + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, 10 * N.x) + >>> Pa = Particle('Pa', P, 1) + >>> Ac = Point('Ac') + >>> Ac.set_vel(N, 25 * N.y) + >>> I = outer(N.x, N.x) + >>> A = RigidBody('A', Ac, N, 20, (I, Ac)) + >>> linear_momentum(N, A, Pa) + 10*N.x + 500*N.y + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Please specify a valid ReferenceFrame') + else: + linear_momentum_sys = Vector(0) + for e in body: + if isinstance(e, (RigidBody, Particle)): + linear_momentum_sys += e.linear_momentum(frame) + else: + raise TypeError('*body must have only Particle or RigidBody') + return linear_momentum_sys + + +def angular_momentum(point, frame, *body): + """Angular momentum of a system. + + Explanation + =========== + + This function returns the angular momentum of a system of Particle's and/or + RigidBody's. The angular momentum of such a system is equal to the vector + sum of the angular momentum of its constituents. Consider a system, S, + comprised of a rigid body, A, and a particle, P. The angular momentum of + the system, H, is equal to the vector sum of the angular momentum of the + particle, H1, and the angular momentum of the rigid body, H2, i.e. + + H = H1 + H2 + + Parameters + ========== + + point : Point + The point about which angular momentum of the system is desired. + frame : ReferenceFrame + The frame in which angular momentum is desired. + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose angular momentum is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, angular_momentum + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> O.set_vel(N, 0 * N.x) + >>> P = O.locatenew('P', 1 * N.x) + >>> P.set_vel(N, 10 * N.x) + >>> Pa = Particle('Pa', P, 1) + >>> Ac = O.locatenew('Ac', 2 * N.y) + >>> Ac.set_vel(N, 5 * N.y) + >>> a = ReferenceFrame('a') + >>> a.set_ang_vel(N, 10 * N.z) + >>> I = outer(N.z, N.z) + >>> A = RigidBody('A', Ac, a, 20, (I, Ac)) + >>> angular_momentum(O, N, Pa, A) + 10*N.z + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Please enter a valid ReferenceFrame') + if not isinstance(point, Point): + raise TypeError('Please specify a valid Point') + else: + angular_momentum_sys = Vector(0) + for e in body: + if isinstance(e, (RigidBody, Particle)): + angular_momentum_sys += e.angular_momentum(point, frame) + else: + raise TypeError('*body must have only Particle or RigidBody') + return angular_momentum_sys + + +def kinetic_energy(frame, *body): + """Kinetic energy of a multibody system. + + Explanation + =========== + + This function returns the kinetic energy of a system of Particle's and/or + RigidBody's. The kinetic energy of such a system is equal to the sum of + the kinetic energies of its constituents. Consider a system, S, comprising + a rigid body, A, and a particle, P. The kinetic energy of the system, T, + is equal to the vector sum of the kinetic energy of the particle, T1, and + the kinetic energy of the rigid body, T2, i.e. + + T = T1 + T2 + + Kinetic energy is a scalar. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which the velocity or angular velocity of the body is + defined. + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose kinetic energy is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, kinetic_energy + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> O.set_vel(N, 0 * N.x) + >>> P = O.locatenew('P', 1 * N.x) + >>> P.set_vel(N, 10 * N.x) + >>> Pa = Particle('Pa', P, 1) + >>> Ac = O.locatenew('Ac', 2 * N.y) + >>> Ac.set_vel(N, 5 * N.y) + >>> a = ReferenceFrame('a') + >>> a.set_ang_vel(N, 10 * N.z) + >>> I = outer(N.z, N.z) + >>> A = RigidBody('A', Ac, a, 20, (I, Ac)) + >>> kinetic_energy(N, Pa, A) + 350 + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Please enter a valid ReferenceFrame') + ke_sys = S.Zero + for e in body: + if isinstance(e, (RigidBody, Particle)): + ke_sys += e.kinetic_energy(frame) + else: + raise TypeError('*body must have only Particle or RigidBody') + return ke_sys + + +def potential_energy(*body): + """Potential energy of a multibody system. + + Explanation + =========== + + This function returns the potential energy of a system of Particle's and/or + RigidBody's. The potential energy of such a system is equal to the sum of + the potential energy of its constituents. Consider a system, S, comprising + a rigid body, A, and a particle, P. The potential energy of the system, V, + is equal to the vector sum of the potential energy of the particle, V1, and + the potential energy of the rigid body, V2, i.e. + + V = V1 + V2 + + Potential energy is a scalar. + + Parameters + ========== + + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose potential energy is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, potential_energy + >>> from sympy import symbols + >>> M, m, g, h = symbols('M m g h') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> O.set_vel(N, 0 * N.x) + >>> P = O.locatenew('P', 1 * N.x) + >>> Pa = Particle('Pa', P, m) + >>> Ac = O.locatenew('Ac', 2 * N.y) + >>> a = ReferenceFrame('a') + >>> I = outer(N.z, N.z) + >>> A = RigidBody('A', Ac, a, M, (I, Ac)) + >>> Pa.potential_energy = m * g * h + >>> A.potential_energy = M * g * h + >>> potential_energy(Pa, A) + M*g*h + g*h*m + + """ + + pe_sys = S.Zero + for e in body: + if isinstance(e, (RigidBody, Particle)): + pe_sys += e.potential_energy + else: + raise TypeError('*body must have only Particle or RigidBody') + return pe_sys + + +def gravity(acceleration, *bodies): + from sympy.physics.mechanics.loads import gravity as _gravity + sympy_deprecation_warning( + """ + The gravity function has been moved. + Import it from "sympy.physics.mechanics.loads". + """, + deprecated_since_version="1.13", + active_deprecations_target="moved-mechanics-functions" + ) + return _gravity(acceleration, *bodies) + + +def center_of_mass(point, *bodies): + """ + Returns the position vector from the given point to the center of mass + of the given bodies(particles or rigidbodies). + + Example + ======= + + >>> from sympy import symbols, S + >>> from sympy.physics.vector import Point + >>> from sympy.physics.mechanics import Particle, ReferenceFrame, RigidBody, outer + >>> from sympy.physics.mechanics.functions import center_of_mass + >>> a = ReferenceFrame('a') + >>> m = symbols('m', real=True) + >>> p1 = Particle('p1', Point('p1_pt'), S(1)) + >>> p2 = Particle('p2', Point('p2_pt'), S(2)) + >>> p3 = Particle('p3', Point('p3_pt'), S(3)) + >>> p4 = Particle('p4', Point('p4_pt'), m) + >>> b_f = ReferenceFrame('b_f') + >>> b_cm = Point('b_cm') + >>> mb = symbols('mb') + >>> b = RigidBody('b', b_cm, b_f, mb, (outer(b_f.x, b_f.x), b_cm)) + >>> p2.point.set_pos(p1.point, a.x) + >>> p3.point.set_pos(p1.point, a.x + a.y) + >>> p4.point.set_pos(p1.point, a.y) + >>> b.masscenter.set_pos(p1.point, a.y + a.z) + >>> point_o=Point('o') + >>> point_o.set_pos(p1.point, center_of_mass(p1.point, p1, p2, p3, p4, b)) + >>> expr = 5/(m + mb + 6)*a.x + (m + mb + 3)/(m + mb + 6)*a.y + mb/(m + mb + 6)*a.z + >>> point_o.pos_from(p1.point) + 5/(m + mb + 6)*a.x + (m + mb + 3)/(m + mb + 6)*a.y + mb/(m + mb + 6)*a.z + + """ + if not bodies: + raise TypeError("No bodies(instances of Particle or Rigidbody) were passed.") + + total_mass = 0 + vec = Vector(0) + for i in bodies: + total_mass += i.mass + + masscenter = getattr(i, 'masscenter', None) + if masscenter is None: + masscenter = i.point + vec += i.mass*masscenter.pos_from(point) + + return vec/total_mass + + +def Lagrangian(frame, *body): + """Lagrangian of a multibody system. + + Explanation + =========== + + This function returns the Lagrangian of a system of Particle's and/or + RigidBody's. The Lagrangian of such a system is equal to the difference + between the kinetic energies and potential energies of its constituents. If + T and V are the kinetic and potential energies of a system then it's + Lagrangian, L, is defined as + + L = T - V + + The Lagrangian is a scalar. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which the velocity or angular velocity of the body is + defined to determine the kinetic energy. + + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose Lagrangian is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, Lagrangian + >>> from sympy import symbols + >>> M, m, g, h = symbols('M m g h') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> O.set_vel(N, 0 * N.x) + >>> P = O.locatenew('P', 1 * N.x) + >>> P.set_vel(N, 10 * N.x) + >>> Pa = Particle('Pa', P, 1) + >>> Ac = O.locatenew('Ac', 2 * N.y) + >>> Ac.set_vel(N, 5 * N.y) + >>> a = ReferenceFrame('a') + >>> a.set_ang_vel(N, 10 * N.z) + >>> I = outer(N.z, N.z) + >>> A = RigidBody('A', Ac, a, 20, (I, Ac)) + >>> Pa.potential_energy = m * g * h + >>> A.potential_energy = M * g * h + >>> Lagrangian(N, Pa, A) + -M*g*h - g*h*m + 350 + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Please supply a valid ReferenceFrame') + for e in body: + if not isinstance(e, (RigidBody, Particle)): + raise TypeError('*body must have only Particle or RigidBody') + return kinetic_energy(frame, *body) - potential_energy(*body) + + +def find_dynamicsymbols(expression, exclude=None, reference_frame=None): + """Find all dynamicsymbols in expression. + + Explanation + =========== + + If the optional ``exclude`` kwarg is used, only dynamicsymbols + not in the iterable ``exclude`` are returned. + If we intend to apply this function on a vector, the optional + ``reference_frame`` is also used to inform about the corresponding frame + with respect to which the dynamic symbols of the given vector is to be + determined. + + Parameters + ========== + + expression : SymPy expression + + exclude : iterable of dynamicsymbols, optional + + reference_frame : ReferenceFrame, optional + The frame with respect to which the dynamic symbols of the + given vector is to be determined. + + Examples + ======== + + >>> from sympy.physics.mechanics import dynamicsymbols, find_dynamicsymbols + >>> from sympy.physics.mechanics import ReferenceFrame + >>> x, y = dynamicsymbols('x, y') + >>> expr = x + x.diff()*y + >>> find_dynamicsymbols(expr) + {x(t), y(t), Derivative(x(t), t)} + >>> find_dynamicsymbols(expr, exclude=[x, y]) + {Derivative(x(t), t)} + >>> a, b, c = dynamicsymbols('a, b, c') + >>> A = ReferenceFrame('A') + >>> v = a * A.x + b * A.y + c * A.z + >>> find_dynamicsymbols(v, reference_frame=A) + {a(t), b(t), c(t)} + + """ + t_set = {dynamicsymbols._t} + if exclude: + if iterable(exclude): + exclude_set = set(exclude) + else: + raise TypeError("exclude kwarg must be iterable") + else: + exclude_set = set() + if isinstance(expression, Vector): + if reference_frame is None: + raise ValueError("You must provide reference_frame when passing a " + "vector expression, got %s." % reference_frame) + else: + expression = expression.to_matrix(reference_frame) + return {i for i in expression.atoms(AppliedUndef, Derivative) if + i.free_symbols == t_set} - exclude_set + + +def msubs(expr, *sub_dicts, smart=False, **kwargs): + """A custom subs for use on expressions derived in physics.mechanics. + + Traverses the expression tree once, performing the subs found in sub_dicts. + Terms inside ``Derivative`` expressions are ignored: + + Examples + ======== + + >>> from sympy.physics.mechanics import dynamicsymbols, msubs + >>> x = dynamicsymbols('x') + >>> msubs(x.diff() + x, {x: 1}) + Derivative(x(t), t) + 1 + + Note that sub_dicts can be a single dictionary, or several dictionaries: + + >>> x, y, z = dynamicsymbols('x, y, z') + >>> sub1 = {x: 1, y: 2} + >>> sub2 = {z: 3, x.diff(): 4} + >>> msubs(x.diff() + x + y + z, sub1, sub2) + 10 + + If smart=True (default False), also checks for conditions that may result + in ``nan``, but if simplified would yield a valid expression. For example: + + >>> from sympy import sin, tan + >>> (sin(x)/tan(x)).subs(x, 0) + nan + >>> msubs(sin(x)/tan(x), {x: 0}, smart=True) + 1 + + It does this by first replacing all ``tan`` with ``sin/cos``. Then each + node is traversed. If the node is a fraction, subs is first evaluated on + the denominator. If this results in 0, simplification of the entire + fraction is attempted. Using this selective simplification, only + subexpressions that result in 1/0 are targeted, resulting in faster + performance. + + """ + + sub_dict = dict_merge(*sub_dicts) + if smart: + func = _smart_subs + elif hasattr(expr, 'msubs'): + return expr.msubs(sub_dict) + else: + func = lambda expr, sub_dict: _crawl(expr, _sub_func, sub_dict) + if isinstance(expr, (Matrix, Vector, Dyadic)): + return expr.applyfunc(lambda x: func(x, sub_dict)) + else: + return func(expr, sub_dict) + + +def _crawl(expr, func, *args, **kwargs): + """Crawl the expression tree, and apply func to every node.""" + val = func(expr, *args, **kwargs) + if val is not None: + return val + new_args = (_crawl(arg, func, *args, **kwargs) for arg in expr.args) + return expr.func(*new_args) + + +def _sub_func(expr, sub_dict): + """Perform direct matching substitution, ignoring derivatives.""" + if expr in sub_dict: + return sub_dict[expr] + elif not expr.args or expr.is_Derivative: + return expr + + +def _tan_repl_func(expr): + """Replace tan with sin/cos.""" + if isinstance(expr, tan): + return sin(*expr.args) / cos(*expr.args) + elif not expr.args or expr.is_Derivative: + return expr + + +def _smart_subs(expr, sub_dict): + """Performs subs, checking for conditions that may result in `nan` or + `oo`, and attempts to simplify them out. + + The expression tree is traversed twice, and the following steps are + performed on each expression node: + - First traverse: + Replace all `tan` with `sin/cos`. + - Second traverse: + If node is a fraction, check if the denominator evaluates to 0. + If so, attempt to simplify it out. Then if node is in sub_dict, + sub in the corresponding value. + + """ + expr = _crawl(expr, _tan_repl_func) + + def _recurser(expr, sub_dict): + # Decompose the expression into num, den + num, den = _fraction_decomp(expr) + if den != 1: + # If there is a non trivial denominator, we need to handle it + denom_subbed = _recurser(den, sub_dict) + if denom_subbed.evalf() == 0: + # If denom is 0 after this, attempt to simplify the bad expr + expr = simplify(expr) + else: + # Expression won't result in nan, find numerator + num_subbed = _recurser(num, sub_dict) + return num_subbed / denom_subbed + # We have to crawl the tree manually, because `expr` may have been + # modified in the simplify step. First, perform subs as normal: + val = _sub_func(expr, sub_dict) + if val is not None: + return val + new_args = (_recurser(arg, sub_dict) for arg in expr.args) + return expr.func(*new_args) + return _recurser(expr, sub_dict) + + +def _fraction_decomp(expr): + """Return num, den such that expr = num/den.""" + if not isinstance(expr, Mul): + return expr, 1 + num = [] + den = [] + for a in expr.args: + if a.is_Pow and a.args[1] < 0: + den.append(1 / a) + else: + num.append(a) + if not den: + return expr, 1 + num = Mul(*num) + den = Mul(*den) + return num, den + + +def _f_list_parser(fl, ref_frame): + """Parses the provided forcelist composed of items + of the form (obj, force). + Returns a tuple containing: + vel_list: The velocity (ang_vel for Frames, vel for Points) in + the provided reference frame. + f_list: The forces. + + Used internally in the KanesMethod and LagrangesMethod classes. + + """ + def flist_iter(): + for pair in fl: + obj, force = pair + if isinstance(obj, ReferenceFrame): + yield obj.ang_vel_in(ref_frame), force + elif isinstance(obj, Point): + yield obj.vel(ref_frame), force + else: + raise TypeError('First entry in each forcelist pair must ' + 'be a point or frame.') + + if not fl: + vel_list, f_list = (), () + else: + unzip = lambda l: list(zip(*l)) if l[0] else [(), ()] + vel_list, f_list = unzip(list(flist_iter())) + return vel_list, f_list + + +def _validate_coordinates(coordinates=None, speeds=None, check_duplicates=True, + is_dynamicsymbols=True, u_auxiliary=None): + """Validate the generalized coordinates and generalized speeds. + + Parameters + ========== + coordinates : iterable, optional + Generalized coordinates to be validated. + speeds : iterable, optional + Generalized speeds to be validated. + check_duplicates : bool, optional + Checks if there are duplicates in the generalized coordinates and + generalized speeds. If so it will raise a ValueError. The default is + True. + is_dynamicsymbols : iterable, optional + Checks if all the generalized coordinates and generalized speeds are + dynamicsymbols. If any is not a dynamicsymbol, a ValueError will be + raised. The default is True. + u_auxiliary : iterable, optional + Auxiliary generalized speeds to be validated. + + """ + t_set = {dynamicsymbols._t} + # Convert input to iterables + if coordinates is None: + coordinates = [] + elif not iterable(coordinates): + coordinates = [coordinates] + if speeds is None: + speeds = [] + elif not iterable(speeds): + speeds = [speeds] + if u_auxiliary is None: + u_auxiliary = [] + elif not iterable(u_auxiliary): + u_auxiliary = [u_auxiliary] + + msgs = [] + if check_duplicates: # Check for duplicates + seen = set() + coord_duplicates = {x for x in coordinates if x in seen or seen.add(x)} + seen = set() + speed_duplicates = {x for x in speeds if x in seen or seen.add(x)} + seen = set() + aux_duplicates = {x for x in u_auxiliary if x in seen or seen.add(x)} + overlap_coords = set(coordinates).intersection(speeds) + overlap_aux = set(coordinates).union(speeds).intersection(u_auxiliary) + if coord_duplicates: + msgs.append(f'The generalized coordinates {coord_duplicates} are ' + f'duplicated, all generalized coordinates should be ' + f'unique.') + if speed_duplicates: + msgs.append(f'The generalized speeds {speed_duplicates} are ' + f'duplicated, all generalized speeds should be unique.') + if aux_duplicates: + msgs.append(f'The auxiliary speeds {aux_duplicates} are duplicated,' + f' all auxiliary speeds should be unique.') + if overlap_coords: + msgs.append(f'{overlap_coords} are defined as both generalized ' + f'coordinates and generalized speeds.') + if overlap_aux: + msgs.append(f'The auxiliary speeds {overlap_aux} are also defined ' + f'as generalized coordinates or generalized speeds.') + if is_dynamicsymbols: # Check whether all coordinates are dynamicsymbols + for coordinate in coordinates: + if not (isinstance(coordinate, (AppliedUndef, Derivative)) and + coordinate.free_symbols == t_set): + msgs.append(f'Generalized coordinate "{coordinate}" is not a ' + f'dynamicsymbol.') + for speed in speeds: + if not (isinstance(speed, (AppliedUndef, Derivative)) and + speed.free_symbols == t_set): + msgs.append( + f'Generalized speed "{speed}" is not a dynamicsymbol.') + for aux in u_auxiliary: + if not (isinstance(aux, (AppliedUndef, Derivative)) and + aux.free_symbols == t_set): + msgs.append( + f'Auxiliary speed "{aux}" is not a dynamicsymbol.') + if msgs: + raise ValueError('\n'.join(msgs)) + + +def _parse_linear_solver(linear_solver): + """Helper function to retrieve a specified linear solver.""" + if callable(linear_solver): + return linear_solver + return lambda A, b: Matrix.solve(A, b, method=linear_solver) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/inertia.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/inertia.py new file mode 100644 index 0000000000000000000000000000000000000000..683c1f630f3cedb82d02a9c5ba2309ae438b7fff --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/inertia.py @@ -0,0 +1,199 @@ +from sympy import sympify +from sympy.physics.vector import Point, Dyadic, ReferenceFrame, outer +from collections import namedtuple + +__all__ = ['inertia', 'inertia_of_point_mass', 'Inertia'] + + +def inertia(frame, ixx, iyy, izz, ixy=0, iyz=0, izx=0): + """Simple way to create inertia Dyadic object. + + Explanation + =========== + + Creates an inertia Dyadic based on the given tensor values and a body-fixed + reference frame. + + Parameters + ========== + + frame : ReferenceFrame + The frame the inertia is defined in. + ixx : Sympifyable + The xx element in the inertia dyadic. + iyy : Sympifyable + The yy element in the inertia dyadic. + izz : Sympifyable + The zz element in the inertia dyadic. + ixy : Sympifyable + The xy element in the inertia dyadic. + iyz : Sympifyable + The yz element in the inertia dyadic. + izx : Sympifyable + The zx element in the inertia dyadic. + + Examples + ======== + + >>> from sympy.physics.mechanics import ReferenceFrame, inertia + >>> N = ReferenceFrame('N') + >>> inertia(N, 1, 2, 3) + (N.x|N.x) + 2*(N.y|N.y) + 3*(N.z|N.z) + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Need to define the inertia in a frame') + ixx, iyy, izz = sympify(ixx), sympify(iyy), sympify(izz) + ixy, iyz, izx = sympify(ixy), sympify(iyz), sympify(izx) + return (ixx*outer(frame.x, frame.x) + ixy*outer(frame.x, frame.y) + + izx*outer(frame.x, frame.z) + ixy*outer(frame.y, frame.x) + + iyy*outer(frame.y, frame.y) + iyz*outer(frame.y, frame.z) + + izx*outer(frame.z, frame.x) + iyz*outer(frame.z, frame.y) + + izz*outer(frame.z, frame.z)) + + +def inertia_of_point_mass(mass, pos_vec, frame): + """Inertia dyadic of a point mass relative to point O. + + Parameters + ========== + + mass : Sympifyable + Mass of the point mass + pos_vec : Vector + Position from point O to point mass + frame : ReferenceFrame + Reference frame to express the dyadic in + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import ReferenceFrame, inertia_of_point_mass + >>> N = ReferenceFrame('N') + >>> r, m = symbols('r m') + >>> px = r * N.x + >>> inertia_of_point_mass(m, px, N) + m*r**2*(N.y|N.y) + m*r**2*(N.z|N.z) + + """ + + return mass*( + (outer(frame.x, frame.x) + + outer(frame.y, frame.y) + + outer(frame.z, frame.z)) * + (pos_vec.dot(pos_vec)) - outer(pos_vec, pos_vec)) + + +class Inertia(namedtuple('Inertia', ['dyadic', 'point'])): + """Inertia object consisting of a Dyadic and a Point of reference. + + Explanation + =========== + + This is a simple class to store the Point and Dyadic, belonging to an + inertia. + + Attributes + ========== + + dyadic : Dyadic + The dyadic of the inertia. + point : Point + The reference point of the inertia. + + Examples + ======== + + >>> from sympy.physics.mechanics import ReferenceFrame, Point, Inertia + >>> N = ReferenceFrame('N') + >>> Po = Point('Po') + >>> Inertia(N.x.outer(N.x) + N.y.outer(N.y) + N.z.outer(N.z), Po) + ((N.x|N.x) + (N.y|N.y) + (N.z|N.z), Po) + + In the example above the Dyadic was created manually, one can however also + use the ``inertia`` function for this or the class method ``from_tensor`` as + shown below. + + >>> Inertia.from_inertia_scalars(Po, N, 1, 1, 1) + ((N.x|N.x) + (N.y|N.y) + (N.z|N.z), Po) + + """ + __slots__ = () + + def __new__(cls, dyadic, point): + # Switch order if given in the wrong order + if isinstance(dyadic, Point) and isinstance(point, Dyadic): + point, dyadic = dyadic, point + if not isinstance(point, Point): + raise TypeError('Reference point should be of type Point') + if not isinstance(dyadic, Dyadic): + raise TypeError('Inertia value should be expressed as a Dyadic') + return super().__new__(cls, dyadic, point) + + @classmethod + def from_inertia_scalars(cls, point, frame, ixx, iyy, izz, ixy=0, iyz=0, + izx=0): + """Simple way to create an Inertia object based on the tensor values. + + Explanation + =========== + + This class method uses the :func`~.inertia` to create the Dyadic based + on the tensor values. + + Parameters + ========== + + point : Point + The reference point of the inertia. + frame : ReferenceFrame + The frame the inertia is defined in. + ixx : Sympifyable + The xx element in the inertia dyadic. + iyy : Sympifyable + The yy element in the inertia dyadic. + izz : Sympifyable + The zz element in the inertia dyadic. + ixy : Sympifyable + The xy element in the inertia dyadic. + iyz : Sympifyable + The yz element in the inertia dyadic. + izx : Sympifyable + The zx element in the inertia dyadic. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import ReferenceFrame, Point, Inertia + >>> ixx, iyy, izz, ixy, iyz, izx = symbols('ixx iyy izz ixy iyz izx') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> I = Inertia.from_inertia_scalars(P, N, ixx, iyy, izz, ixy, iyz, izx) + + The tensor values can easily be seen when converting the dyadic to a + matrix. + + >>> I.dyadic.to_matrix(N) + Matrix([ + [ixx, ixy, izx], + [ixy, iyy, iyz], + [izx, iyz, izz]]) + + """ + return cls(inertia(frame, ixx, iyy, izz, ixy, iyz, izx), point) + + def __add__(self, other): + raise TypeError(f"unsupported operand type(s) for +: " + f"'{self.__class__.__name__}' and " + f"'{other.__class__.__name__}'") + + def __mul__(self, other): + raise TypeError(f"unsupported operand type(s) for *: " + f"'{self.__class__.__name__}' and " + f"'{other.__class__.__name__}'") + + __radd__ = __add__ + __rmul__ = __mul__ diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/joint.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/joint.py new file mode 100644 index 0000000000000000000000000000000000000000..6f3fe661532cff6bf8dda4ab4383fc09f75e9e44 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/joint.py @@ -0,0 +1,2188 @@ +# coding=utf-8 + +from abc import ABC, abstractmethod + +from sympy import pi, Derivative, Matrix +from sympy.core.function import AppliedUndef +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.functions import _validate_coordinates +from sympy.physics.vector import (Vector, dynamicsymbols, cross, Point, + ReferenceFrame) +from sympy.utilities.iterables import iterable +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['Joint', 'PinJoint', 'PrismaticJoint', 'CylindricalJoint', + 'PlanarJoint', 'SphericalJoint', 'WeldJoint'] + + +class Joint(ABC): + """Abstract base class for all specific joints. + + Explanation + =========== + + A joint subtracts degrees of freedom from a body. This is the base class + for all specific joints and holds all common methods acting as an interface + for all joints. Custom joint can be created by inheriting Joint class and + defining all abstract functions. + + The abstract methods are: + + - ``_generate_coordinates`` + - ``_generate_speeds`` + - ``_orient_frames`` + - ``_set_angular_velocity`` + - ``_set_linear_velocity`` + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + coordinates : iterable of dynamicsymbols, optional + Generalized coordinates of the joint. + speeds : iterable of dynamicsymbols, optional + Generalized speeds of joint. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the parent body which aligns with an axis fixed in the + child body. The default is the x axis of parent's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + child_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the child body which aligns with an axis fixed in the + parent body. The default is the x axis of child's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + parent_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by parent_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + child_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by child_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_axis : Vector + The axis fixed in the parent frame that represents the joint. + child_axis : Vector + The axis fixed in the child frame that represents the joint. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + + Notes + ===== + + When providing a vector as the intermediate frame, a new intermediate frame + is created which aligns its X axis with the provided vector. This is done + with a single fixed rotation about a rotation axis. This rotation axis is + determined by taking the cross product of the ``body.x`` axis with the + provided vector. In the case where the provided vector is in the ``-body.x`` + direction, the rotation is done about the ``body.y`` axis. + + """ + + def __init__(self, name, parent, child, coordinates=None, speeds=None, + parent_point=None, child_point=None, parent_interframe=None, + child_interframe=None, parent_axis=None, child_axis=None, + parent_joint_pos=None, child_joint_pos=None): + + if not isinstance(name, str): + raise TypeError('Supply a valid name.') + self._name = name + + if not isinstance(parent, BodyBase): + raise TypeError('Parent must be a body.') + self._parent = parent + + if not isinstance(child, BodyBase): + raise TypeError('Child must be a body.') + self._child = child + + if parent_axis is not None or child_axis is not None: + sympy_deprecation_warning( + """ + The parent_axis and child_axis arguments for the Joint classes + are deprecated. Instead use parent_interframe, child_interframe. + """, + deprecated_since_version="1.12", + active_deprecations_target="deprecated-mechanics-joint-axis", + stacklevel=4 + ) + if parent_interframe is None: + parent_interframe = parent_axis + if child_interframe is None: + child_interframe = child_axis + + # Set parent and child frame attributes + if hasattr(self._parent, 'frame'): + self._parent_frame = self._parent.frame + else: + if isinstance(parent_interframe, ReferenceFrame): + self._parent_frame = parent_interframe + else: + self._parent_frame = ReferenceFrame( + f'{self.name}_{self._parent.name}_frame') + if hasattr(self._child, 'frame'): + self._child_frame = self._child.frame + else: + if isinstance(child_interframe, ReferenceFrame): + self._child_frame = child_interframe + else: + self._child_frame = ReferenceFrame( + f'{self.name}_{self._child.name}_frame') + + self._parent_interframe = self._locate_joint_frame( + self._parent, parent_interframe, self._parent_frame) + self._child_interframe = self._locate_joint_frame( + self._child, child_interframe, self._child_frame) + self._parent_axis = self._axis(parent_axis, self._parent_frame) + self._child_axis = self._axis(child_axis, self._child_frame) + + if parent_joint_pos is not None or child_joint_pos is not None: + sympy_deprecation_warning( + """ + The parent_joint_pos and child_joint_pos arguments for the Joint + classes are deprecated. Instead use parent_point and child_point. + """, + deprecated_since_version="1.12", + active_deprecations_target="deprecated-mechanics-joint-pos", + stacklevel=4 + ) + if parent_point is None: + parent_point = parent_joint_pos + if child_point is None: + child_point = child_joint_pos + self._parent_point = self._locate_joint_pos( + self._parent, parent_point, self._parent_frame) + self._child_point = self._locate_joint_pos( + self._child, child_point, self._child_frame) + + self._coordinates = self._generate_coordinates(coordinates) + self._speeds = self._generate_speeds(speeds) + _validate_coordinates(self.coordinates, self.speeds) + self._kdes = self._generate_kdes() + + self._orient_frames() + self._set_angular_velocity() + self._set_linear_velocity() + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + @property + def name(self): + """Name of the joint.""" + return self._name + + @property + def parent(self): + """Parent body of Joint.""" + return self._parent + + @property + def child(self): + """Child body of Joint.""" + return self._child + + @property + def coordinates(self): + """Matrix of the joint's generalized coordinates.""" + return self._coordinates + + @property + def speeds(self): + """Matrix of the joint's generalized speeds.""" + return self._speeds + + @property + def kdes(self): + """Kinematical differential equations of the joint.""" + return self._kdes + + @property + def parent_axis(self): + """The axis of parent frame.""" + # Will be removed with `deprecated-mechanics-joint-axis` + return self._parent_axis + + @property + def child_axis(self): + """The axis of child frame.""" + # Will be removed with `deprecated-mechanics-joint-axis` + return self._child_axis + + @property + def parent_point(self): + """Attachment point where the joint is fixed to the parent body.""" + return self._parent_point + + @property + def child_point(self): + """Attachment point where the joint is fixed to the child body.""" + return self._child_point + + @property + def parent_interframe(self): + return self._parent_interframe + + @property + def child_interframe(self): + return self._child_interframe + + @abstractmethod + def _generate_coordinates(self, coordinates): + """Generate Matrix of the joint's generalized coordinates.""" + pass + + @abstractmethod + def _generate_speeds(self, speeds): + """Generate Matrix of the joint's generalized speeds.""" + pass + + @abstractmethod + def _orient_frames(self): + """Orient frames as per the joint.""" + pass + + @abstractmethod + def _set_angular_velocity(self): + """Set angular velocity of the joint related frames.""" + pass + + @abstractmethod + def _set_linear_velocity(self): + """Set velocity of related points to the joint.""" + pass + + @staticmethod + def _to_vector(matrix, frame): + """Converts a matrix to a vector in the given frame.""" + return Vector([(matrix, frame)]) + + @staticmethod + def _axis(ax, *frames): + """Check whether an axis is fixed in one of the frames.""" + if ax is None: + ax = frames[0].x + return ax + if not isinstance(ax, Vector): + raise TypeError("Axis must be a Vector.") + ref_frame = None # Find a body in which the axis can be expressed + for frame in frames: + try: + ax.to_matrix(frame) + except ValueError: + pass + else: + ref_frame = frame + break + if ref_frame is None: + raise ValueError("Axis cannot be expressed in one of the body's " + "frames.") + if not ax.dt(ref_frame) == 0: + raise ValueError('Axis cannot be time-varying when viewed from the ' + 'associated body.') + return ax + + @staticmethod + def _choose_rotation_axis(frame, axis): + components = axis.to_matrix(frame) + x, y, z = components[0], components[1], components[2] + + if x != 0: + if y != 0: + if z != 0: + return cross(axis, frame.x) + if z != 0: + return frame.y + return frame.z + else: + if y != 0: + return frame.x + return frame.y + + @staticmethod + def _create_aligned_interframe(frame, align_axis, frame_axis=None, + frame_name=None): + """ + Returns an intermediate frame, where the ``frame_axis`` defined in + ``frame`` is aligned with ``axis``. By default this means that the X + axis will be aligned with ``axis``. + + Parameters + ========== + + frame : BodyBase or ReferenceFrame + The body or reference frame with respect to which the intermediate + frame is oriented. + align_axis : Vector + The vector with respect to which the intermediate frame will be + aligned. + frame_axis : Vector + The vector of the frame which should get aligned with ``axis``. The + default is the X axis of the frame. + frame_name : string + Name of the to be created intermediate frame. The default adds + "_int_frame" to the name of ``frame``. + + Example + ======= + + An intermediate frame, where the X axis of the parent becomes aligned + with ``parent.y + parent.z`` can be created as follows: + + >>> from sympy.physics.mechanics.joint import Joint + >>> from sympy.physics.mechanics import RigidBody + >>> parent = RigidBody('parent') + >>> parent_interframe = Joint._create_aligned_interframe( + ... parent, parent.y + parent.z) + >>> parent_interframe + parent_int_frame + >>> parent.frame.dcm(parent_interframe) + Matrix([ + [ 0, -sqrt(2)/2, -sqrt(2)/2], + [sqrt(2)/2, 1/2, -1/2], + [sqrt(2)/2, -1/2, 1/2]]) + >>> (parent.y + parent.z).express(parent_interframe) + sqrt(2)*parent_int_frame.x + + Notes + ===== + + The direction cosine matrix between the given frame and intermediate + frame is formed using a simple rotation about an axis that is normal to + both ``align_axis`` and ``frame_axis``. In general, the normal axis is + formed by crossing the ``frame_axis`` with the ``align_axis``. The + exception is if the axes are parallel with opposite directions, in which + case the rotation vector is chosen using the rules in the following + table with the vectors expressed in the given frame: + + .. list-table:: + :header-rows: 1 + + * - ``align_axis`` + - ``frame_axis`` + - ``rotation_axis`` + * - ``-x`` + - ``x`` + - ``z`` + * - ``-y`` + - ``y`` + - ``x`` + * - ``-z`` + - ``z`` + - ``y`` + * - ``-x-y`` + - ``x+y`` + - ``z`` + * - ``-y-z`` + - ``y+z`` + - ``x`` + * - ``-x-z`` + - ``x+z`` + - ``y`` + * - ``-x-y-z`` + - ``x+y+z`` + - ``(x+y+z) × x`` + + """ + if isinstance(frame, BodyBase): + frame = frame.frame + if frame_axis is None: + frame_axis = frame.x + if frame_name is None: + if frame.name[-6:] == '_frame': + frame_name = f'{frame.name[:-6]}_int_frame' + else: + frame_name = f'{frame.name}_int_frame' + angle = frame_axis.angle_between(align_axis) + rotation_axis = cross(frame_axis, align_axis) + if rotation_axis == Vector(0) and angle == 0: + return frame + if angle == pi: + rotation_axis = Joint._choose_rotation_axis(frame, align_axis) + + int_frame = ReferenceFrame(frame_name) + int_frame.orient_axis(frame, rotation_axis, angle) + int_frame.set_ang_vel(frame, 0 * rotation_axis) + return int_frame + + def _generate_kdes(self): + """Generate kinematical differential equations.""" + kdes = [] + t = dynamicsymbols._t + for i in range(len(self.coordinates)): + kdes.append(-self.coordinates[i].diff(t) + self.speeds[i]) + return Matrix(kdes) + + def _locate_joint_pos(self, body, joint_pos, body_frame=None): + """Returns the attachment point of a body.""" + if body_frame is None: + body_frame = body.frame + if joint_pos is None: + return body.masscenter + if not isinstance(joint_pos, (Point, Vector)): + raise TypeError('Attachment point must be a Point or Vector.') + if isinstance(joint_pos, Vector): + point_name = f'{self.name}_{body.name}_joint' + joint_pos = body.masscenter.locatenew(point_name, joint_pos) + if not joint_pos.pos_from(body.masscenter).dt(body_frame) == 0: + raise ValueError('Attachment point must be fixed to the associated ' + 'body.') + return joint_pos + + def _locate_joint_frame(self, body, interframe, body_frame=None): + """Returns the attachment frame of a body.""" + if body_frame is None: + body_frame = body.frame + if interframe is None: + return body_frame + if isinstance(interframe, Vector): + interframe = Joint._create_aligned_interframe( + body_frame, interframe, + frame_name=f'{self.name}_{body.name}_int_frame') + elif not isinstance(interframe, ReferenceFrame): + raise TypeError('Interframe must be a ReferenceFrame.') + if not interframe.ang_vel_in(body_frame) == 0: + raise ValueError(f'Interframe {interframe} is not fixed to body ' + f'{body}.') + body.masscenter.set_vel(interframe, 0) # Fixate interframe to body + return interframe + + def _fill_coordinate_list(self, coordinates, n_coords, label='q', offset=0, + number_single=False): + """Helper method for _generate_coordinates and _generate_speeds. + + Parameters + ========== + + coordinates : iterable + Iterable of coordinates or speeds that have been provided. + n_coords : Integer + Number of coordinates that should be returned. + label : String, optional + Coordinate type either 'q' (coordinates) or 'u' (speeds). The + Default is 'q'. + offset : Integer + Count offset when creating new dynamicsymbols. The default is 0. + number_single : Boolean + Boolean whether if n_coords == 1, number should still be used. The + default is False. + + """ + + def create_symbol(number): + if n_coords == 1 and not number_single: + return dynamicsymbols(f'{label}_{self.name}') + return dynamicsymbols(f'{label}{number}_{self.name}') + + name = 'generalized coordinate' if label == 'q' else 'generalized speed' + generated_coordinates = [] + if coordinates is None: + coordinates = [] + elif not iterable(coordinates): + coordinates = [coordinates] + if not (len(coordinates) == 0 or len(coordinates) == n_coords): + raise ValueError(f'Expected {n_coords} {name}s, instead got ' + f'{len(coordinates)} {name}s.') + # Supports more iterables, also Matrix + for i, coord in enumerate(coordinates): + if coord is None: + generated_coordinates.append(create_symbol(i + offset)) + elif isinstance(coord, (AppliedUndef, Derivative)): + generated_coordinates.append(coord) + else: + raise TypeError(f'The {name} {coord} should have been a ' + f'dynamicsymbol.') + for i in range(len(coordinates) + offset, n_coords + offset): + generated_coordinates.append(create_symbol(i)) + return Matrix(generated_coordinates) + + +class PinJoint(Joint): + """Pin (Revolute) Joint. + + .. raw:: html + :file: ../../../doc/src/explanation/modules/physics/mechanics/PinJoint.svg + + Explanation + =========== + + A pin joint is defined such that the joint rotation axis is fixed in both + the child and parent and the location of the joint is relative to the mass + center of each body. The child rotates an angle, θ, from the parent about + the rotation axis and has a simple angular speed, ω, relative to the + parent. The direction cosine matrix between the child interframe and + parent interframe is formed using a simple rotation about the joint axis. + The page on the joints framework gives a more detailed explanation of the + intermediate frames. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + coordinates : dynamicsymbol, optional + Generalized coordinates of the joint. + speeds : dynamicsymbol, optional + Generalized speeds of joint. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the parent body which aligns with an axis fixed in the + child body. The default is the x axis of parent's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + child_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the child body which aligns with an axis fixed in the + parent body. The default is the x axis of child's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + joint_axis : Vector + The axis about which the rotation occurs. Note that the components + of this axis are the same in the parent_interframe and child_interframe. + parent_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by parent_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + child_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by child_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. The default value is + ``dynamicsymbols(f'q_{joint.name}')``. + speeds : Matrix + Matrix of the joint's generalized speeds. The default value is + ``dynamicsymbols(f'u_{joint.name}')``. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_axis : Vector + The axis fixed in the parent frame that represents the joint. + child_axis : Vector + The axis fixed in the child frame that represents the joint. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + joint_axis : Vector + The axis about which the rotation occurs. Note that the components of + this axis are the same in the parent_interframe and child_interframe. + kdes : Matrix + Kinematical differential equations of the joint. + + Examples + ========= + + A single pin joint is created from two bodies and has the following basic + attributes: + + >>> from sympy.physics.mechanics import RigidBody, PinJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = PinJoint('PC', parent, child) + >>> joint + PinJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.parent_axis + P_frame.x + >>> joint.child_axis + C_frame.x + >>> joint.coordinates + Matrix([[q_PC(t)]]) + >>> joint.speeds + Matrix([[u_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame) + u_PC(t)*P_frame.x + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, cos(q_PC(t)), sin(q_PC(t))], + [0, -sin(q_PC(t)), cos(q_PC(t))]]) + >>> joint.child_point.pos_from(joint.parent_point) + 0 + + To further demonstrate the use of the pin joint, the kinematics of simple + double pendulum that rotates about the Z axis of each connected body can be + created as follows. + + >>> from sympy import symbols, trigsimp + >>> from sympy.physics.mechanics import RigidBody, PinJoint + >>> l1, l2 = symbols('l1 l2') + + First create bodies to represent the fixed ceiling and one to represent + each pendulum bob. + + >>> ceiling = RigidBody('C') + >>> upper_bob = RigidBody('U') + >>> lower_bob = RigidBody('L') + + The first joint will connect the upper bob to the ceiling by a distance of + ``l1`` and the joint axis will be about the Z axis for each body. + + >>> ceiling_joint = PinJoint('P1', ceiling, upper_bob, + ... child_point=-l1*upper_bob.frame.x, + ... joint_axis=ceiling.frame.z) + + The second joint will connect the lower bob to the upper bob by a distance + of ``l2`` and the joint axis will also be about the Z axis for each body. + + >>> pendulum_joint = PinJoint('P2', upper_bob, lower_bob, + ... child_point=-l2*lower_bob.frame.x, + ... joint_axis=upper_bob.frame.z) + + Once the joints are established the kinematics of the connected bodies can + be accessed. First the direction cosine matrices of pendulum link relative + to the ceiling are found: + + >>> upper_bob.frame.dcm(ceiling.frame) + Matrix([ + [ cos(q_P1(t)), sin(q_P1(t)), 0], + [-sin(q_P1(t)), cos(q_P1(t)), 0], + [ 0, 0, 1]]) + >>> trigsimp(lower_bob.frame.dcm(ceiling.frame)) + Matrix([ + [ cos(q_P1(t) + q_P2(t)), sin(q_P1(t) + q_P2(t)), 0], + [-sin(q_P1(t) + q_P2(t)), cos(q_P1(t) + q_P2(t)), 0], + [ 0, 0, 1]]) + + The position of the lower bob's masscenter is found with: + + >>> lower_bob.masscenter.pos_from(ceiling.masscenter) + l1*U_frame.x + l2*L_frame.x + + The angular velocities of the two pendulum links can be computed with + respect to the ceiling. + + >>> upper_bob.frame.ang_vel_in(ceiling.frame) + u_P1(t)*C_frame.z + >>> lower_bob.frame.ang_vel_in(ceiling.frame) + u_P1(t)*C_frame.z + u_P2(t)*U_frame.z + + And finally, the linear velocities of the two pendulum bobs can be computed + with respect to the ceiling. + + >>> upper_bob.masscenter.vel(ceiling.frame) + l1*u_P1(t)*U_frame.y + >>> lower_bob.masscenter.vel(ceiling.frame) + l1*u_P1(t)*U_frame.y + l2*(u_P1(t) + u_P2(t))*L_frame.y + + """ + + def __init__(self, name, parent, child, coordinates=None, speeds=None, + parent_point=None, child_point=None, parent_interframe=None, + child_interframe=None, parent_axis=None, child_axis=None, + joint_axis=None, parent_joint_pos=None, child_joint_pos=None): + + self._joint_axis = joint_axis + super().__init__(name, parent, child, coordinates, speeds, parent_point, + child_point, parent_interframe, child_interframe, + parent_axis, child_axis, parent_joint_pos, + child_joint_pos) + + def __str__(self): + return (f'PinJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + @property + def joint_axis(self): + """Axis about which the child rotates with respect to the parent.""" + return self._joint_axis + + def _generate_coordinates(self, coordinate): + return self._fill_coordinate_list(coordinate, 1, 'q') + + def _generate_speeds(self, speed): + return self._fill_coordinate_list(speed, 1, 'u') + + def _orient_frames(self): + self._joint_axis = self._axis(self.joint_axis, self.parent_interframe) + self.child_interframe.orient_axis( + self.parent_interframe, self.joint_axis, self.coordinates[0]) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel(self.parent_interframe, self.speeds[ + 0] * self.joint_axis.normalize()) + + def _set_linear_velocity(self): + self.child_point.set_pos(self.parent_point, 0) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child.masscenter.v2pt_theory(self.parent_point, + self._parent_frame, self._child_frame) + + +class PrismaticJoint(Joint): + """Prismatic (Sliding) Joint. + + .. image:: PrismaticJoint.svg + + Explanation + =========== + + It is defined such that the child body translates with respect to the parent + body along the body-fixed joint axis. The location of the joint is defined + by two points, one in each body, which coincide when the generalized + coordinate is zero. The direction cosine matrix between the + parent_interframe and child_interframe is the identity matrix. Therefore, + the direction cosine matrix between the parent and child frames is fully + defined by the definition of the intermediate frames. The page on the joints + framework gives a more detailed explanation of the intermediate frames. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + coordinates : dynamicsymbol, optional + Generalized coordinates of the joint. The default value is + ``dynamicsymbols(f'q_{joint.name}')``. + speeds : dynamicsymbol, optional + Generalized speeds of joint. The default value is + ``dynamicsymbols(f'u_{joint.name}')``. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the parent body which aligns with an axis fixed in the + child body. The default is the x axis of parent's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + child_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the child body which aligns with an axis fixed in the + parent body. The default is the x axis of child's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + joint_axis : Vector + The axis along which the translation occurs. Note that the components + of this axis are the same in the parent_interframe and child_interframe. + parent_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by parent_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + child_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by child_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_axis : Vector + The axis fixed in the parent frame that represents the joint. + child_axis : Vector + The axis fixed in the child frame that represents the joint. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + + Examples + ========= + + A single prismatic joint is created from two bodies and has the following + basic attributes: + + >>> from sympy.physics.mechanics import RigidBody, PrismaticJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = PrismaticJoint('PC', parent, child) + >>> joint + PrismaticJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.parent_axis + P_frame.x + >>> joint.child_axis + C_frame.x + >>> joint.coordinates + Matrix([[q_PC(t)]]) + >>> joint.speeds + Matrix([[u_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame) + 0 + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> joint.child_point.pos_from(joint.parent_point) + q_PC(t)*P_frame.x + + To further demonstrate the use of the prismatic joint, the kinematics of two + masses sliding, one moving relative to a fixed body and the other relative + to the moving body. about the X axis of each connected body can be created + as follows. + + >>> from sympy.physics.mechanics import PrismaticJoint, RigidBody + + First create bodies to represent the fixed ceiling and one to represent + a particle. + + >>> wall = RigidBody('W') + >>> Part1 = RigidBody('P1') + >>> Part2 = RigidBody('P2') + + The first joint will connect the particle to the ceiling and the + joint axis will be about the X axis for each body. + + >>> J1 = PrismaticJoint('J1', wall, Part1) + + The second joint will connect the second particle to the first particle + and the joint axis will also be about the X axis for each body. + + >>> J2 = PrismaticJoint('J2', Part1, Part2) + + Once the joint is established the kinematics of the connected bodies can + be accessed. First the direction cosine matrices of Part relative + to the ceiling are found: + + >>> Part1.frame.dcm(wall.frame) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + >>> Part2.frame.dcm(wall.frame) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + The position of the particles' masscenter is found with: + + >>> Part1.masscenter.pos_from(wall.masscenter) + q_J1(t)*W_frame.x + + >>> Part2.masscenter.pos_from(wall.masscenter) + q_J1(t)*W_frame.x + q_J2(t)*P1_frame.x + + The angular velocities of the two particle links can be computed with + respect to the ceiling. + + >>> Part1.frame.ang_vel_in(wall.frame) + 0 + + >>> Part2.frame.ang_vel_in(wall.frame) + 0 + + And finally, the linear velocities of the two particles can be computed + with respect to the ceiling. + + >>> Part1.masscenter.vel(wall.frame) + u_J1(t)*W_frame.x + + >>> Part2.masscenter.vel(wall.frame) + u_J1(t)*W_frame.x + Derivative(q_J2(t), t)*P1_frame.x + + """ + + def __init__(self, name, parent, child, coordinates=None, speeds=None, + parent_point=None, child_point=None, parent_interframe=None, + child_interframe=None, parent_axis=None, child_axis=None, + joint_axis=None, parent_joint_pos=None, child_joint_pos=None): + + self._joint_axis = joint_axis + super().__init__(name, parent, child, coordinates, speeds, parent_point, + child_point, parent_interframe, child_interframe, + parent_axis, child_axis, parent_joint_pos, + child_joint_pos) + + def __str__(self): + return (f'PrismaticJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + @property + def joint_axis(self): + """Axis along which the child translates with respect to the parent.""" + return self._joint_axis + + def _generate_coordinates(self, coordinate): + return self._fill_coordinate_list(coordinate, 1, 'q') + + def _generate_speeds(self, speed): + return self._fill_coordinate_list(speed, 1, 'u') + + def _orient_frames(self): + self._joint_axis = self._axis(self.joint_axis, self.parent_interframe) + self.child_interframe.orient_axis( + self.parent_interframe, self.joint_axis, 0) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel(self.parent_interframe, 0) + + def _set_linear_velocity(self): + axis = self.joint_axis.normalize() + self.child_point.set_pos(self.parent_point, self.coordinates[0] * axis) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child_point.set_vel(self._parent_frame, self.speeds[0] * axis) + self.child.masscenter.set_vel(self._parent_frame, self.speeds[0] * axis) + + +class CylindricalJoint(Joint): + """Cylindrical Joint. + + .. image:: CylindricalJoint.svg + :align: center + :width: 600 + + Explanation + =========== + + A cylindrical joint is defined such that the child body both rotates about + and translates along the body-fixed joint axis with respect to the parent + body. The joint axis is both the rotation axis and translation axis. The + location of the joint is defined by two points, one in each body, which + coincide when the generalized coordinate corresponding to the translation is + zero. The direction cosine matrix between the child interframe and parent + interframe is formed using a simple rotation about the joint axis. The page + on the joints framework gives a more detailed explanation of the + intermediate frames. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + rotation_coordinate : dynamicsymbol, optional + Generalized coordinate corresponding to the rotation angle. The default + value is ``dynamicsymbols(f'q0_{joint.name}')``. + translation_coordinate : dynamicsymbol, optional + Generalized coordinate corresponding to the translation distance. The + default value is ``dynamicsymbols(f'q1_{joint.name}')``. + rotation_speed : dynamicsymbol, optional + Generalized speed corresponding to the angular velocity. The default + value is ``dynamicsymbols(f'u0_{joint.name}')``. + translation_speed : dynamicsymbol, optional + Generalized speed corresponding to the translation velocity. The default + value is ``dynamicsymbols(f'u1_{joint.name}')``. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + joint_axis : Vector, optional + The rotation as well as translation axis. Note that the components of + this axis are the same in the parent_interframe and child_interframe. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + rotation_coordinate : dynamicsymbol + Generalized coordinate corresponding to the rotation angle. + translation_coordinate : dynamicsymbol + Generalized coordinate corresponding to the translation distance. + rotation_speed : dynamicsymbol + Generalized speed corresponding to the angular velocity. + translation_speed : dynamicsymbol + Generalized speed corresponding to the translation velocity. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + joint_axis : Vector + The axis of rotation and translation. + + Examples + ========= + + A single cylindrical joint is created between two bodies and has the + following basic attributes: + + >>> from sympy.physics.mechanics import RigidBody, CylindricalJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = CylindricalJoint('PC', parent, child) + >>> joint + CylindricalJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.parent_axis + P_frame.x + >>> joint.child_axis + C_frame.x + >>> joint.coordinates + Matrix([ + [q0_PC(t)], + [q1_PC(t)]]) + >>> joint.speeds + Matrix([ + [u0_PC(t)], + [u1_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame) + u0_PC(t)*P_frame.x + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, cos(q0_PC(t)), sin(q0_PC(t))], + [0, -sin(q0_PC(t)), cos(q0_PC(t))]]) + >>> joint.child_point.pos_from(joint.parent_point) + q1_PC(t)*P_frame.x + >>> child.masscenter.vel(parent.frame) + u1_PC(t)*P_frame.x + + To further demonstrate the use of the cylindrical joint, the kinematics of + two cylindrical joints perpendicular to each other can be created as follows. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import RigidBody, CylindricalJoint + >>> r, l, w = symbols('r l w') + + First create bodies to represent the fixed floor with a fixed pole on it. + The second body represents a freely moving tube around that pole. The third + body represents a solid flag freely translating along and rotating around + the Y axis of the tube. + + >>> floor = RigidBody('floor') + >>> tube = RigidBody('tube') + >>> flag = RigidBody('flag') + + The first joint will connect the first tube to the floor with it translating + along and rotating around the Z axis of both bodies. + + >>> floor_joint = CylindricalJoint('C1', floor, tube, joint_axis=floor.z) + + The second joint will connect the tube perpendicular to the flag along the Y + axis of both the tube and the flag, with the joint located at a distance + ``r`` from the tube's center of mass and a combination of the distances + ``l`` and ``w`` from the flag's center of mass. + + >>> flag_joint = CylindricalJoint('C2', tube, flag, + ... parent_point=r * tube.y, + ... child_point=-w * flag.y + l * flag.z, + ... joint_axis=tube.y) + + Once the joints are established the kinematics of the connected bodies can + be accessed. First the direction cosine matrices of both the body and the + flag relative to the floor are found: + + >>> tube.frame.dcm(floor.frame) + Matrix([ + [ cos(q0_C1(t)), sin(q0_C1(t)), 0], + [-sin(q0_C1(t)), cos(q0_C1(t)), 0], + [ 0, 0, 1]]) + >>> flag.frame.dcm(floor.frame) + Matrix([ + [cos(q0_C1(t))*cos(q0_C2(t)), sin(q0_C1(t))*cos(q0_C2(t)), -sin(q0_C2(t))], + [ -sin(q0_C1(t)), cos(q0_C1(t)), 0], + [sin(q0_C2(t))*cos(q0_C1(t)), sin(q0_C1(t))*sin(q0_C2(t)), cos(q0_C2(t))]]) + + The position of the flag's center of mass is found with: + + >>> flag.masscenter.pos_from(floor.masscenter) + q1_C1(t)*floor_frame.z + (r + q1_C2(t))*tube_frame.y + w*flag_frame.y - l*flag_frame.z + + The angular velocities of the two tubes can be computed with respect to the + floor. + + >>> tube.frame.ang_vel_in(floor.frame) + u0_C1(t)*floor_frame.z + >>> flag.frame.ang_vel_in(floor.frame) + u0_C1(t)*floor_frame.z + u0_C2(t)*tube_frame.y + + Finally, the linear velocities of the two tube centers of mass can be + computed with respect to the floor, while expressed in the tube's frame. + + >>> tube.masscenter.vel(floor.frame).to_matrix(tube.frame) + Matrix([ + [ 0], + [ 0], + [u1_C1(t)]]) + >>> flag.masscenter.vel(floor.frame).to_matrix(tube.frame).simplify() + Matrix([ + [-l*u0_C2(t)*cos(q0_C2(t)) - r*u0_C1(t) - w*u0_C1(t) - q1_C2(t)*u0_C1(t)], + [ -l*u0_C1(t)*sin(q0_C2(t)) + Derivative(q1_C2(t), t)], + [ l*u0_C2(t)*sin(q0_C2(t)) + u1_C1(t)]]) + + """ + + def __init__(self, name, parent, child, rotation_coordinate=None, + translation_coordinate=None, rotation_speed=None, + translation_speed=None, parent_point=None, child_point=None, + parent_interframe=None, child_interframe=None, + joint_axis=None): + self._joint_axis = joint_axis + coordinates = (rotation_coordinate, translation_coordinate) + speeds = (rotation_speed, translation_speed) + super().__init__(name, parent, child, coordinates, speeds, + parent_point, child_point, + parent_interframe=parent_interframe, + child_interframe=child_interframe) + + def __str__(self): + return (f'CylindricalJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + @property + def joint_axis(self): + """Axis about and along which the rotation and translation occurs.""" + return self._joint_axis + + @property + def rotation_coordinate(self): + """Generalized coordinate corresponding to the rotation angle.""" + return self.coordinates[0] + + @property + def translation_coordinate(self): + """Generalized coordinate corresponding to the translation distance.""" + return self.coordinates[1] + + @property + def rotation_speed(self): + """Generalized speed corresponding to the angular velocity.""" + return self.speeds[0] + + @property + def translation_speed(self): + """Generalized speed corresponding to the translation velocity.""" + return self.speeds[1] + + def _generate_coordinates(self, coordinates): + return self._fill_coordinate_list(coordinates, 2, 'q') + + def _generate_speeds(self, speeds): + return self._fill_coordinate_list(speeds, 2, 'u') + + def _orient_frames(self): + self._joint_axis = self._axis(self.joint_axis, self.parent_interframe) + self.child_interframe.orient_axis( + self.parent_interframe, self.joint_axis, self.rotation_coordinate) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel( + self.parent_interframe, + self.rotation_speed * self.joint_axis.normalize()) + + def _set_linear_velocity(self): + self.child_point.set_pos( + self.parent_point, + self.translation_coordinate * self.joint_axis.normalize()) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child_point.set_vel( + self._parent_frame, + self.translation_speed * self.joint_axis.normalize()) + self.child.masscenter.v2pt_theory(self.child_point, self._parent_frame, + self.child_interframe) + + +class PlanarJoint(Joint): + """Planar Joint. + + .. raw:: html + :file: ../../../doc/src/modules/physics/mechanics/api/PlanarJoint.svg + + Explanation + =========== + + A planar joint is defined such that the child body translates over a fixed + plane of the parent body as well as rotate about the rotation axis, which + is perpendicular to that plane. The origin of this plane is the + ``parent_point`` and the plane is spanned by two nonparallel planar vectors. + The location of the ``child_point`` is based on the planar vectors + ($\\vec{v}_1$, $\\vec{v}_2$) and generalized coordinates ($q_1$, $q_2$), + i.e. $\\vec{r} = q_1 \\hat{v}_1 + q_2 \\hat{v}_2$. The direction cosine + matrix between the ``child_interframe`` and ``parent_interframe`` is formed + using a simple rotation ($q_0$) about the rotation axis. + + In order to simplify the definition of the ``PlanarJoint``, the + ``rotation_axis`` and ``planar_vectors`` are set to be the unit vectors of + the ``parent_interframe`` according to the table below. This ensures that + you can only define these vectors by creating a separate frame and supplying + that as the interframe. If you however would only like to supply the normals + of the plane with respect to the parent and child bodies, then you can also + supply those to the ``parent_interframe`` and ``child_interframe`` + arguments. An example of both of these cases is in the examples section + below and the page on the joints framework provides a more detailed + explanation of the intermediate frames. + + .. list-table:: + + * - ``rotation_axis`` + - ``parent_interframe.x`` + * - ``planar_vectors[0]`` + - ``parent_interframe.y`` + * - ``planar_vectors[1]`` + - ``parent_interframe.z`` + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + rotation_coordinate : dynamicsymbol, optional + Generalized coordinate corresponding to the rotation angle. The default + value is ``dynamicsymbols(f'q0_{joint.name}')``. + planar_coordinates : iterable of dynamicsymbols, optional + Two generalized coordinates used for the planar translation. The default + value is ``dynamicsymbols(f'q1_{joint.name} q2_{joint.name}')``. + rotation_speed : dynamicsymbol, optional + Generalized speed corresponding to the angular velocity. The default + value is ``dynamicsymbols(f'u0_{joint.name}')``. + planar_speeds : dynamicsymbols, optional + Two generalized speeds used for the planar translation velocity. The + default value is ``dynamicsymbols(f'u1_{joint.name} u2_{joint.name}')``. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + rotation_coordinate : dynamicsymbol + Generalized coordinate corresponding to the rotation angle. + planar_coordinates : Matrix + Two generalized coordinates used for the planar translation. + rotation_speed : dynamicsymbol + Generalized speed corresponding to the angular velocity. + planar_speeds : Matrix + Two generalized speeds used for the planar translation velocity. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + rotation_axis : Vector + The axis about which the rotation occurs. + planar_vectors : list + The vectors that describe the planar translation directions. + + Examples + ========= + + A single planar joint is created between two bodies and has the following + basic attributes: + + >>> from sympy.physics.mechanics import RigidBody, PlanarJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = PlanarJoint('PC', parent, child) + >>> joint + PlanarJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.rotation_axis + P_frame.x + >>> joint.planar_vectors + [P_frame.y, P_frame.z] + >>> joint.rotation_coordinate + q0_PC(t) + >>> joint.planar_coordinates + Matrix([ + [q1_PC(t)], + [q2_PC(t)]]) + >>> joint.coordinates + Matrix([ + [q0_PC(t)], + [q1_PC(t)], + [q2_PC(t)]]) + >>> joint.rotation_speed + u0_PC(t) + >>> joint.planar_speeds + Matrix([ + [u1_PC(t)], + [u2_PC(t)]]) + >>> joint.speeds + Matrix([ + [u0_PC(t)], + [u1_PC(t)], + [u2_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame) + u0_PC(t)*P_frame.x + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, cos(q0_PC(t)), sin(q0_PC(t))], + [0, -sin(q0_PC(t)), cos(q0_PC(t))]]) + >>> joint.child_point.pos_from(joint.parent_point) + q1_PC(t)*P_frame.y + q2_PC(t)*P_frame.z + >>> child.masscenter.vel(parent.frame) + u1_PC(t)*P_frame.y + u2_PC(t)*P_frame.z + + To further demonstrate the use of the planar joint, the kinematics of a + block sliding on a slope, can be created as follows. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import PlanarJoint, RigidBody, ReferenceFrame + >>> a, d, h = symbols('a d h') + + First create bodies to represent the slope and the block. + + >>> ground = RigidBody('G') + >>> block = RigidBody('B') + + To define the slope you can either define the plane by specifying the + ``planar_vectors`` or/and the ``rotation_axis``. However it is advisable to + create a rotated intermediate frame, so that the ``parent_vectors`` and + ``rotation_axis`` will be the unit vectors of this intermediate frame. + + >>> slope = ReferenceFrame('A') + >>> slope.orient_axis(ground.frame, ground.y, a) + + The planar joint can be created using these bodies and intermediate frame. + We can specify the origin of the slope to be ``d`` above the slope's center + of mass and the block's center of mass to be a distance ``h`` above the + slope's surface. Note that we can specify the normal of the plane using the + rotation axis argument. + + >>> joint = PlanarJoint('PC', ground, block, parent_point=d * ground.x, + ... child_point=-h * block.x, parent_interframe=slope) + + Once the joint is established the kinematics of the bodies can be accessed. + First the ``rotation_axis``, which is normal to the plane and the + ``plane_vectors``, can be found. + + >>> joint.rotation_axis + A.x + >>> joint.planar_vectors + [A.y, A.z] + + The direction cosine matrix of the block with respect to the ground can be + found with: + + >>> block.frame.dcm(ground.frame) + Matrix([ + [ cos(a), 0, -sin(a)], + [sin(a)*sin(q0_PC(t)), cos(q0_PC(t)), sin(q0_PC(t))*cos(a)], + [sin(a)*cos(q0_PC(t)), -sin(q0_PC(t)), cos(a)*cos(q0_PC(t))]]) + + The angular velocity of the block can be computed with respect to the + ground. + + >>> block.frame.ang_vel_in(ground.frame) + u0_PC(t)*A.x + + The position of the block's center of mass can be found with: + + >>> block.masscenter.pos_from(ground.masscenter) + d*G_frame.x + h*B_frame.x + q1_PC(t)*A.y + q2_PC(t)*A.z + + Finally, the linear velocity of the block's center of mass can be + computed with respect to the ground. + + >>> block.masscenter.vel(ground.frame) + u1_PC(t)*A.y + u2_PC(t)*A.z + + In some cases it could be your preference to only define the normals of the + plane with respect to both bodies. This can most easily be done by supplying + vectors to the ``interframe`` arguments. What will happen in this case is + that an interframe will be created with its ``x`` axis aligned with the + provided vector. For a further explanation of how this is done see the notes + of the ``Joint`` class. In the code below, the above example (with the block + on the slope) is recreated by supplying vectors to the interframe arguments. + Note that the previously described option is however more computationally + efficient, because the algorithm now has to compute the rotation angle + between the provided vector and the 'x' axis. + + >>> from sympy import symbols, cos, sin + >>> from sympy.physics.mechanics import PlanarJoint, RigidBody + >>> a, d, h = symbols('a d h') + >>> ground = RigidBody('G') + >>> block = RigidBody('B') + >>> joint = PlanarJoint( + ... 'PC', ground, block, parent_point=d * ground.x, + ... child_point=-h * block.x, child_interframe=block.x, + ... parent_interframe=cos(a) * ground.x + sin(a) * ground.z) + >>> block.frame.dcm(ground.frame).simplify() + Matrix([ + [ cos(a), 0, sin(a)], + [-sin(a)*sin(q0_PC(t)), cos(q0_PC(t)), sin(q0_PC(t))*cos(a)], + [-sin(a)*cos(q0_PC(t)), -sin(q0_PC(t)), cos(a)*cos(q0_PC(t))]]) + + """ + + def __init__(self, name, parent, child, rotation_coordinate=None, + planar_coordinates=None, rotation_speed=None, + planar_speeds=None, parent_point=None, child_point=None, + parent_interframe=None, child_interframe=None): + # A ready to merge implementation of setting the planar_vectors and + # rotation_axis was added and removed in PR #24046 + coordinates = (rotation_coordinate, planar_coordinates) + speeds = (rotation_speed, planar_speeds) + super().__init__(name, parent, child, coordinates, speeds, + parent_point, child_point, + parent_interframe=parent_interframe, + child_interframe=child_interframe) + + def __str__(self): + return (f'PlanarJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + @property + def rotation_coordinate(self): + """Generalized coordinate corresponding to the rotation angle.""" + return self.coordinates[0] + + @property + def planar_coordinates(self): + """Two generalized coordinates used for the planar translation.""" + return self.coordinates[1:, 0] + + @property + def rotation_speed(self): + """Generalized speed corresponding to the angular velocity.""" + return self.speeds[0] + + @property + def planar_speeds(self): + """Two generalized speeds used for the planar translation velocity.""" + return self.speeds[1:, 0] + + @property + def rotation_axis(self): + """The axis about which the rotation occurs.""" + return self.parent_interframe.x + + @property + def planar_vectors(self): + """The vectors that describe the planar translation directions.""" + return [self.parent_interframe.y, self.parent_interframe.z] + + def _generate_coordinates(self, coordinates): + rotation_speed = self._fill_coordinate_list(coordinates[0], 1, 'q', + number_single=True) + planar_speeds = self._fill_coordinate_list(coordinates[1], 2, 'q', 1) + return rotation_speed.col_join(planar_speeds) + + def _generate_speeds(self, speeds): + rotation_speed = self._fill_coordinate_list(speeds[0], 1, 'u', + number_single=True) + planar_speeds = self._fill_coordinate_list(speeds[1], 2, 'u', 1) + return rotation_speed.col_join(planar_speeds) + + def _orient_frames(self): + self.child_interframe.orient_axis( + self.parent_interframe, self.rotation_axis, + self.rotation_coordinate) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel( + self.parent_interframe, + self.rotation_speed * self.rotation_axis) + + def _set_linear_velocity(self): + self.child_point.set_pos( + self.parent_point, + self.planar_coordinates[0] * self.planar_vectors[0] + + self.planar_coordinates[1] * self.planar_vectors[1]) + self.parent_point.set_vel(self.parent_interframe, 0) + self.child_point.set_vel(self.child_interframe, 0) + self.child_point.set_vel( + self._parent_frame, self.planar_speeds[0] * self.planar_vectors[0] + + self.planar_speeds[1] * self.planar_vectors[1]) + self.child.masscenter.v2pt_theory(self.child_point, self._parent_frame, + self._child_frame) + + +class SphericalJoint(Joint): + """Spherical (Ball-and-Socket) Joint. + + .. image:: SphericalJoint.svg + :align: center + :width: 600 + + Explanation + =========== + + A spherical joint is defined such that the child body is free to rotate in + any direction, without allowing a translation of the ``child_point``. As can + also be seen in the image, the ``parent_point`` and ``child_point`` are + fixed on top of each other, i.e. the ``joint_point``. This rotation is + defined using the :func:`parent_interframe.orient(child_interframe, + rot_type, amounts, rot_order) + ` method. The default + rotation consists of three relative rotations, i.e. body-fixed rotations. + Based on the direction cosine matrix following from these rotations, the + angular velocity is computed based on the generalized coordinates and + generalized speeds. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + coordinates: iterable of dynamicsymbols, optional + Generalized coordinates of the joint. + speeds : iterable of dynamicsymbols, optional + Generalized speeds of joint. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + rot_type : str, optional + The method used to generate the direction cosine matrix. Supported + methods are: + + - ``'Body'``: three successive rotations about new intermediate axes, + also called "Euler and Tait-Bryan angles" + - ``'Space'``: three successive rotations about the parent frames' unit + vectors + + The default method is ``'Body'``. + amounts : + Expressions defining the rotation angles or direction cosine matrix. + These must match the ``rot_type``. See examples below for details. The + input types are: + + - ``'Body'``: 3-tuple of expressions, symbols, or functions + - ``'Space'``: 3-tuple of expressions, symbols, or functions + + The default amounts are the given ``coordinates``. + rot_order : str or int, optional + If applicable, the order of the successive of rotations. The string + ``'123'`` and integer ``123`` are equivalent, for example. Required for + ``'Body'`` and ``'Space'``. The default value is ``123``. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + + Examples + ========= + + A single spherical joint is created from two bodies and has the following + basic attributes: + + >>> from sympy.physics.mechanics import RigidBody, SphericalJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = SphericalJoint('PC', parent, child) + >>> joint + SphericalJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.parent_interframe + P_frame + >>> joint.child_interframe + C_frame + >>> joint.coordinates + Matrix([ + [q0_PC(t)], + [q1_PC(t)], + [q2_PC(t)]]) + >>> joint.speeds + Matrix([ + [u0_PC(t)], + [u1_PC(t)], + [u2_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame).to_matrix(child.frame) + Matrix([ + [ u0_PC(t)*cos(q1_PC(t))*cos(q2_PC(t)) + u1_PC(t)*sin(q2_PC(t))], + [-u0_PC(t)*sin(q2_PC(t))*cos(q1_PC(t)) + u1_PC(t)*cos(q2_PC(t))], + [ u0_PC(t)*sin(q1_PC(t)) + u2_PC(t)]]) + >>> child.frame.x.to_matrix(parent.frame) + Matrix([ + [ cos(q1_PC(t))*cos(q2_PC(t))], + [sin(q0_PC(t))*sin(q1_PC(t))*cos(q2_PC(t)) + sin(q2_PC(t))*cos(q0_PC(t))], + [sin(q0_PC(t))*sin(q2_PC(t)) - sin(q1_PC(t))*cos(q0_PC(t))*cos(q2_PC(t))]]) + >>> joint.child_point.pos_from(joint.parent_point) + 0 + + To further demonstrate the use of the spherical joint, the kinematics of a + spherical joint with a ZXZ rotation can be created as follows. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import RigidBody, SphericalJoint + >>> l1 = symbols('l1') + + First create bodies to represent the fixed floor and a pendulum bob. + + >>> floor = RigidBody('F') + >>> bob = RigidBody('B') + + The joint will connect the bob to the floor, with the joint located at a + distance of ``l1`` from the child's center of mass and the rotation set to a + body-fixed ZXZ rotation. + + >>> joint = SphericalJoint('S', floor, bob, child_point=l1 * bob.y, + ... rot_type='body', rot_order='ZXZ') + + Now that the joint is established, the kinematics of the connected body can + be accessed. + + The position of the bob's masscenter is found with: + + >>> bob.masscenter.pos_from(floor.masscenter) + - l1*B_frame.y + + The angular velocities of the pendulum link can be computed with respect to + the floor. + + >>> bob.frame.ang_vel_in(floor.frame).to_matrix( + ... floor.frame).simplify() + Matrix([ + [u1_S(t)*cos(q0_S(t)) + u2_S(t)*sin(q0_S(t))*sin(q1_S(t))], + [u1_S(t)*sin(q0_S(t)) - u2_S(t)*sin(q1_S(t))*cos(q0_S(t))], + [ u0_S(t) + u2_S(t)*cos(q1_S(t))]]) + + Finally, the linear velocity of the bob's center of mass can be computed. + + >>> bob.masscenter.vel(floor.frame).to_matrix(bob.frame) + Matrix([ + [ l1*(u0_S(t)*cos(q1_S(t)) + u2_S(t))], + [ 0], + [-l1*(u0_S(t)*sin(q1_S(t))*sin(q2_S(t)) + u1_S(t)*cos(q2_S(t)))]]) + + """ + def __init__(self, name, parent, child, coordinates=None, speeds=None, + parent_point=None, child_point=None, parent_interframe=None, + child_interframe=None, rot_type='BODY', amounts=None, + rot_order=123): + self._rot_type = rot_type + self._amounts = amounts + self._rot_order = rot_order + super().__init__(name, parent, child, coordinates, speeds, + parent_point, child_point, + parent_interframe=parent_interframe, + child_interframe=child_interframe) + + def __str__(self): + return (f'SphericalJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + def _generate_coordinates(self, coordinates): + return self._fill_coordinate_list(coordinates, 3, 'q') + + def _generate_speeds(self, speeds): + return self._fill_coordinate_list(speeds, len(self.coordinates), 'u') + + def _orient_frames(self): + supported_rot_types = ('BODY', 'SPACE') + if self._rot_type.upper() not in supported_rot_types: + raise NotImplementedError( + f'Rotation type "{self._rot_type}" is not implemented. ' + f'Implemented rotation types are: {supported_rot_types}') + amounts = self.coordinates if self._amounts is None else self._amounts + self.child_interframe.orient(self.parent_interframe, self._rot_type, + amounts, self._rot_order) + + def _set_angular_velocity(self): + t = dynamicsymbols._t + vel = self.child_interframe.ang_vel_in(self.parent_interframe).xreplace( + {q.diff(t): u for q, u in zip(self.coordinates, self.speeds)} + ) + self.child_interframe.set_ang_vel(self.parent_interframe, vel) + + def _set_linear_velocity(self): + self.child_point.set_pos(self.parent_point, 0) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child.masscenter.v2pt_theory(self.parent_point, self._parent_frame, + self._child_frame) + + +class WeldJoint(Joint): + """Weld Joint. + + .. raw:: html + :file: ../../../doc/src/modules/physics/mechanics/api/WeldJoint.svg + + Explanation + =========== + + A weld joint is defined such that there is no relative motion between the + child and parent bodies. The direction cosine matrix between the attachment + frame (``parent_interframe`` and ``child_interframe``) is the identity + matrix and the attachment points (``parent_point`` and ``child_point``) are + coincident. The page on the joints framework gives a more detailed + explanation of the intermediate frames. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. The default value is + ``dynamicsymbols(f'q_{joint.name}')``. + speeds : Matrix + Matrix of the joint's generalized speeds. The default value is + ``dynamicsymbols(f'u_{joint.name}')``. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + + Examples + ========= + + A single weld joint is created from two bodies and has the following basic + attributes: + + >>> from sympy.physics.mechanics import RigidBody, WeldJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = WeldJoint('PC', parent, child) + >>> joint + WeldJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.coordinates + Matrix(0, 0, []) + >>> joint.speeds + Matrix(0, 0, []) + >>> child.frame.ang_vel_in(parent.frame) + 0 + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> joint.child_point.pos_from(joint.parent_point) + 0 + + To further demonstrate the use of the weld joint, two relatively-fixed + bodies rotated by a quarter turn about the Y axis can be created as follows: + + >>> from sympy import symbols, pi + >>> from sympy.physics.mechanics import ReferenceFrame, RigidBody, WeldJoint + >>> l1, l2 = symbols('l1 l2') + + First create the bodies to represent the parent and rotated child body. + + >>> parent = RigidBody('P') + >>> child = RigidBody('C') + + Next the intermediate frame specifying the fixed rotation with respect to + the parent can be created. + + >>> rotated_frame = ReferenceFrame('Pr') + >>> rotated_frame.orient_axis(parent.frame, parent.y, pi / 2) + + The weld between the parent body and child body is located at a distance + ``l1`` from the parent's center of mass in the X direction and ``l2`` from + the child's center of mass in the child's negative X direction. + + >>> weld = WeldJoint('weld', parent, child, parent_point=l1 * parent.x, + ... child_point=-l2 * child.x, + ... parent_interframe=rotated_frame) + + Now that the joint has been established, the kinematics of the bodies can be + accessed. The direction cosine matrix of the child body with respect to the + parent can be found: + + >>> child.frame.dcm(parent.frame) + Matrix([ + [0, 0, -1], + [0, 1, 0], + [1, 0, 0]]) + + As can also been seen from the direction cosine matrix, the parent X axis is + aligned with the child's Z axis: + >>> parent.x == child.z + True + + The position of the child's center of mass with respect to the parent's + center of mass can be found with: + + >>> child.masscenter.pos_from(parent.masscenter) + l1*P_frame.x + l2*C_frame.x + + The angular velocity of the child with respect to the parent is 0 as one + would expect. + + >>> child.frame.ang_vel_in(parent.frame) + 0 + + """ + + def __init__(self, name, parent, child, parent_point=None, child_point=None, + parent_interframe=None, child_interframe=None): + super().__init__(name, parent, child, [], [], parent_point, + child_point, parent_interframe=parent_interframe, + child_interframe=child_interframe) + self._kdes = Matrix(1, 0, []).T # Removes stackability problems #10770 + + def __str__(self): + return (f'WeldJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + def _generate_coordinates(self, coordinate): + return Matrix() + + def _generate_speeds(self, speed): + return Matrix() + + def _orient_frames(self): + self.child_interframe.orient_axis(self.parent_interframe, + self.parent_interframe.x, 0) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel(self.parent_interframe, 0) + + def _set_linear_velocity(self): + self.child_point.set_pos(self.parent_point, 0) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child.masscenter.set_vel(self._parent_frame, 0) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/jointsmethod.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/jointsmethod.py new file mode 100644 index 0000000000000000000000000000000000000000..df7bd56360072feb57a65e5f78c2d116f0d4842d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/jointsmethod.py @@ -0,0 +1,318 @@ +from sympy.physics.mechanics import (Body, Lagrangian, KanesMethod, LagrangesMethod, + RigidBody, Particle) +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.method import _Methods +from sympy import Matrix +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['JointsMethod'] + + +class JointsMethod(_Methods): + """Method for formulating the equations of motion using a set of interconnected bodies with joints. + + .. deprecated:: 1.13 + The JointsMethod class is deprecated. Its functionality has been + replaced by the new :class:`~.System` class. + + Parameters + ========== + + newtonion : Body or ReferenceFrame + The newtonion(inertial) frame. + *joints : Joint + The joints in the system + + Attributes + ========== + + q, u : iterable + Iterable of the generalized coordinates and speeds + bodies : iterable + Iterable of Body objects in the system. + loads : iterable + Iterable of (Point, vector) or (ReferenceFrame, vector) tuples + describing the forces on the system. + mass_matrix : Matrix, shape(n, n) + The system's mass matrix + forcing : Matrix, shape(n, 1) + The system's forcing vector + mass_matrix_full : Matrix, shape(2*n, 2*n) + The "mass matrix" for the u's and q's + forcing_full : Matrix, shape(2*n, 1) + The "forcing vector" for the u's and q's + method : KanesMethod or Lagrange's method + Method's object. + kdes : iterable + Iterable of kde in they system. + + Examples + ======== + + As Body and JointsMethod have been deprecated, the following examples are + for illustrative purposes only. The functionality of Body is fully captured + by :class:`~.RigidBody` and :class:`~.Particle` and the functionality of + JointsMethod is fully captured by :class:`~.System`. To ignore the + deprecation warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Body, JointsMethod, PrismaticJoint + >>> from sympy.physics.vector import dynamicsymbols + >>> c, k = symbols('c k') + >>> x, v = dynamicsymbols('x v') + >>> with ignore_warnings(DeprecationWarning): + ... wall = Body('W') + ... body = Body('B') + >>> J = PrismaticJoint('J', wall, body, coordinates=x, speeds=v) + >>> wall.apply_force(c*v*wall.x, reaction_body=body) + >>> wall.apply_force(k*x*wall.x, reaction_body=body) + >>> with ignore_warnings(DeprecationWarning): + ... method = JointsMethod(wall, J) + >>> method.form_eoms() + Matrix([[-B_mass*Derivative(v(t), t) - c*v(t) - k*x(t)]]) + >>> M = method.mass_matrix_full + >>> F = method.forcing_full + >>> rhs = M.LUsolve(F) + >>> rhs + Matrix([ + [ v(t)], + [(-c*v(t) - k*x(t))/B_mass]]) + + Notes + ===== + + ``JointsMethod`` currently only works with systems that do not have any + configuration or motion constraints. + + """ + + def __init__(self, newtonion, *joints): + sympy_deprecation_warning( + """ + The JointsMethod class is deprecated. + Its functionality has been replaced by the new System class. + """, + deprecated_since_version="1.13", + active_deprecations_target="deprecated-mechanics-jointsmethod" + ) + if isinstance(newtonion, BodyBase): + self.frame = newtonion.frame + else: + self.frame = newtonion + + self._joints = joints + self._bodies = self._generate_bodylist() + self._loads = self._generate_loadlist() + self._q = self._generate_q() + self._u = self._generate_u() + self._kdes = self._generate_kdes() + + self._method = None + + @property + def bodies(self): + """List of bodies in they system.""" + return self._bodies + + @property + def loads(self): + """List of loads on the system.""" + return self._loads + + @property + def q(self): + """List of the generalized coordinates.""" + return self._q + + @property + def u(self): + """List of the generalized speeds.""" + return self._u + + @property + def kdes(self): + """List of the generalized coordinates.""" + return self._kdes + + @property + def forcing_full(self): + """The "forcing vector" for the u's and q's.""" + return self.method.forcing_full + + @property + def mass_matrix_full(self): + """The "mass matrix" for the u's and q's.""" + return self.method.mass_matrix_full + + @property + def mass_matrix(self): + """The system's mass matrix.""" + return self.method.mass_matrix + + @property + def forcing(self): + """The system's forcing vector.""" + return self.method.forcing + + @property + def method(self): + """Object of method used to form equations of systems.""" + return self._method + + def _generate_bodylist(self): + bodies = [] + for joint in self._joints: + if joint.child not in bodies: + bodies.append(joint.child) + if joint.parent not in bodies: + bodies.append(joint.parent) + return bodies + + def _generate_loadlist(self): + load_list = [] + for body in self.bodies: + if isinstance(body, Body): + load_list.extend(body.loads) + return load_list + + def _generate_q(self): + q_ind = [] + for joint in self._joints: + for coordinate in joint.coordinates: + if coordinate in q_ind: + raise ValueError('Coordinates of joints should be unique.') + q_ind.append(coordinate) + return Matrix(q_ind) + + def _generate_u(self): + u_ind = [] + for joint in self._joints: + for speed in joint.speeds: + if speed in u_ind: + raise ValueError('Speeds of joints should be unique.') + u_ind.append(speed) + return Matrix(u_ind) + + def _generate_kdes(self): + kd_ind = Matrix(1, 0, []).T + for joint in self._joints: + kd_ind = kd_ind.col_join(joint.kdes) + return kd_ind + + def _convert_bodies(self): + # Convert `Body` to `Particle` and `RigidBody` + bodylist = [] + for body in self.bodies: + if not isinstance(body, Body): + bodylist.append(body) + continue + if body.is_rigidbody: + rb = RigidBody(body.name, body.masscenter, body.frame, body.mass, + (body.central_inertia, body.masscenter)) + rb.potential_energy = body.potential_energy + bodylist.append(rb) + else: + part = Particle(body.name, body.masscenter, body.mass) + part.potential_energy = body.potential_energy + bodylist.append(part) + return bodylist + + def form_eoms(self, method=KanesMethod): + """Method to form system's equation of motions. + + Parameters + ========== + + method : Class + Class name of method. + + Returns + ======== + + Matrix + Vector of equations of motions. + + Examples + ======== + + As Body and JointsMethod have been deprecated, the following examples + are for illustrative purposes only. The functionality of Body is fully + captured by :class:`~.RigidBody` and :class:`~.Particle` and the + functionality of JointsMethod is fully captured by :class:`~.System`. To + ignore the deprecation warning we can use the ignore_warnings context + manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + >>> from sympy import S, symbols + >>> from sympy.physics.mechanics import LagrangesMethod, dynamicsymbols, Body + >>> from sympy.physics.mechanics import PrismaticJoint, JointsMethod + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> m, k, b = symbols('m k b') + >>> with ignore_warnings(DeprecationWarning): + ... wall = Body('W') + ... part = Body('P', mass=m) + >>> part.potential_energy = k * q**2 / S(2) + >>> J = PrismaticJoint('J', wall, part, coordinates=q, speeds=qd) + >>> wall.apply_force(b * qd * wall.x, reaction_body=part) + >>> with ignore_warnings(DeprecationWarning): + ... method = JointsMethod(wall, J) + >>> method.form_eoms(LagrangesMethod) + Matrix([[b*Derivative(q(t), t) + k*q(t) + m*Derivative(q(t), (t, 2))]]) + + We can also solve for the states using the 'rhs' method. + + >>> method.rhs() + Matrix([ + [ Derivative(q(t), t)], + [(-b*Derivative(q(t), t) - k*q(t))/m]]) + + """ + + bodylist = self._convert_bodies() + if issubclass(method, LagrangesMethod): #LagrangesMethod or similar + L = Lagrangian(self.frame, *bodylist) + self._method = method(L, self.q, self.loads, bodylist, self.frame) + else: #KanesMethod or similar + self._method = method(self.frame, q_ind=self.q, u_ind=self.u, kd_eqs=self.kdes, + forcelist=self.loads, bodies=bodylist) + soln = self.method._form_eoms() + return soln + + def rhs(self, inv_method=None): + """Returns equations that can be solved numerically. + + Parameters + ========== + + inv_method : str + The specific sympy inverse matrix calculation method to use. For a + list of valid methods, see + :meth:`~sympy.matrices.matrixbase.MatrixBase.inv` + + Returns + ======== + + Matrix + Numerically solvable equations. + + See Also + ======== + + sympy.physics.mechanics.kane.KanesMethod.rhs: + KanesMethod's rhs function. + sympy.physics.mechanics.lagrange.LagrangesMethod.rhs: + LagrangesMethod's rhs function. + + """ + + return self.method.rhs(inv_method=inv_method) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/kane.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/kane.py new file mode 100644 index 0000000000000000000000000000000000000000..805587a4fe9d7696f45c5815ee5406b103150698 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/kane.py @@ -0,0 +1,859 @@ +from sympy import zeros, Matrix, diff, eye, linear_eq_to_matrix +from sympy.core.sorting import default_sort_key +from sympy.physics.vector import (ReferenceFrame, dynamicsymbols, + partial_velocity) +from sympy.physics.mechanics.method import _Methods +from sympy.physics.mechanics.particle import Particle +from sympy.physics.mechanics.rigidbody import RigidBody +from sympy.physics.mechanics.functions import (msubs, find_dynamicsymbols, + _f_list_parser, + _validate_coordinates, + _parse_linear_solver) +from sympy.physics.mechanics.linearize import Linearizer +from sympy.utilities.iterables import iterable + + +__all__ = ['KanesMethod'] + + +class KanesMethod(_Methods): + r"""Kane's method object. + + Explanation + =========== + + This object is used to do the "book-keeping" as you go through and form + equations of motion in the way Kane presents in: + Kane, T., Levinson, D. Dynamics Theory and Applications. 1985 McGraw-Hill + + The attributes are for equations in the form [M] udot = forcing. + + Attributes + ========== + + q, u : Matrix + Matrices of the generalized coordinates and speeds + bodies : iterable + Iterable of Particle and RigidBody objects in the system. + loads : iterable + Iterable of (Point, vector) or (ReferenceFrame, vector) tuples + describing the forces on the system. + auxiliary_eqs : Matrix + If applicable, the set of auxiliary Kane's + equations used to solve for non-contributing + forces. + mass_matrix : Matrix + The system's dynamics mass matrix: [k_d; k_dnh] + forcing : Matrix + The system's dynamics forcing vector: -[f_d; f_dnh] + mass_matrix_kin : Matrix + The "mass matrix" for kinematic differential equations: k_kqdot + forcing_kin : Matrix + The forcing vector for kinematic differential equations: -(k_ku*u + f_k) + mass_matrix_full : Matrix + The "mass matrix" for the u's and q's with dynamics and kinematics + forcing_full : Matrix + The "forcing vector" for the u's and q's with dynamics and kinematics + + Parameters + ========== + + frame : ReferenceFrame + The inertial reference frame for the system. + q_ind : iterable of dynamicsymbols + Independent generalized coordinates. + u_ind : iterable of dynamicsymbols + Independent generalized speeds. + kd_eqs : iterable of Expr, optional + Kinematic differential equations, which linearly relate the generalized + speeds to the time-derivatives of the generalized coordinates. + q_dependent : iterable of dynamicsymbols, optional + Dependent generalized coordinates. + configuration_constraints : iterable of Expr, optional + Constraints on the system's configuration, i.e. holonomic constraints. + u_dependent : iterable of dynamicsymbols, optional + Dependent generalized speeds. + velocity_constraints : iterable of Expr, optional + Constraints on the system's velocity, i.e. the combination of the + nonholonomic constraints and the time-derivative of the holonomic + constraints. + acceleration_constraints : iterable of Expr, optional + Constraints on the system's acceleration, by default these are the + time-derivative of the velocity constraints. + u_auxiliary : iterable of dynamicsymbols, optional + Auxiliary generalized speeds. + bodies : iterable of Particle and/or RigidBody, optional + The particles and rigid bodies in the system. + forcelist : iterable of tuple[Point | ReferenceFrame, Vector], optional + Forces and torques applied on the system. + explicit_kinematics : bool + Boolean whether the mass matrices and forcing vectors should use the + explicit form (default) or implicit form for kinematics. + See the notes for more details. + kd_eqs_solver : str, callable + Method used to solve the kinematic differential equations. If a string + is supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``f(A, rhs)``, where it solves the + equations and returns the solution. The default utilizes LU solve. See + the notes for more information. + constraint_solver : str, callable + Method used to solve the velocity constraints. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``f(A, rhs)``, where it solves the + equations and returns the solution. The default utilizes LU solve. See + the notes for more information. + + Notes + ===== + + The mass matrices and forcing vectors related to kinematic equations + are given in the explicit form by default. In other words, the kinematic + mass matrix is $\mathbf{k_{k\dot{q}}} = \mathbf{I}$. + In order to get the implicit form of those matrices/vectors, you can set the + ``explicit_kinematics`` attribute to ``False``. So $\mathbf{k_{k\dot{q}}}$ + is not necessarily an identity matrix. This can provide more compact + equations for non-simple kinematics. + + Two linear solvers can be supplied to ``KanesMethod``: one for solving the + kinematic differential equations and one to solve the velocity constraints. + Both of these sets of equations can be expressed as a linear system ``Ax = rhs``, + which have to be solved in order to obtain the equations of motion. + + The default solver ``'LU'``, which stands for LU solve, results relatively low + number of operations. The weakness of this method is that it can result in zero + division errors. + + If zero divisions are encountered, a possible solver which may solve the problem + is ``"CRAMER"``. This method uses Cramer's rule to solve the system. This method + is slower and results in more operations than the default solver. However it only + uses a single division by default per entry of the solution. + + While a valid list of solvers can be found at + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`, it is also possible to supply a + `callable`. This way it is possible to use a different solver routine. If the + kinematic differential equations are not too complex it can be worth it to simplify + the solution by using ``lambda A, b: simplify(Matrix.LUsolve(A, b))``. Another + option solver one may use is :func:`sympy.solvers.solveset.linsolve`. This can be + done using `lambda A, b: tuple(linsolve((A, b)))[0]`, where we select the first + solution as our system should have only one unique solution. + + Examples + ======== + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + In this example, we first need to do the kinematics. + This involves creating generalized speeds and coordinates and their + derivatives. + Then we create a point and set its velocity in a frame. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import dynamicsymbols, ReferenceFrame + >>> from sympy.physics.mechanics import Point, Particle, KanesMethod + >>> q, u = dynamicsymbols('q u') + >>> qd, ud = dynamicsymbols('q u', 1) + >>> m, c, k = symbols('m c k') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, u * N.x) + + Next we need to arrange/store information in the way that KanesMethod + requires. The kinematic differential equations should be an iterable of + expressions. A list of forces/torques must be constructed, where each entry + in the list is a (Point, Vector) or (ReferenceFrame, Vector) tuple, where + the Vectors represent the Force or Torque. + Next a particle needs to be created, and it needs to have a point and mass + assigned to it. + Finally, a list of all bodies and particles needs to be created. + + >>> kd = [qd - u] + >>> FL = [(P, (-k * q - c * u) * N.x)] + >>> pa = Particle('pa', P, m) + >>> BL = [pa] + + Finally we can generate the equations of motion. + First we create the KanesMethod object and supply an inertial frame, + coordinates, generalized speeds, and the kinematic differential equations. + Additional quantities such as configuration and motion constraints, + dependent coordinates and speeds, and auxiliary speeds are also supplied + here (see the online documentation). + Next we form FR* and FR to complete: Fr + Fr* = 0. + We have the equations of motion at this point. + It makes sense to rearrange them though, so we calculate the mass matrix and + the forcing terms, for E.o.M. in the form: [MM] udot = forcing, where MM is + the mass matrix, udot is a vector of the time derivatives of the + generalized speeds, and forcing is a vector representing "forcing" terms. + + >>> KM = KanesMethod(N, q_ind=[q], u_ind=[u], kd_eqs=kd) + >>> (fr, frstar) = KM.kanes_equations(BL, FL) + >>> MM = KM.mass_matrix + >>> forcing = KM.forcing + >>> rhs = MM.inv() * forcing + >>> rhs + Matrix([[(-c*u(t) - k*q(t))/m]]) + >>> KM.linearize(A_and_B=True)[0] + Matrix([ + [ 0, 1], + [-k/m, -c/m]]) + + Please look at the documentation pages for more information on how to + perform linearization and how to deal with dependent coordinates & speeds, + and how do deal with bringing non-contributing forces into evidence. + + """ + + def __init__(self, frame, q_ind, u_ind, kd_eqs=None, q_dependent=None, + configuration_constraints=None, u_dependent=None, + velocity_constraints=None, acceleration_constraints=None, + u_auxiliary=None, bodies=None, forcelist=None, + explicit_kinematics=True, kd_eqs_solver='LU', + constraint_solver='LU'): + + """Please read the online documentation. """ + if not q_ind: + q_ind = [dynamicsymbols('dummy_q')] + kd_eqs = [dynamicsymbols('dummy_kd')] + + if not isinstance(frame, ReferenceFrame): + raise TypeError('An inertial ReferenceFrame must be supplied') + self._inertial = frame + + self._fr = None + self._frstar = None + + self._forcelist = forcelist + self._bodylist = bodies + + self.explicit_kinematics = explicit_kinematics + self._constraint_solver = constraint_solver + self._initialize_vectors(q_ind, q_dependent, u_ind, u_dependent, + u_auxiliary) + _validate_coordinates(self.q, self.u) + self._initialize_kindiffeq_matrices(kd_eqs, kd_eqs_solver) + self._initialize_constraint_matrices( + configuration_constraints, velocity_constraints, + acceleration_constraints, constraint_solver) + + def _initialize_vectors(self, q_ind, q_dep, u_ind, u_dep, u_aux): + """Initialize the coordinate and speed vectors.""" + + none_handler = lambda x: Matrix(x) if x else Matrix() + + # Initialize generalized coordinates + q_dep = none_handler(q_dep) + if not iterable(q_ind): + raise TypeError('Generalized coordinates must be an iterable.') + if not iterable(q_dep): + raise TypeError('Dependent coordinates must be an iterable.') + q_ind = Matrix(q_ind) + self._qdep = q_dep + self._q = Matrix([q_ind, q_dep]) + self._qdot = self.q.diff(dynamicsymbols._t) + + # Initialize generalized speeds + u_dep = none_handler(u_dep) + if not iterable(u_ind): + raise TypeError('Generalized speeds must be an iterable.') + if not iterable(u_dep): + raise TypeError('Dependent speeds must be an iterable.') + u_ind = Matrix(u_ind) + self._udep = u_dep + self._u = Matrix([u_ind, u_dep]) + self._udot = self.u.diff(dynamicsymbols._t) + self._uaux = none_handler(u_aux) + + def _initialize_constraint_matrices(self, config, vel, acc, linear_solver='LU'): + """Initializes constraint matrices.""" + linear_solver = _parse_linear_solver(linear_solver) + # Define vector dimensions + o = len(self.u) + m = len(self._udep) + p = o - m + none_handler = lambda x: Matrix(x) if x else Matrix() + + # Initialize configuration constraints + config = none_handler(config) + if len(self._qdep) != len(config): + raise ValueError('There must be an equal number of dependent ' + 'coordinates and configuration constraints.') + self._f_h = none_handler(config) + + # Initialize velocity and acceleration constraints + vel = none_handler(vel) + acc = none_handler(acc) + if len(vel) != m: + raise ValueError('There must be an equal number of dependent ' + 'speeds and velocity constraints.') + if acc and (len(acc) != m): + raise ValueError('There must be an equal number of dependent ' + 'speeds and acceleration constraints.') + if vel: + + # When calling kanes_equations, another class instance will be + # created if auxiliary u's are present. In this case, the + # computation of kinetic differential equation matrices will be + # skipped as this was computed during the original KanesMethod + # object, and the qd_u_map will not be available. + if self._qdot_u_map is not None: + vel = msubs(vel, self._qdot_u_map) + self._k_nh, f_nh_neg = linear_eq_to_matrix(vel, self.u[:]) + self._f_nh = -f_nh_neg + + # If no acceleration constraints given, calculate them. + if not acc: + _f_dnh = (self._k_nh.diff(dynamicsymbols._t) * self.u + + self._f_nh.diff(dynamicsymbols._t)) + if self._qdot_u_map is not None: + _f_dnh = msubs(_f_dnh, self._qdot_u_map) + self._f_dnh = _f_dnh + self._k_dnh = self._k_nh + else: + if self._qdot_u_map is not None: + acc = msubs(acc, self._qdot_u_map) + + self._k_dnh, f_dnh_neg = linear_eq_to_matrix(acc, self._udot[:]) + self._f_dnh = -f_dnh_neg + # Form of non-holonomic constraints is B*u + C = 0. + # We partition B into independent and dependent columns: + # Ars is then -B_dep.inv() * B_ind, and it relates dependent speeds + # to independent speeds as: udep = Ars*uind, neglecting the C term. + B_ind = self._k_nh[:, :p] + B_dep = self._k_nh[:, p:o] + self._Ars = -linear_solver(B_dep, B_ind) + else: + self._f_nh = Matrix() + self._k_nh = Matrix() + self._f_dnh = Matrix() + self._k_dnh = Matrix() + self._Ars = Matrix() + + def _initialize_kindiffeq_matrices(self, kdeqs, linear_solver='LU'): + """Initialize the kinematic differential equation matrices. + + Parameters + ========== + kdeqs : sequence of sympy expressions + Kinematic differential equations in the form of f(u,q',q,t) where + f() = 0. The equations have to be linear in the time-derivatives of + the generalized coordinates and in the generalized speeds. + + """ + linear_solver = _parse_linear_solver(linear_solver) + if kdeqs: + if len(self.q) != len(kdeqs): + raise ValueError('There must be an equal number of kinematic ' + 'differential equations and coordinates.') + + u = self.u + qdot = self._qdot + + kdeqs = Matrix(kdeqs) + + u_zero = dict.fromkeys(u, 0) + uaux_zero = dict.fromkeys(self._uaux, 0) + qdot_zero = dict.fromkeys(qdot, 0) + + # Extract the linear coefficient matrices as per the following + # equation: + # + # k_ku(q,t)*u(t) + k_kqdot(q,t)*q'(t) + f_k(q,t) = 0 + # + k_ku = kdeqs.jacobian(u) + k_kqdot = kdeqs.jacobian(qdot) + f_k = kdeqs.xreplace(u_zero).xreplace(qdot_zero) + + # The kinematic differential equations should be linear in both q' + # and u so check for u and q' in the components. + dy_syms = find_dynamicsymbols(k_ku.row_join(k_kqdot).row_join(f_k)) + nonlin_vars = [vari for vari in u[:] + qdot[:] if vari in dy_syms] + if nonlin_vars: + msg = ('The provided kinematic differential equations are ' + 'nonlinear in {}. They must be linear in the ' + 'generalized speeds and derivatives of the generalized ' + 'coordinates.') + raise ValueError(msg.format(nonlin_vars)) + + self._f_k_implicit = f_k.xreplace(uaux_zero) + self._k_ku_implicit = k_ku.xreplace(uaux_zero) + self._k_kqdot_implicit = k_kqdot + + # Solve for q'(t) such that the coefficient matrices are now in + # this form: + # + # k_kqdot^-1*k_ku*u(t) + I*q'(t) + k_kqdot^-1*f_k = 0 + # + # NOTE : Solving the kinematic differential equations here is not + # necessary and prevents the equations from being provided in fully + # implicit form. + f_k_explicit = linear_solver(k_kqdot, f_k) + k_ku_explicit = linear_solver(k_kqdot, k_ku) + self._qdot_u_map = dict(zip(qdot, -(k_ku_explicit*u + f_k_explicit))) + + self._f_k = f_k_explicit.xreplace(uaux_zero) + self._k_ku = k_ku_explicit.xreplace(uaux_zero) + self._k_kqdot = eye(len(qdot)) + + else: + self._qdot_u_map = None + self._f_k_implicit = self._f_k = Matrix() + self._k_ku_implicit = self._k_ku = Matrix() + self._k_kqdot_implicit = self._k_kqdot = Matrix() + + def _form_fr(self, fl): + """Form the generalized active force.""" + if fl is not None and (len(fl) == 0 or not iterable(fl)): + raise ValueError('Force pairs must be supplied in an ' + 'non-empty iterable or None.') + + N = self._inertial + # pull out relevant velocities for constructing partial velocities + vel_list, f_list = _f_list_parser(fl, N) + vel_list = [msubs(i, self._qdot_u_map) for i in vel_list] + f_list = [msubs(i, self._qdot_u_map) for i in f_list] + + # Fill Fr with dot product of partial velocities and forces + o = len(self.u) + b = len(f_list) + FR = zeros(o, 1) + partials = partial_velocity(vel_list, self.u, N) + for i in range(o): + FR[i] = sum(partials[j][i].dot(f_list[j]) for j in range(b)) + + # In case there are dependent speeds + if self._udep: + p = o - len(self._udep) + FRtilde = FR[:p, 0] + FRold = FR[p:o, 0] + FRtilde += self._Ars.T * FRold + FR = FRtilde + + self._forcelist = fl + self._fr = FR + return FR + + def _form_frstar(self, bl): + """Form the generalized inertia force.""" + + if not iterable(bl): + raise TypeError('Bodies must be supplied in an iterable.') + + t = dynamicsymbols._t + N = self._inertial + # Dicts setting things to zero + udot_zero = dict.fromkeys(self._udot, 0) + uaux_zero = dict.fromkeys(self._uaux, 0) + uauxdot = [diff(i, t) for i in self._uaux] + uauxdot_zero = dict.fromkeys(uauxdot, 0) + # Dictionary of q' and q'' to u and u' + q_ddot_u_map = {k.diff(t): v.diff(t).xreplace( + self._qdot_u_map) for (k, v) in self._qdot_u_map.items()} + q_ddot_u_map.update(self._qdot_u_map) + + # Fill up the list of partials: format is a list with num elements + # equal to number of entries in body list. Each of these elements is a + # list - either of length 1 for the translational components of + # particles or of length 2 for the translational and rotational + # components of rigid bodies. The inner most list is the list of + # partial velocities. + def get_partial_velocity(body): + if isinstance(body, RigidBody): + vlist = [body.masscenter.vel(N), body.frame.ang_vel_in(N)] + elif isinstance(body, Particle): + vlist = [body.point.vel(N),] + else: + raise TypeError('The body list may only contain either ' + 'RigidBody or Particle as list elements.') + v = [msubs(vel, self._qdot_u_map) for vel in vlist] + return partial_velocity(v, self.u, N) + partials = [get_partial_velocity(body) for body in bl] + + # Compute fr_star in two components: + # fr_star = -(MM*u' + nonMM) + o = len(self.u) + MM = zeros(o, o) + nonMM = zeros(o, 1) + zero_uaux = lambda expr: msubs(expr, uaux_zero) + zero_udot_uaux = lambda expr: msubs(msubs(expr, udot_zero), uaux_zero) + for i, body in enumerate(bl): + if isinstance(body, RigidBody): + M = zero_uaux(body.mass) + I = zero_uaux(body.central_inertia) + vel = zero_uaux(body.masscenter.vel(N)) + omega = zero_uaux(body.frame.ang_vel_in(N)) + acc = zero_udot_uaux(body.masscenter.acc(N)) + inertial_force = (M.diff(t) * vel + M * acc) + inertial_torque = zero_uaux((I.dt(body.frame).dot(omega)) + + msubs(I.dot(body.frame.ang_acc_in(N)), udot_zero) + + (omega.cross(I.dot(omega)))) + for j in range(o): + tmp_vel = zero_uaux(partials[i][0][j]) + tmp_ang = zero_uaux(I.dot(partials[i][1][j])) + for k in range(o): + # translational + MM[j, k] += M*tmp_vel.dot(partials[i][0][k]) + # rotational + MM[j, k] += tmp_ang.dot(partials[i][1][k]) + nonMM[j] += inertial_force.dot(partials[i][0][j]) + nonMM[j] += inertial_torque.dot(partials[i][1][j]) + else: + M = zero_uaux(body.mass) + vel = zero_uaux(body.point.vel(N)) + acc = zero_udot_uaux(body.point.acc(N)) + inertial_force = (M.diff(t) * vel + M * acc) + for j in range(o): + temp = zero_uaux(partials[i][0][j]) + for k in range(o): + MM[j, k] += M*temp.dot(partials[i][0][k]) + nonMM[j] += inertial_force.dot(partials[i][0][j]) + # Compose fr_star out of MM and nonMM + MM = zero_uaux(msubs(MM, q_ddot_u_map)) + nonMM = msubs(msubs(nonMM, q_ddot_u_map), + udot_zero, uauxdot_zero, uaux_zero) + fr_star = -(MM * msubs(Matrix(self._udot), uauxdot_zero) + nonMM) + + # If there are dependent speeds, we need to find fr_star_tilde + if self._udep: + p = o - len(self._udep) + fr_star_ind = fr_star[:p, 0] + fr_star_dep = fr_star[p:o, 0] + fr_star = fr_star_ind + (self._Ars.T * fr_star_dep) + # Apply the same to MM + MMi = MM[:p, :] + MMd = MM[p:o, :] + MM = MMi + (self._Ars.T * MMd) + # Apply the same to nonMM + nonMM = nonMM[:p, :] + (self._Ars.T * nonMM[p:o, :]) + + self._bodylist = bl + self._frstar = fr_star + self._k_d = MM + self._f_d = -(self._fr - nonMM) + return fr_star + + def to_linearizer(self, linear_solver='LU'): + """Returns an instance of the Linearizer class, initiated from the + data in the KanesMethod class. This may be more desirable than using + the linearize class method, as the Linearizer object will allow more + efficient recalculation (i.e. about varying operating points). + + Parameters + ========== + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + + Returns + ======= + Linearizer + An instantiated + :class:`sympy.physics.mechanics.linearize.Linearizer`. + + """ + + if (self._fr is None) or (self._frstar is None): + raise ValueError('Need to compute Fr, Fr* first.') + + # Get required equation components. The Kane's method class breaks + # these into pieces. Need to reassemble + f_c = self._f_h + if self._f_nh and self._k_nh: + f_v = self._f_nh + self._k_nh*Matrix(self.u) + else: + f_v = Matrix() + if self._f_dnh and self._k_dnh: + f_a = self._f_dnh + self._k_dnh*Matrix(self._udot) + else: + f_a = Matrix() + # Dicts to sub to zero, for splitting up expressions + u_zero = dict.fromkeys(self.u, 0) + ud_zero = dict.fromkeys(self._udot, 0) + qd_zero = dict.fromkeys(self._qdot, 0) + qd_u_zero = dict.fromkeys(Matrix([self._qdot, self.u]), 0) + # Break the kinematic differential eqs apart into f_0 and f_1 + f_0 = msubs(self._f_k, u_zero) + self._k_kqdot*Matrix(self._qdot) + f_1 = msubs(self._f_k, qd_zero) + self._k_ku*Matrix(self.u) + # Break the dynamic differential eqs into f_2 and f_3 + f_2 = msubs(self._frstar, qd_u_zero) + f_3 = msubs(self._frstar, ud_zero) + self._fr + f_4 = zeros(len(f_2), 1) + + # Get the required vector components + q = self.q + u = self.u + if self._qdep: + q_i = q[:-len(self._qdep)] + else: + q_i = q + q_d = self._qdep + if self._udep: + u_i = u[:-len(self._udep)] + else: + u_i = u + u_d = self._udep + + # Form dictionary to set auxiliary speeds & their derivatives to 0. + uaux = self._uaux + uauxdot = uaux.diff(dynamicsymbols._t) + uaux_zero = dict.fromkeys(Matrix([uaux, uauxdot]), 0) + + # Checking for dynamic symbols outside the dynamic differential + # equations; throws error if there is. + sym_list = set(Matrix([q, self._qdot, u, self._udot, uaux, uauxdot])) + if any(find_dynamicsymbols(i, sym_list) for i in [self._k_kqdot, + self._k_ku, self._f_k, self._k_dnh, self._f_dnh, self._k_d]): + raise ValueError('Cannot have dynamicsymbols outside dynamic \ + forcing vector.') + + # Find all other dynamic symbols, forming the forcing vector r. + # Sort r to make it canonical. + r = list(find_dynamicsymbols(msubs(self._f_d, uaux_zero), sym_list)) + r.sort(key=default_sort_key) + + # Check for any derivatives of variables in r that are also found in r. + for i in r: + if diff(i, dynamicsymbols._t) in r: + raise ValueError('Cannot have derivatives of specified \ + quantities when linearizing forcing terms.') + return Linearizer(f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a, q, u, q_i, + q_d, u_i, u_d, r, linear_solver=linear_solver) + + # TODO : Remove `new_method` after 1.1 has been released. + def linearize(self, *, new_method=None, linear_solver='LU', **kwargs): + """ Linearize the equations of motion about a symbolic operating point. + + Parameters + ========== + new_method + Deprecated, does nothing and will be removed. + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + **kwargs + Extra keyword arguments are passed to + :meth:`sympy.physics.mechanics.linearize.Linearizer.linearize`. + + Explanation + =========== + + If kwarg A_and_B is False (default), returns M, A, B, r for the + linearized form, M*[q', u']^T = A*[q_ind, u_ind]^T + B*r. + + If kwarg A_and_B is True, returns A, B, r for the linearized form + dx = A*x + B*r, where x = [q_ind, u_ind]^T. Note that this is + computationally intensive if there are many symbolic parameters. For + this reason, it may be more desirable to use the default A_and_B=False, + returning M, A, and B. Values may then be substituted in to these + matrices, and the state space form found as + A = P.T*M.inv()*A, B = P.T*M.inv()*B, where P = Linearizer.perm_mat. + + In both cases, r is found as all dynamicsymbols in the equations of + motion that are not part of q, u, q', or u'. They are sorted in + canonical form. + + The operating points may be also entered using the ``op_point`` kwarg. + This takes a dictionary of {symbol: value}, or a an iterable of such + dictionaries. The values may be numeric or symbolic. The more values + you can specify beforehand, the faster this computation will run. + + For more documentation, please see the ``Linearizer`` class. + + """ + + linearizer = self.to_linearizer(linear_solver=linear_solver) + result = linearizer.linearize(**kwargs) + return result + (linearizer.r,) + + def kanes_equations(self, bodies=None, loads=None): + """ Method to form Kane's equations, Fr + Fr* = 0. + + Explanation + =========== + + Returns (Fr, Fr*). In the case where auxiliary generalized speeds are + present (say, s auxiliary speeds, o generalized speeds, and m motion + constraints) the length of the returned vectors will be o - m + s in + length. The first o - m equations will be the constrained Kane's + equations, then the s auxiliary Kane's equations. These auxiliary + equations can be accessed with the auxiliary_eqs property. + + Parameters + ========== + + bodies : iterable + An iterable of all RigidBody's and Particle's in the system. + A system must have at least one body. + loads : iterable + Takes in an iterable of (Particle, Vector) or (ReferenceFrame, Vector) + tuples which represent the force at a point or torque on a frame. + Must be either a non-empty iterable of tuples or None which corresponds + to a system with no constraints. + """ + if bodies is None: + bodies = self.bodies + if loads is None and self._forcelist is not None: + loads = self._forcelist + if loads == []: + loads = None + if not self._k_kqdot: + raise AttributeError('Create an instance of KanesMethod with ' + 'kinematic differential equations to use this method.') + fr = self._form_fr(loads) + frstar = self._form_frstar(bodies) + if self._uaux: + if not self._udep: + km = KanesMethod(self._inertial, self.q, self._uaux, + u_auxiliary=self._uaux, constraint_solver=self._constraint_solver) + else: + km = KanesMethod(self._inertial, self.q, self._uaux, + u_auxiliary=self._uaux, u_dependent=self._udep, + velocity_constraints=(self._k_nh * self.u + + self._f_nh), + acceleration_constraints=(self._k_dnh * self._udot + + self._f_dnh), + constraint_solver=self._constraint_solver + ) + km._qdot_u_map = self._qdot_u_map + self._km = km + fraux = km._form_fr(loads) + frstaraux = km._form_frstar(bodies) + self._aux_eq = fraux + frstaraux + self._fr = fr.col_join(fraux) + self._frstar = frstar.col_join(frstaraux) + return (self._fr, self._frstar) + + def _form_eoms(self): + fr, frstar = self.kanes_equations(self.bodylist, self.forcelist) + return fr + frstar + + def rhs(self, inv_method=None): + """Returns the system's equations of motion in first order form. The + output is the right hand side of:: + + x' = |q'| =: f(q, u, r, p, t) + |u'| + + The right hand side is what is needed by most numerical ODE + integrators. + + Parameters + ========== + + inv_method : str + The specific sympy inverse matrix calculation method to use. For a + list of valid methods, see + :meth:`~sympy.matrices.matrixbase.MatrixBase.inv` + + """ + rhs = zeros(len(self.q) + len(self.u), 1) + kdes = self.kindiffdict() + for i, q_i in enumerate(self.q): + rhs[i] = kdes[q_i.diff()] + + if inv_method is None: + rhs[len(self.q):, 0] = self.mass_matrix.LUsolve(self.forcing) + else: + rhs[len(self.q):, 0] = (self.mass_matrix.inv(inv_method, + try_block_diag=True) * + self.forcing) + + return rhs + + def kindiffdict(self): + """Returns a dictionary mapping q' to u.""" + if not self._qdot_u_map: + raise AttributeError('Create an instance of KanesMethod with ' + 'kinematic differential equations to use this method.') + return self._qdot_u_map + + @property + def auxiliary_eqs(self): + """A matrix containing the auxiliary equations.""" + if not self._fr or not self._frstar: + raise ValueError('Need to compute Fr, Fr* first.') + if not self._uaux: + raise ValueError('No auxiliary speeds have been declared.') + return self._aux_eq + + @property + def mass_matrix_kin(self): + r"""The kinematic "mass matrix" $\mathbf{k_{k\dot{q}}}$ of the system.""" + return self._k_kqdot if self.explicit_kinematics else self._k_kqdot_implicit + + @property + def forcing_kin(self): + """The kinematic "forcing vector" of the system.""" + if self.explicit_kinematics: + return -(self._k_ku * Matrix(self.u) + self._f_k) + else: + return -(self._k_ku_implicit * Matrix(self.u) + self._f_k_implicit) + + @property + def mass_matrix(self): + """The mass matrix of the system.""" + if not self._fr or not self._frstar: + raise ValueError('Need to compute Fr, Fr* first.') + return Matrix([self._k_d, self._k_dnh]) + + @property + def forcing(self): + """The forcing vector of the system.""" + if not self._fr or not self._frstar: + raise ValueError('Need to compute Fr, Fr* first.') + return -Matrix([self._f_d, self._f_dnh]) + + @property + def mass_matrix_full(self): + """The mass matrix of the system, augmented by the kinematic + differential equations in explicit or implicit form.""" + if not self._fr or not self._frstar: + raise ValueError('Need to compute Fr, Fr* first.') + o, n = len(self.u), len(self.q) + return (self.mass_matrix_kin.row_join(zeros(n, o))).col_join( + zeros(o, n).row_join(self.mass_matrix)) + + @property + def forcing_full(self): + """The forcing vector of the system, augmented by the kinematic + differential equations in explicit or implicit form.""" + return Matrix([self.forcing_kin, self.forcing]) + + @property + def q(self): + return self._q + + @property + def u(self): + return self._u + + @property + def bodylist(self): + return self._bodylist + + @property + def forcelist(self): + return self._forcelist + + @property + def bodies(self): + return self._bodylist + + @property + def loads(self): + return self._forcelist diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/lagrange.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/lagrange.py new file mode 100644 index 0000000000000000000000000000000000000000..282176a404f77762abc3ee8c6a575519b2de1f02 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/lagrange.py @@ -0,0 +1,512 @@ +from sympy import diff, zeros, Matrix, eye, sympify +from sympy.core.sorting import default_sort_key +from sympy.physics.vector import dynamicsymbols, ReferenceFrame +from sympy.physics.mechanics.method import _Methods +from sympy.physics.mechanics.functions import ( + find_dynamicsymbols, msubs, _f_list_parser, _validate_coordinates) +from sympy.physics.mechanics.linearize import Linearizer +from sympy.utilities.iterables import iterable + +__all__ = ['LagrangesMethod'] + + +class LagrangesMethod(_Methods): + """Lagrange's method object. + + Explanation + =========== + + This object generates the equations of motion in a two step procedure. The + first step involves the initialization of LagrangesMethod by supplying the + Lagrangian and the generalized coordinates, at the bare minimum. If there + are any constraint equations, they can be supplied as keyword arguments. + The Lagrange multipliers are automatically generated and are equal in + number to the constraint equations. Similarly any non-conservative forces + can be supplied in an iterable (as described below and also shown in the + example) along with a ReferenceFrame. This is also discussed further in the + __init__ method. + + Attributes + ========== + + q, u : Matrix + Matrices of the generalized coordinates and speeds + loads : iterable + Iterable of (Point, vector) or (ReferenceFrame, vector) tuples + describing the forces on the system. + bodies : iterable + Iterable containing the rigid bodies and particles of the system. + mass_matrix : Matrix + The system's mass matrix + forcing : Matrix + The system's forcing vector + mass_matrix_full : Matrix + The "mass matrix" for the qdot's, qdoubledot's, and the + lagrange multipliers (lam) + forcing_full : Matrix + The forcing vector for the qdot's, qdoubledot's and + lagrange multipliers (lam) + + Examples + ======== + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + In this example, we first need to do the kinematics. + This involves creating generalized coordinates and their derivatives. + Then we create a point and set its velocity in a frame. + + >>> from sympy.physics.mechanics import LagrangesMethod, Lagrangian + >>> from sympy.physics.mechanics import ReferenceFrame, Particle, Point + >>> from sympy.physics.mechanics import dynamicsymbols + >>> from sympy import symbols + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> m, k, b = symbols('m k b') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, qd * N.x) + + We need to then prepare the information as required by LagrangesMethod to + generate equations of motion. + First we create the Particle, which has a point attached to it. + Following this the lagrangian is created from the kinetic and potential + energies. + Then, an iterable of nonconservative forces/torques must be constructed, + where each item is a (Point, Vector) or (ReferenceFrame, Vector) tuple, + with the Vectors representing the nonconservative forces or torques. + + >>> Pa = Particle('Pa', P, m) + >>> Pa.potential_energy = k * q**2 / 2.0 + >>> L = Lagrangian(N, Pa) + >>> fl = [(P, -b * qd * N.x)] + + Finally we can generate the equations of motion. + First we create the LagrangesMethod object. To do this one must supply + the Lagrangian, and the generalized coordinates. The constraint equations, + the forcelist, and the inertial frame may also be provided, if relevant. + Next we generate Lagrange's equations of motion, such that: + Lagrange's equations of motion = 0. + We have the equations of motion at this point. + + >>> l = LagrangesMethod(L, [q], forcelist = fl, frame = N) + >>> print(l.form_lagranges_equations()) + Matrix([[b*Derivative(q(t), t) + 1.0*k*q(t) + m*Derivative(q(t), (t, 2))]]) + + We can also solve for the states using the 'rhs' method. + + >>> print(l.rhs()) + Matrix([[Derivative(q(t), t)], [(-b*Derivative(q(t), t) - 1.0*k*q(t))/m]]) + + Please refer to the docstrings on each method for more details. + """ + + def __init__(self, Lagrangian, qs, forcelist=None, bodies=None, frame=None, + hol_coneqs=None, nonhol_coneqs=None): + """Supply the following for the initialization of LagrangesMethod. + + Lagrangian : Sympifyable + + qs : array_like + The generalized coordinates + + hol_coneqs : array_like, optional + The holonomic constraint equations + + nonhol_coneqs : array_like, optional + The nonholonomic constraint equations + + forcelist : iterable, optional + Takes an iterable of (Point, Vector) or (ReferenceFrame, Vector) + tuples which represent the force at a point or torque on a frame. + This feature is primarily to account for the nonconservative forces + and/or moments. + + bodies : iterable, optional + Takes an iterable containing the rigid bodies and particles of the + system. + + frame : ReferenceFrame, optional + Supply the inertial frame. This is used to determine the + generalized forces due to non-conservative forces. + """ + + self._L = Matrix([sympify(Lagrangian)]) + self.eom = None + self._m_cd = Matrix() # Mass Matrix of differentiated coneqs + self._m_d = Matrix() # Mass Matrix of dynamic equations + self._f_cd = Matrix() # Forcing part of the diff coneqs + self._f_d = Matrix() # Forcing part of the dynamic equations + self.lam_coeffs = Matrix() # The coeffecients of the multipliers + + forcelist = forcelist if forcelist else [] + if not iterable(forcelist): + raise TypeError('Force pairs must be supplied in an iterable.') + self._forcelist = forcelist + if frame and not isinstance(frame, ReferenceFrame): + raise TypeError('frame must be a valid ReferenceFrame') + self._bodies = bodies + self.inertial = frame + + self.lam_vec = Matrix() + + self._term1 = Matrix() + self._term2 = Matrix() + self._term3 = Matrix() + self._term4 = Matrix() + + # Creating the qs, qdots and qdoubledots + if not iterable(qs): + raise TypeError('Generalized coordinates must be an iterable') + self._q = Matrix(qs) + self._qdots = self.q.diff(dynamicsymbols._t) + self._qdoubledots = self._qdots.diff(dynamicsymbols._t) + _validate_coordinates(self.q) + + mat_build = lambda x: Matrix(x) if x else Matrix() + hol_coneqs = mat_build(hol_coneqs) + nonhol_coneqs = mat_build(nonhol_coneqs) + self.coneqs = Matrix([hol_coneqs.diff(dynamicsymbols._t), + nonhol_coneqs]) + self._hol_coneqs = hol_coneqs + + def form_lagranges_equations(self): + """Method to form Lagrange's equations of motion. + + Returns a vector of equations of motion using Lagrange's equations of + the second kind. + """ + + qds = self._qdots + qdd_zero = dict.fromkeys(self._qdoubledots, 0) + n = len(self.q) + + # Internally we represent the EOM as four terms: + # EOM = term1 - term2 - term3 - term4 = 0 + + # First term + self._term1 = self._L.jacobian(qds) + self._term1 = self._term1.diff(dynamicsymbols._t).T + + # Second term + self._term2 = self._L.jacobian(self.q).T + + # Third term + if self.coneqs: + coneqs = self.coneqs + m = len(coneqs) + # Creating the multipliers + self.lam_vec = Matrix(dynamicsymbols('lam1:' + str(m + 1))) + self.lam_coeffs = -coneqs.jacobian(qds) + self._term3 = self.lam_coeffs.T * self.lam_vec + # Extracting the coeffecients of the qdds from the diff coneqs + diffconeqs = coneqs.diff(dynamicsymbols._t) + self._m_cd = diffconeqs.jacobian(self._qdoubledots) + # The remaining terms i.e. the 'forcing' terms in diff coneqs + self._f_cd = -diffconeqs.subs(qdd_zero) + else: + self._term3 = zeros(n, 1) + + # Fourth term + if self.forcelist: + N = self.inertial + self._term4 = zeros(n, 1) + for i, qd in enumerate(qds): + flist = zip(*_f_list_parser(self.forcelist, N)) + self._term4[i] = sum(v.diff(qd, N).dot(f) for (v, f) in flist) + else: + self._term4 = zeros(n, 1) + + # Form the dynamic mass and forcing matrices + without_lam = self._term1 - self._term2 - self._term4 + self._m_d = without_lam.jacobian(self._qdoubledots) + self._f_d = -without_lam.subs(qdd_zero) + + # Form the EOM + self.eom = without_lam - self._term3 + return self.eom + + def _form_eoms(self): + return self.form_lagranges_equations() + + @property + def mass_matrix(self): + """Returns the mass matrix, which is augmented by the Lagrange + multipliers, if necessary. + + Explanation + =========== + + If the system is described by 'n' generalized coordinates and there are + no constraint equations then an n X n matrix is returned. + + If there are 'n' generalized coordinates and 'm' constraint equations + have been supplied during initialization then an n X (n+m) matrix is + returned. The (n + m - 1)th and (n + m)th columns contain the + coefficients of the Lagrange multipliers. + """ + + if self.eom is None: + raise ValueError('Need to compute the equations of motion first') + if self.coneqs: + return (self._m_d).row_join(self.lam_coeffs.T) + else: + return self._m_d + + @property + def mass_matrix_full(self): + """Augments the coefficients of qdots to the mass_matrix.""" + + if self.eom is None: + raise ValueError('Need to compute the equations of motion first') + n = len(self.q) + m = len(self.coneqs) + row1 = eye(n).row_join(zeros(n, n + m)) + row2 = zeros(n, n).row_join(self.mass_matrix) + if self.coneqs: + row3 = zeros(m, n).row_join(self._m_cd).row_join(zeros(m, m)) + return row1.col_join(row2).col_join(row3) + else: + return row1.col_join(row2) + + @property + def forcing(self): + """Returns the forcing vector from 'lagranges_equations' method.""" + + if self.eom is None: + raise ValueError('Need to compute the equations of motion first') + return self._f_d + + @property + def forcing_full(self): + """Augments qdots to the forcing vector above.""" + + if self.eom is None: + raise ValueError('Need to compute the equations of motion first') + if self.coneqs: + return self._qdots.col_join(self.forcing).col_join(self._f_cd) + else: + return self._qdots.col_join(self.forcing) + + def to_linearizer(self, q_ind=None, qd_ind=None, q_dep=None, qd_dep=None, + linear_solver='LU'): + """Returns an instance of the Linearizer class, initiated from the data + in the LagrangesMethod class. This may be more desirable than using the + linearize class method, as the Linearizer object will allow more + efficient recalculation (i.e. about varying operating points). + + Parameters + ========== + + q_ind, qd_ind : array_like, optional + The independent generalized coordinates and speeds. + q_dep, qd_dep : array_like, optional + The dependent generalized coordinates and speeds. + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + + Returns + ======= + Linearizer + An instantiated + :class:`sympy.physics.mechanics.linearize.Linearizer`. + + """ + + # Compose vectors + t = dynamicsymbols._t + q = self.q + u = self._qdots + ud = u.diff(t) + # Get vector of lagrange multipliers + lams = self.lam_vec + + mat_build = lambda x: Matrix(x) if x else Matrix() + q_i = mat_build(q_ind) + q_d = mat_build(q_dep) + u_i = mat_build(qd_ind) + u_d = mat_build(qd_dep) + + # Compose general form equations + f_c = self._hol_coneqs + f_v = self.coneqs + f_a = f_v.diff(t) + f_0 = u + f_1 = -u + f_2 = self._term1 + f_3 = -(self._term2 + self._term4) + f_4 = -self._term3 + + # Check that there are an appropriate number of independent and + # dependent coordinates + if len(q_d) != len(f_c) or len(u_d) != len(f_v): + raise ValueError(("Must supply {:} dependent coordinates, and " + + "{:} dependent speeds").format(len(f_c), len(f_v))) + if set(Matrix([q_i, q_d])) != set(q): + raise ValueError("Must partition q into q_ind and q_dep, with " + + "no extra or missing symbols.") + if set(Matrix([u_i, u_d])) != set(u): + raise ValueError("Must partition qd into qd_ind and qd_dep, " + + "with no extra or missing symbols.") + + # Find all other dynamic symbols, forming the forcing vector r. + # Sort r to make it canonical. + insyms = set(Matrix([q, u, ud, lams])) + r = list(find_dynamicsymbols(f_3, insyms)) + r.sort(key=default_sort_key) + # Check for any derivatives of variables in r that are also found in r. + for i in r: + if diff(i, dynamicsymbols._t) in r: + raise ValueError('Cannot have derivatives of specified \ + quantities when linearizing forcing terms.') + + return Linearizer(f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a, q, u, q_i, + q_d, u_i, u_d, r, lams, linear_solver=linear_solver) + + def linearize(self, q_ind=None, qd_ind=None, q_dep=None, qd_dep=None, + linear_solver='LU', **kwargs): + """Linearize the equations of motion about a symbolic operating point. + + Parameters + ========== + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + **kwargs + Extra keyword arguments are passed to + :meth:`sympy.physics.mechanics.linearize.Linearizer.linearize`. + + Explanation + =========== + + If kwarg A_and_B is False (default), returns M, A, B, r for the + linearized form, M*[q', u']^T = A*[q_ind, u_ind]^T + B*r. + + If kwarg A_and_B is True, returns A, B, r for the linearized form + dx = A*x + B*r, where x = [q_ind, u_ind]^T. Note that this is + computationally intensive if there are many symbolic parameters. For + this reason, it may be more desirable to use the default A_and_B=False, + returning M, A, and B. Values may then be substituted in to these + matrices, and the state space form found as + A = P.T*M.inv()*A, B = P.T*M.inv()*B, where P = Linearizer.perm_mat. + + In both cases, r is found as all dynamicsymbols in the equations of + motion that are not part of q, u, q', or u'. They are sorted in + canonical form. + + The operating points may be also entered using the ``op_point`` kwarg. + This takes a dictionary of {symbol: value}, or a an iterable of such + dictionaries. The values may be numeric or symbolic. The more values + you can specify beforehand, the faster this computation will run. + + For more documentation, please see the ``Linearizer`` class.""" + + linearizer = self.to_linearizer(q_ind, qd_ind, q_dep, qd_dep, + linear_solver=linear_solver) + result = linearizer.linearize(**kwargs) + return result + (linearizer.r,) + + def solve_multipliers(self, op_point=None, sol_type='dict'): + """Solves for the values of the lagrange multipliers symbolically at + the specified operating point. + + Parameters + ========== + + op_point : dict or iterable of dicts, optional + Point at which to solve at. The operating point is specified as + a dictionary or iterable of dictionaries of {symbol: value}. The + value may be numeric or symbolic itself. + + sol_type : str, optional + Solution return type. Valid options are: + - 'dict': A dict of {symbol : value} (default) + - 'Matrix': An ordered column matrix of the solution + """ + + # Determine number of multipliers + k = len(self.lam_vec) + if k == 0: + raise ValueError("System has no lagrange multipliers to solve for.") + # Compose dict of operating conditions + if isinstance(op_point, dict): + op_point_dict = op_point + elif iterable(op_point): + op_point_dict = {} + for op in op_point: + op_point_dict.update(op) + elif op_point is None: + op_point_dict = {} + else: + raise TypeError("op_point must be either a dictionary or an " + "iterable of dictionaries.") + # Compose the system to be solved + mass_matrix = self.mass_matrix.col_join(-self.lam_coeffs.row_join( + zeros(k, k))) + force_matrix = self.forcing.col_join(self._f_cd) + # Sub in the operating point + mass_matrix = msubs(mass_matrix, op_point_dict) + force_matrix = msubs(force_matrix, op_point_dict) + # Solve for the multipliers + sol_list = mass_matrix.LUsolve(-force_matrix)[-k:] + if sol_type == 'dict': + return dict(zip(self.lam_vec, sol_list)) + elif sol_type == 'Matrix': + return Matrix(sol_list) + else: + raise ValueError("Unknown sol_type {:}.".format(sol_type)) + + def rhs(self, inv_method=None, **kwargs): + """Returns equations that can be solved numerically. + + Parameters + ========== + + inv_method : str + The specific sympy inverse matrix calculation method to use. For a + list of valid methods, see + :meth:`~sympy.matrices.matrixbase.MatrixBase.inv` + """ + + if inv_method is None: + self._rhs = self.mass_matrix_full.LUsolve(self.forcing_full) + else: + self._rhs = (self.mass_matrix_full.inv(inv_method, + try_block_diag=True) * self.forcing_full) + return self._rhs + + @property + def q(self): + return self._q + + @property + def u(self): + return self._qdots + + @property + def bodies(self): + return self._bodies + + @property + def forcelist(self): + return self._forcelist + + @property + def loads(self): + return self._forcelist diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/linearize.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/linearize.py new file mode 100644 index 0000000000000000000000000000000000000000..b94ddb865a7236a5ac6f1a41ba96679eb8b2cd8f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/linearize.py @@ -0,0 +1,474 @@ +__all__ = ['Linearizer'] + +from sympy import Matrix, eye, zeros +from sympy.core.symbol import Dummy +from sympy.utilities.iterables import flatten +from sympy.physics.vector import dynamicsymbols +from sympy.physics.mechanics.functions import msubs, _parse_linear_solver + +from collections import namedtuple +from collections.abc import Iterable + + +class Linearizer: + """This object holds the general model form for a dynamic system. This + model is used for computing the linearized form of the system, while + properly dealing with constraints leading to dependent coordinates and + speeds. The notation and method is described in [1]_. + + Attributes + ========== + + f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a : Matrix + Matrices holding the general system form. + q, u, r : Matrix + Matrices holding the generalized coordinates, speeds, and + input vectors. + q_i, u_i : Matrix + Matrices of the independent generalized coordinates and speeds. + q_d, u_d : Matrix + Matrices of the dependent generalized coordinates and speeds. + perm_mat : Matrix + Permutation matrix such that [q_ind, u_ind]^T = perm_mat*[q, u]^T + + References + ========== + + .. [1] D. L. Peterson, G. Gede, and M. Hubbard, "Symbolic linearization of + equations of motion of constrained multibody systems," Multibody + Syst Dyn, vol. 33, no. 2, pp. 143-161, Feb. 2015, doi: + 10.1007/s11044-014-9436-5. + + """ + + def __init__(self, f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a, q, u, q_i=None, + q_d=None, u_i=None, u_d=None, r=None, lams=None, + linear_solver='LU'): + """ + Parameters + ========== + + f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a : array_like + System of equations holding the general system form. + Supply empty array or Matrix if the parameter + does not exist. + q : array_like + The generalized coordinates. + u : array_like + The generalized speeds + q_i, u_i : array_like, optional + The independent generalized coordinates and speeds. + q_d, u_d : array_like, optional + The dependent generalized coordinates and speeds. + r : array_like, optional + The input variables. + lams : array_like, optional + The lagrange multipliers + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + + """ + self.linear_solver = _parse_linear_solver(linear_solver) + + # Generalized equation form + self.f_0 = Matrix(f_0) + self.f_1 = Matrix(f_1) + self.f_2 = Matrix(f_2) + self.f_3 = Matrix(f_3) + self.f_4 = Matrix(f_4) + self.f_c = Matrix(f_c) + self.f_v = Matrix(f_v) + self.f_a = Matrix(f_a) + + # Generalized equation variables + self.q = Matrix(q) + self.u = Matrix(u) + none_handler = lambda x: Matrix(x) if x else Matrix() + self.q_i = none_handler(q_i) + self.q_d = none_handler(q_d) + self.u_i = none_handler(u_i) + self.u_d = none_handler(u_d) + self.r = none_handler(r) + self.lams = none_handler(lams) + + # Derivatives of generalized equation variables + self._qd = self.q.diff(dynamicsymbols._t) + self._ud = self.u.diff(dynamicsymbols._t) + # If the user doesn't actually use generalized variables, and the + # qd and u vectors have any intersecting variables, this can cause + # problems. We'll fix this with some hackery, and Dummy variables + dup_vars = set(self._qd).intersection(self.u) + self._qd_dup = Matrix([var if var not in dup_vars else Dummy() for var + in self._qd]) + + # Derive dimension terms + l = len(self.f_c) + m = len(self.f_v) + n = len(self.q) + o = len(self.u) + s = len(self.r) + k = len(self.lams) + dims = namedtuple('dims', ['l', 'm', 'n', 'o', 's', 'k']) + self._dims = dims(l, m, n, o, s, k) + + self._Pq = None + self._Pqi = None + self._Pqd = None + self._Pu = None + self._Pui = None + self._Pud = None + self._C_0 = None + self._C_1 = None + self._C_2 = None + self.perm_mat = None + + self._setup_done = False + + def _setup(self): + # Calculations here only need to be run once. They are moved out of + # the __init__ method to increase the speed of Linearizer creation. + self._form_permutation_matrices() + self._form_block_matrices() + self._form_coefficient_matrices() + self._setup_done = True + + def _form_permutation_matrices(self): + """Form the permutation matrices Pq and Pu.""" + + # Extract dimension variables + l, m, n, o, s, k = self._dims + # Compute permutation matrices + if n != 0: + self._Pq = permutation_matrix(self.q, Matrix([self.q_i, self.q_d])) + if l > 0: + self._Pqi = self._Pq[:, :-l] + self._Pqd = self._Pq[:, -l:] + else: + self._Pqi = self._Pq + self._Pqd = Matrix() + if o != 0: + self._Pu = permutation_matrix(self.u, Matrix([self.u_i, self.u_d])) + if m > 0: + self._Pui = self._Pu[:, :-m] + self._Pud = self._Pu[:, -m:] + else: + self._Pui = self._Pu + self._Pud = Matrix() + # Compute combination permutation matrix for computing A and B + P_col1 = Matrix([self._Pqi, zeros(o + k, n - l)]) + P_col2 = Matrix([zeros(n, o - m), self._Pui, zeros(k, o - m)]) + if P_col1: + if P_col2: + self.perm_mat = P_col1.row_join(P_col2) + else: + self.perm_mat = P_col1 + else: + self.perm_mat = P_col2 + + def _form_coefficient_matrices(self): + """Form the coefficient matrices C_0, C_1, and C_2.""" + + # Extract dimension variables + l, m, n, o, s, k = self._dims + # Build up the coefficient matrices C_0, C_1, and C_2 + # If there are configuration constraints (l > 0), form C_0 as normal. + # If not, C_0 is I_(nxn). Note that this works even if n=0 + if l > 0: + f_c_jac_q = self.f_c.jacobian(self.q) + self._C_0 = (eye(n) - self._Pqd * + self.linear_solver(f_c_jac_q*self._Pqd, + f_c_jac_q))*self._Pqi + else: + self._C_0 = eye(n) + # If there are motion constraints (m > 0), form C_1 and C_2 as normal. + # If not, C_1 is 0, and C_2 is I_(oxo). Note that this works even if + # o = 0. + if m > 0: + f_v_jac_u = self.f_v.jacobian(self.u) + temp = f_v_jac_u * self._Pud + if n != 0: + f_v_jac_q = self.f_v.jacobian(self.q) + self._C_1 = -self._Pud * self.linear_solver(temp, f_v_jac_q) + else: + self._C_1 = zeros(o, n) + self._C_2 = (eye(o) - self._Pud * + self.linear_solver(temp, f_v_jac_u))*self._Pui + else: + self._C_1 = zeros(o, n) + self._C_2 = eye(o) + + def _form_block_matrices(self): + """Form the block matrices for composing M, A, and B.""" + + # Extract dimension variables + l, m, n, o, s, k = self._dims + # Block Matrix Definitions. These are only defined if under certain + # conditions. If undefined, an empty matrix is used instead + if n != 0: + self._M_qq = self.f_0.jacobian(self._qd) + self._A_qq = -(self.f_0 + self.f_1).jacobian(self.q) + else: + self._M_qq = Matrix() + self._A_qq = Matrix() + if n != 0 and m != 0: + self._M_uqc = self.f_a.jacobian(self._qd_dup) + self._A_uqc = -self.f_a.jacobian(self.q) + else: + self._M_uqc = Matrix() + self._A_uqc = Matrix() + if n != 0 and o - m + k != 0: + self._M_uqd = self.f_3.jacobian(self._qd_dup) + self._A_uqd = -(self.f_2 + self.f_3 + self.f_4).jacobian(self.q) + else: + self._M_uqd = Matrix() + self._A_uqd = Matrix() + if o != 0 and m != 0: + self._M_uuc = self.f_a.jacobian(self._ud) + self._A_uuc = -self.f_a.jacobian(self.u) + else: + self._M_uuc = Matrix() + self._A_uuc = Matrix() + if o != 0 and o - m + k != 0: + self._M_uud = self.f_2.jacobian(self._ud) + self._A_uud = -(self.f_2 + self.f_3).jacobian(self.u) + else: + self._M_uud = Matrix() + self._A_uud = Matrix() + if o != 0 and n != 0: + self._A_qu = -self.f_1.jacobian(self.u) + else: + self._A_qu = Matrix() + if k != 0 and o - m + k != 0: + self._M_uld = self.f_4.jacobian(self.lams) + else: + self._M_uld = Matrix() + if s != 0 and o - m + k != 0: + self._B_u = -self.f_3.jacobian(self.r) + else: + self._B_u = Matrix() + + def linearize(self, op_point=None, A_and_B=False, simplify=False): + """Linearize the system about the operating point. Note that + q_op, u_op, qd_op, ud_op must satisfy the equations of motion. + These may be either symbolic or numeric. + + Parameters + ========== + op_point : dict or iterable of dicts, optional + Dictionary or iterable of dictionaries containing the operating + point conditions for all or a subset of the generalized + coordinates, generalized speeds, and time derivatives of the + generalized speeds. These will be substituted into the linearized + system before the linearization is complete. Leave set to ``None`` + if you want the operating point to be an arbitrary set of symbols. + Note that any reduction in symbols (whether substituted for numbers + or expressions with a common parameter) will result in faster + runtime. + A_and_B : bool, optional + If A_and_B=False (default), (M, A, B) is returned and of + A_and_B=True, (A, B) is returned. See below. + simplify : bool, optional + Determines if returned values are simplified before return. + For large expressions this may be time consuming. Default is False. + + Returns + ======= + M, A, B : Matrices, ``A_and_B=False`` + Matrices from the implicit form: + ``[M]*[q', u']^T = [A]*[q_ind, u_ind]^T + [B]*r`` + A, B : Matrices, ``A_and_B=True`` + Matrices from the explicit form: + ``[q_ind', u_ind']^T = [A]*[q_ind, u_ind]^T + [B]*r`` + + Notes + ===== + + Note that the process of solving with A_and_B=True is computationally + intensive if there are many symbolic parameters. For this reason, it + may be more desirable to use the default A_and_B=False, returning M, A, + and B. More values may then be substituted in to these matrices later + on. The state space form can then be found as A = P.T*M.LUsolve(A), B = + P.T*M.LUsolve(B), where P = Linearizer.perm_mat. + + """ + + # Run the setup if needed: + if not self._setup_done: + self._setup() + + # Compose dict of operating conditions + if isinstance(op_point, dict): + op_point_dict = op_point + elif isinstance(op_point, Iterable): + op_point_dict = {} + for op in op_point: + op_point_dict.update(op) + else: + op_point_dict = {} + + # Extract dimension variables + l, m, n, o, s, k = self._dims + + # Rename terms to shorten expressions + M_qq = self._M_qq + M_uqc = self._M_uqc + M_uqd = self._M_uqd + M_uuc = self._M_uuc + M_uud = self._M_uud + M_uld = self._M_uld + A_qq = self._A_qq + A_uqc = self._A_uqc + A_uqd = self._A_uqd + A_qu = self._A_qu + A_uuc = self._A_uuc + A_uud = self._A_uud + B_u = self._B_u + C_0 = self._C_0 + C_1 = self._C_1 + C_2 = self._C_2 + + # Build up Mass Matrix + # |M_qq 0_nxo 0_nxk| + # M = |M_uqc M_uuc 0_mxk| + # |M_uqd M_uud M_uld| + if o != 0: + col2 = Matrix([zeros(n, o), M_uuc, M_uud]) + if k != 0: + col3 = Matrix([zeros(n + m, k), M_uld]) + if n != 0: + col1 = Matrix([M_qq, M_uqc, M_uqd]) + if o != 0 and k != 0: + M = col1.row_join(col2).row_join(col3) + elif o != 0: + M = col1.row_join(col2) + else: + M = col1 + elif k != 0: + M = col2.row_join(col3) + else: + M = col2 + M_eq = msubs(M, op_point_dict) + + # Build up state coefficient matrix A + # |(A_qq + A_qu*C_1)*C_0 A_qu*C_2| + # A = |(A_uqc + A_uuc*C_1)*C_0 A_uuc*C_2| + # |(A_uqd + A_uud*C_1)*C_0 A_uud*C_2| + # Col 1 is only defined if n != 0 + if n != 0: + r1c1 = A_qq + if o != 0: + r1c1 += (A_qu * C_1) + r1c1 = r1c1 * C_0 + if m != 0: + r2c1 = A_uqc + if o != 0: + r2c1 += (A_uuc * C_1) + r2c1 = r2c1 * C_0 + else: + r2c1 = Matrix() + if o - m + k != 0: + r3c1 = A_uqd + if o != 0: + r3c1 += (A_uud * C_1) + r3c1 = r3c1 * C_0 + else: + r3c1 = Matrix() + col1 = Matrix([r1c1, r2c1, r3c1]) + else: + col1 = Matrix() + # Col 2 is only defined if o != 0 + if o != 0: + if n != 0: + r1c2 = A_qu * C_2 + else: + r1c2 = Matrix() + if m != 0: + r2c2 = A_uuc * C_2 + else: + r2c2 = Matrix() + if o - m + k != 0: + r3c2 = A_uud * C_2 + else: + r3c2 = Matrix() + col2 = Matrix([r1c2, r2c2, r3c2]) + else: + col2 = Matrix() + if col1: + if col2: + Amat = col1.row_join(col2) + else: + Amat = col1 + else: + Amat = col2 + Amat_eq = msubs(Amat, op_point_dict) + + # Build up the B matrix if there are forcing variables + # |0_(n + m)xs| + # B = |B_u | + if s != 0 and o - m + k != 0: + Bmat = zeros(n + m, s).col_join(B_u) + Bmat_eq = msubs(Bmat, op_point_dict) + else: + Bmat_eq = Matrix() + + # kwarg A_and_B indicates to return A, B for forming the equation + # dx = [A]x + [B]r, where x = [q_indnd, u_indnd]^T, + if A_and_B: + A_cont = self.perm_mat.T * self.linear_solver(M_eq, Amat_eq) + if Bmat_eq: + B_cont = self.perm_mat.T * self.linear_solver(M_eq, Bmat_eq) + else: + # Bmat = Matrix([]), so no need to sub + B_cont = Bmat_eq + if simplify: + A_cont.simplify() + B_cont.simplify() + return A_cont, B_cont + # Otherwise return M, A, B for forming the equation + # [M]dx = [A]x + [B]r, where x = [q, u]^T + else: + if simplify: + M_eq.simplify() + Amat_eq.simplify() + Bmat_eq.simplify() + return M_eq, Amat_eq, Bmat_eq + + +def permutation_matrix(orig_vec, per_vec): + """Compute the permutation matrix to change order of + orig_vec into order of per_vec. + + Parameters + ========== + + orig_vec : array_like + Symbols in original ordering. + per_vec : array_like + Symbols in new ordering. + + Returns + ======= + + p_matrix : Matrix + Permutation matrix such that orig_vec == (p_matrix * per_vec). + """ + if not isinstance(orig_vec, (list, tuple)): + orig_vec = flatten(orig_vec) + if not isinstance(per_vec, (list, tuple)): + per_vec = flatten(per_vec) + if set(orig_vec) != set(per_vec): + raise ValueError("orig_vec and per_vec must be the same length, " + "and contain the same symbols.") + ind_list = [orig_vec.index(i) for i in per_vec] + p_matrix = zeros(len(orig_vec)) + for i, j in enumerate(ind_list): + p_matrix[i, j] = 1 + return p_matrix diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/loads.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/loads.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9db763ffd6f99905e9d17fdc07f4171de4801b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/loads.py @@ -0,0 +1,177 @@ +from abc import ABC +from collections import namedtuple +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.vector import Vector, ReferenceFrame, Point + +__all__ = ['LoadBase', 'Force', 'Torque'] + + +class LoadBase(ABC, namedtuple('LoadBase', ['location', 'vector'])): + """Abstract base class for the various loading types.""" + + def __add__(self, other): + raise TypeError(f"unsupported operand type(s) for +: " + f"'{self.__class__.__name__}' and " + f"'{other.__class__.__name__}'") + + def __mul__(self, other): + raise TypeError(f"unsupported operand type(s) for *: " + f"'{self.__class__.__name__}' and " + f"'{other.__class__.__name__}'") + + __radd__ = __add__ + __rmul__ = __mul__ + + +class Force(LoadBase): + """Force acting upon a point. + + Explanation + =========== + + A force is a vector that is bound to a line of action. This class stores + both a point, which lies on the line of action, and the vector. A tuple can + also be used, with the location as the first entry and the vector as second + entry. + + Examples + ======== + + A force of magnitude 2 along N.x acting on a point Po can be created as + follows: + + >>> from sympy.physics.mechanics import Point, ReferenceFrame, Force + >>> N = ReferenceFrame('N') + >>> Po = Point('Po') + >>> Force(Po, 2 * N.x) + (Po, 2*N.x) + + If a body is supplied, then the center of mass of that body is used. + + >>> from sympy.physics.mechanics import Particle + >>> P = Particle('P', point=Po) + >>> Force(P, 2 * N.x) + (Po, 2*N.x) + + """ + + def __new__(cls, point, force): + if isinstance(point, BodyBase): + point = point.masscenter + if not isinstance(point, Point): + raise TypeError('Force location should be a Point.') + if not isinstance(force, Vector): + raise TypeError('Force vector should be a Vector.') + return super().__new__(cls, point, force) + + def __repr__(self): + return (f'{self.__class__.__name__}(point={self.point}, ' + f'force={self.force})') + + @property + def point(self): + return self.location + + @property + def force(self): + return self.vector + + +class Torque(LoadBase): + """Torque acting upon a frame. + + Explanation + =========== + + A torque is a free vector that is acting on a reference frame, which is + associated with a rigid body. This class stores both the frame and the + vector. A tuple can also be used, with the location as the first item and + the vector as second item. + + Examples + ======== + + A torque of magnitude 2 about N.x acting on a frame N can be created as + follows: + + >>> from sympy.physics.mechanics import ReferenceFrame, Torque + >>> N = ReferenceFrame('N') + >>> Torque(N, 2 * N.x) + (N, 2*N.x) + + If a body is supplied, then the frame fixed to that body is used. + + >>> from sympy.physics.mechanics import RigidBody + >>> rb = RigidBody('rb', frame=N) + >>> Torque(rb, 2 * N.x) + (N, 2*N.x) + + """ + + def __new__(cls, frame, torque): + if isinstance(frame, BodyBase): + frame = frame.frame + if not isinstance(frame, ReferenceFrame): + raise TypeError('Torque location should be a ReferenceFrame.') + if not isinstance(torque, Vector): + raise TypeError('Torque vector should be a Vector.') + return super().__new__(cls, frame, torque) + + def __repr__(self): + return (f'{self.__class__.__name__}(frame={self.frame}, ' + f'torque={self.torque})') + + @property + def frame(self): + return self.location + + @property + def torque(self): + return self.vector + + +def gravity(acceleration, *bodies): + """ + Returns a list of gravity forces given the acceleration + due to gravity and any number of particles or rigidbodies. + + Example + ======= + + >>> from sympy.physics.mechanics import ReferenceFrame, Particle, RigidBody + >>> from sympy.physics.mechanics.loads import gravity + >>> from sympy import symbols + >>> N = ReferenceFrame('N') + >>> g = symbols('g') + >>> P = Particle('P') + >>> B = RigidBody('B') + >>> gravity(g*N.y, P, B) + [(P_masscenter, P_mass*g*N.y), + (B_masscenter, B_mass*g*N.y)] + + """ + + gravity_force = [] + for body in bodies: + if not isinstance(body, BodyBase): + raise TypeError(f'{type(body)} is not a body type') + gravity_force.append(Force(body.masscenter, body.mass * acceleration)) + return gravity_force + + +def _parse_load(load): + """Helper function to parse loads and convert tuples to load objects.""" + if isinstance(load, LoadBase): + return load + elif isinstance(load, tuple): + if len(load) != 2: + raise ValueError(f'Load {load} should have a length of 2.') + if isinstance(load[0], Point): + return Force(load[0], load[1]) + elif isinstance(load[0], ReferenceFrame): + return Torque(load[0], load[1]) + else: + raise ValueError(f'Load not recognized. The load location {load[0]}' + f' should either be a Point or a ReferenceFrame.') + raise TypeError(f'Load type {type(load)} not recognized as a load. It ' + f'should be a Force, Torque or tuple.') diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/method.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/method.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2c4a5f388e56e37bd9ecdf6daffc08ffa51070 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/method.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod + +class _Methods(ABC): + """Abstract Base Class for all methods.""" + + @abstractmethod + def q(self): + pass + + @abstractmethod + def u(self): + pass + + @abstractmethod + def bodies(self): + pass + + @abstractmethod + def loads(self): + pass + + @abstractmethod + def mass_matrix(self): + pass + + @abstractmethod + def forcing(self): + pass + + @abstractmethod + def mass_matrix_full(self): + pass + + @abstractmethod + def forcing_full(self): + pass + + def _form_eoms(self): + raise NotImplementedError("Subclasses must implement this.") diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/models.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/models.py new file mode 100644 index 0000000000000000000000000000000000000000..a89b929ffd540a07787f6f94714850b348c90781 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/models.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python +"""This module contains some sample symbolic models used for testing and +examples.""" + +# Internal imports +from sympy.core import backend as sm +import sympy.physics.mechanics as me + + +def multi_mass_spring_damper(n=1, apply_gravity=False, + apply_external_forces=False): + r"""Returns a system containing the symbolic equations of motion and + associated variables for a simple multi-degree of freedom point mass, + spring, damper system with optional gravitational and external + specified forces. For example, a two mass system under the influence of + gravity and external forces looks like: + + :: + + ---------------- + | | | | g + \ | | | V + k0 / --- c0 | + | | | x0, v0 + --------- V + | m0 | ----- + --------- | + | | | | + \ v | | | + k1 / f0 --- c1 | + | | | x1, v1 + --------- V + | m1 | ----- + --------- + | f1 + V + + Parameters + ========== + + n : integer + The number of masses in the serial chain. + apply_gravity : boolean + If true, gravity will be applied to each mass. + apply_external_forces : boolean + If true, a time varying external force will be applied to each mass. + + Returns + ======= + + kane : sympy.physics.mechanics.kane.KanesMethod + A KanesMethod object. + + """ + + mass = sm.symbols('m:{}'.format(n)) + stiffness = sm.symbols('k:{}'.format(n)) + damping = sm.symbols('c:{}'.format(n)) + + acceleration_due_to_gravity = sm.symbols('g') + + coordinates = me.dynamicsymbols('x:{}'.format(n)) + speeds = me.dynamicsymbols('v:{}'.format(n)) + specifieds = me.dynamicsymbols('f:{}'.format(n)) + + ceiling = me.ReferenceFrame('N') + origin = me.Point('origin') + origin.set_vel(ceiling, 0) + + points = [origin] + kinematic_equations = [] + particles = [] + forces = [] + + for i in range(n): + + center = points[-1].locatenew('center{}'.format(i), + coordinates[i] * ceiling.x) + center.set_vel(ceiling, points[-1].vel(ceiling) + + speeds[i] * ceiling.x) + points.append(center) + + block = me.Particle('block{}'.format(i), center, mass[i]) + + kinematic_equations.append(speeds[i] - coordinates[i].diff()) + + total_force = (-stiffness[i] * coordinates[i] - + damping[i] * speeds[i]) + try: + total_force += (stiffness[i + 1] * coordinates[i + 1] + + damping[i + 1] * speeds[i + 1]) + except IndexError: # no force from below on last mass + pass + + if apply_gravity: + total_force += mass[i] * acceleration_due_to_gravity + + if apply_external_forces: + total_force += specifieds[i] + + forces.append((center, total_force * ceiling.x)) + + particles.append(block) + + kane = me.KanesMethod(ceiling, q_ind=coordinates, u_ind=speeds, + kd_eqs=kinematic_equations) + kane.kanes_equations(particles, forces) + + return kane + + +def n_link_pendulum_on_cart(n=1, cart_force=True, joint_torques=False): + r"""Returns the system containing the symbolic first order equations of + motion for a 2D n-link pendulum on a sliding cart under the influence of + gravity. + + :: + + | + o y v + \ 0 ^ g + \ | + --\-|---- + | \| | + F-> | o --|---> x + | | + --------- + o o + + Parameters + ========== + + n : integer + The number of links in the pendulum. + cart_force : boolean, default=True + If true an external specified lateral force is applied to the cart. + joint_torques : boolean, default=False + If true joint torques will be added as specified inputs at each + joint. + + Returns + ======= + + kane : sympy.physics.mechanics.kane.KanesMethod + A KanesMethod object. + + Notes + ===== + + The degrees of freedom of the system are n + 1, i.e. one for each + pendulum link and one for the lateral motion of the cart. + + M x' = F, where x = [u0, ..., un+1, q0, ..., qn+1] + + The joint angles are all defined relative to the ground where the x axis + defines the ground line and the y axis points up. The joint torques are + applied between each adjacent link and the between the cart and the + lower link where a positive torque corresponds to positive angle. + + """ + if n <= 0: + raise ValueError('The number of links must be a positive integer.') + + q = me.dynamicsymbols('q:{}'.format(n + 1)) + u = me.dynamicsymbols('u:{}'.format(n + 1)) + + if joint_torques is True: + T = me.dynamicsymbols('T1:{}'.format(n + 1)) + + m = sm.symbols('m:{}'.format(n + 1)) + l = sm.symbols('l:{}'.format(n)) + g, t = sm.symbols('g t') + + I = me.ReferenceFrame('I') + O = me.Point('O') + O.set_vel(I, 0) + + P0 = me.Point('P0') + P0.set_pos(O, q[0] * I.x) + P0.set_vel(I, u[0] * I.x) + Pa0 = me.Particle('Pa0', P0, m[0]) + + frames = [I] + points = [P0] + particles = [Pa0] + forces = [(P0, -m[0] * g * I.y)] + kindiffs = [q[0].diff(t) - u[0]] + + if cart_force is True or joint_torques is True: + specified = [] + else: + specified = None + + for i in range(n): + Bi = I.orientnew('B{}'.format(i), 'Axis', [q[i + 1], I.z]) + Bi.set_ang_vel(I, u[i + 1] * I.z) + frames.append(Bi) + + Pi = points[-1].locatenew('P{}'.format(i + 1), l[i] * Bi.y) + Pi.v2pt_theory(points[-1], I, Bi) + points.append(Pi) + + Pai = me.Particle('Pa' + str(i + 1), Pi, m[i + 1]) + particles.append(Pai) + + forces.append((Pi, -m[i + 1] * g * I.y)) + + if joint_torques is True: + + specified.append(T[i]) + + if i == 0: + forces.append((I, -T[i] * I.z)) + + if i == n - 1: + forces.append((Bi, T[i] * I.z)) + else: + forces.append((Bi, T[i] * I.z - T[i + 1] * I.z)) + + kindiffs.append(q[i + 1].diff(t) - u[i + 1]) + + if cart_force is True: + F = me.dynamicsymbols('F') + forces.append((P0, F * I.x)) + specified.append(F) + + kane = me.KanesMethod(I, q_ind=q, u_ind=u, kd_eqs=kindiffs) + kane.kanes_equations(particles, forces) + + return kane diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/particle.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/particle.py new file mode 100644 index 0000000000000000000000000000000000000000..5d49d4f811b8d1c7fff16c71991f5e01da6ded02 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/particle.py @@ -0,0 +1,209 @@ +from sympy import S +from sympy.physics.vector import cross, dot +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.inertia import inertia_of_point_mass +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['Particle'] + + +class Particle(BodyBase): + """A particle. + + Explanation + =========== + + Particles have a non-zero mass and lack spatial extension; they take up no + space. + + Values need to be supplied on initialization, but can be changed later. + + Parameters + ========== + + name : str + Name of particle + point : Point + A physics/mechanics Point which represents the position, velocity, and + acceleration of this Particle + mass : Sympifyable + A SymPy expression representing the Particle's mass + potential_energy : Sympifyable + The potential energy of the Particle. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point + >>> from sympy import Symbol + >>> po = Point('po') + >>> m = Symbol('m') + >>> pa = Particle('pa', po, m) + >>> # Or you could change these later + >>> pa.mass = m + >>> pa.point = po + + """ + point = BodyBase.masscenter + + def __init__(self, name, point=None, mass=None): + super().__init__(name, point, mass) + + def linear_momentum(self, frame): + """Linear momentum of the particle. + + Explanation + =========== + + The linear momentum L, of a particle P, with respect to frame N is + given by: + + L = m * v + + where m is the mass of the particle, and v is the velocity of the + particle in the frame N. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which linear momentum is desired. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point, ReferenceFrame + >>> from sympy.physics.mechanics import dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> m, v = dynamicsymbols('m v') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> A = Particle('A', P, m) + >>> P.set_vel(N, v * N.x) + >>> A.linear_momentum(N) + m*v*N.x + + """ + + return self.mass * self.point.vel(frame) + + def angular_momentum(self, point, frame): + """Angular momentum of the particle about the point. + + Explanation + =========== + + The angular momentum H, about some point O of a particle, P, is given + by: + + ``H = cross(r, m * v)`` + + where r is the position vector from point O to the particle P, m is + the mass of the particle, and v is the velocity of the particle in + the inertial frame, N. + + Parameters + ========== + + point : Point + The point about which angular momentum of the particle is desired. + + frame : ReferenceFrame + The frame in which angular momentum is desired. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point, ReferenceFrame + >>> from sympy.physics.mechanics import dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> m, v, r = dynamicsymbols('m v r') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> A = O.locatenew('A', r * N.x) + >>> P = Particle('P', A, m) + >>> P.point.set_vel(N, v * N.y) + >>> P.angular_momentum(O, N) + m*r*v*N.z + + """ + + return cross(self.point.pos_from(point), + self.mass * self.point.vel(frame)) + + def kinetic_energy(self, frame): + """Kinetic energy of the particle. + + Explanation + =========== + + The kinetic energy, T, of a particle, P, is given by: + + ``T = 1/2 (dot(m * v, v))`` + + where m is the mass of particle P, and v is the velocity of the + particle in the supplied ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The Particle's velocity is typically defined with respect to + an inertial frame but any relevant frame in which the velocity is + known can be supplied. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point, ReferenceFrame + >>> from sympy import symbols + >>> m, v, r = symbols('m v r') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> P = Particle('P', O, m) + >>> P.point.set_vel(N, v * N.y) + >>> P.kinetic_energy(N) + m*v**2/2 + + """ + + return S.Half * self.mass * dot(self.point.vel(frame), + self.point.vel(frame)) + + def set_potential_energy(self, scalar): + sympy_deprecation_warning( + """ +The sympy.physics.mechanics.Particle.set_potential_energy() +method is deprecated. Instead use + + P.potential_energy = scalar + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-set-potential-energy", + ) + self.potential_energy = scalar + + def parallel_axis(self, point, frame): + """Returns an inertia dyadic of the particle with respect to another + point and frame. + + Parameters + ========== + + point : sympy.physics.vector.Point + The point to express the inertia dyadic about. + frame : sympy.physics.vector.ReferenceFrame + The reference frame used to construct the dyadic. + + Returns + ======= + + inertia : sympy.physics.vector.Dyadic + The inertia dyadic of the particle expressed about the provided + point and frame. + + """ + return inertia_of_point_mass(self.mass, self.point.pos_from(point), + frame) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/pathway.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/pathway.py new file mode 100644 index 0000000000000000000000000000000000000000..b86ba85b1d9d1434c51de3fd7cc429442fdbedb0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/pathway.py @@ -0,0 +1,688 @@ +"""Implementations of pathways for use by actuators.""" + +from abc import ABC, abstractmethod + +from sympy.core.singleton import S +from sympy.physics.mechanics.loads import Force +from sympy.physics.mechanics.wrapping_geometry import WrappingGeometryBase +from sympy.physics.vector import Point, dynamicsymbols + + +__all__ = ['PathwayBase', 'LinearPathway', 'ObstacleSetPathway', + 'WrappingPathway'] + + +class PathwayBase(ABC): + """Abstract base class for all pathway classes to inherit from. + + Notes + ===== + + Instances of this class cannot be directly instantiated by users. However, + it can be used to created custom pathway types through subclassing. + + """ + + def __init__(self, *attachments): + """Initializer for ``PathwayBase``.""" + self.attachments = attachments + + @property + def attachments(self): + """The pair of points defining a pathway's ends.""" + return self._attachments + + @attachments.setter + def attachments(self, attachments): + if hasattr(self, '_attachments'): + msg = ( + f'Can\'t set attribute `attachments` to {repr(attachments)} ' + f'as it is immutable.' + ) + raise AttributeError(msg) + if len(attachments) != 2: + msg = ( + f'Value {repr(attachments)} passed to `attachments` was an ' + f'iterable of length {len(attachments)}, must be an iterable ' + f'of length 2.' + ) + raise ValueError(msg) + for i, point in enumerate(attachments): + if not isinstance(point, Point): + msg = ( + f'Value {repr(point)} passed to `attachments` at index ' + f'{i} was of type {type(point)}, must be {Point}.' + ) + raise TypeError(msg) + self._attachments = tuple(attachments) + + @property + @abstractmethod + def length(self): + """An expression representing the pathway's length.""" + pass + + @property + @abstractmethod + def extension_velocity(self): + """An expression representing the pathway's extension velocity.""" + pass + + @abstractmethod + def to_loads(self, force): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + """ + pass + + def __repr__(self): + """Default representation of a pathway.""" + attachments = ', '.join(str(a) for a in self.attachments) + return f'{self.__class__.__name__}({attachments})' + + +class LinearPathway(PathwayBase): + """Linear pathway between a pair of attachment points. + + Explanation + =========== + + A linear pathway forms a straight-line segment between two points and is + the simplest pathway that can be formed. It will not interact with any + other objects in the system, i.e. a ``LinearPathway`` will intersect other + objects to ensure that the path between its two ends (its attachments) is + the shortest possible. + + A linear pathway is made up of two points that can move relative to each + other, and a pair of equal and opposite forces acting on the points. If the + positive time-varying Euclidean distance between the two points is defined, + then the "extension velocity" is the time derivative of this distance. The + extension velocity is positive when the two points are moving away from + each other and negative when moving closer to each other. The direction for + the force acting on either point is determined by constructing a unit + vector directed from the other point to this point. This establishes a sign + convention such that a positive force magnitude tends to push the points + apart. The following diagram shows the positive force sense and the + distance between the points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + >>> from sympy.physics.mechanics import LinearPathway + + To construct a pathway, two points are required to be passed to the + ``attachments`` parameter as a ``tuple``. + + >>> from sympy.physics.mechanics import Point + >>> pA, pB = Point('pA'), Point('pB') + >>> linear_pathway = LinearPathway(pA, pB) + >>> linear_pathway + LinearPathway(pA, pB) + + The pathway created above isn't very interesting without the positions and + velocities of its attachment points being described. Without this its not + possible to describe how the pathway moves, i.e. its length or its + extension velocity. + + >>> from sympy.physics.mechanics import ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> pB.set_pos(pA, q*N.x) + >>> pB.pos_from(pA) + q(t)*N.x + + A pathway's length can be accessed via its ``length`` attribute. + + >>> linear_pathway.length + sqrt(q(t)**2) + + Note how what appears to be an overly-complex expression is returned. This + is actually required as it ensures that a pathway's length is always + positive. + + A pathway's extension velocity can be accessed similarly via its + ``extension_velocity`` attribute. + + >>> linear_pathway.extension_velocity + sqrt(q(t)**2)*Derivative(q(t), t)/q(t) + + Parameters + ========== + + attachments : tuple[Point, Point] + Pair of ``Point`` objects between which the linear pathway spans. + Constructor expects two points to be passed, e.g. + ``LinearPathway(Point('pA'), Point('pB'))``. More or fewer points will + cause an error to be thrown. + + """ + + def __init__(self, *attachments): + """Initializer for ``LinearPathway``. + + Parameters + ========== + + attachments : Point + Pair of ``Point`` objects between which the linear pathway spans. + Constructor expects two points to be passed, e.g. + ``LinearPathway(Point('pA'), Point('pB'))``. More or fewer points + will cause an error to be thrown. + + """ + super().__init__(*attachments) + + @property + def length(self): + """Exact analytical expression for the pathway's length.""" + return _point_pair_length(*self.attachments) + + @property + def extension_velocity(self): + """Exact analytical expression for the pathway's extension velocity.""" + return _point_pair_extension_velocity(*self.attachments) + + def to_loads(self, force): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced in a linear + actuator that produces an expansile force ``F``. First, create a linear + actuator between two points separated by the coordinate ``q`` in the + ``x`` direction of the global frame ``N``. + + >>> from sympy.physics.mechanics import (LinearPathway, Point, + ... ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> q = dynamicsymbols('q') + >>> N = ReferenceFrame('N') + >>> pA, pB = Point('pA'), Point('pB') + >>> pB.set_pos(pA, q*N.x) + >>> linear_pathway = LinearPathway(pA, pB) + + Now create a symbol ``F`` to describe the magnitude of the (expansile) + force that will be produced along the pathway. The list of loads that + ``KanesMethod`` requires can be produced by calling the pathway's + ``to_loads`` method with ``F`` passed as the only argument. + + >>> from sympy import symbols + >>> F = symbols('F') + >>> linear_pathway.to_loads(F) + [(pA, - F*q(t)/sqrt(q(t)**2)*N.x), (pB, F*q(t)/sqrt(q(t)**2)*N.x)] + + Parameters + ========== + + force : Expr + Magnitude of the force acting along the length of the pathway. As + per the sign conventions for the pathway length, pathway extension + velocity, and pair of point forces, if this ``Expr`` is positive + then the force will act to push the pair of points away from one + another (it is expansile). + + """ + relative_position = _point_pair_relative_position(*self.attachments) + loads = [ + Force(self.attachments[0], -force*relative_position/self.length), + Force(self.attachments[-1], force*relative_position/self.length), + ] + return loads + + +class ObstacleSetPathway(PathwayBase): + """Obstacle-set pathway between a set of attachment points. + + Explanation + =========== + + An obstacle-set pathway forms a series of straight-line segment between + pairs of consecutive points in a set of points. It is similar to multiple + linear pathways joined end-to-end. It will not interact with any other + objects in the system, i.e. an ``ObstacleSetPathway`` will intersect other + objects to ensure that the path between its pairs of points (its + attachments) is the shortest possible. + + Examples + ======== + + To construct an obstacle-set pathway, three or more points are required to + be passed to the ``attachments`` parameter as a ``tuple``. + + >>> from sympy.physics.mechanics import ObstacleSetPathway, Point + >>> pA, pB, pC, pD = Point('pA'), Point('pB'), Point('pC'), Point('pD') + >>> obstacle_set_pathway = ObstacleSetPathway(pA, pB, pC, pD) + >>> obstacle_set_pathway + ObstacleSetPathway(pA, pB, pC, pD) + + The pathway created above isn't very interesting without the positions and + velocities of its attachment points being described. Without this its not + possible to describe how the pathway moves, i.e. its length or its + extension velocity. + + >>> from sympy import cos, sin + >>> from sympy.physics.mechanics import ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> pO = Point('pO') + >>> pA.set_pos(pO, N.y) + >>> pB.set_pos(pO, -N.x) + >>> pC.set_pos(pA, cos(q) * N.x - (sin(q) + 1) * N.y) + >>> pD.set_pos(pA, sin(q) * N.x + (cos(q) - 1) * N.y) + >>> pB.pos_from(pA) + - N.x - N.y + >>> pC.pos_from(pA) + cos(q(t))*N.x + (-sin(q(t)) - 1)*N.y + >>> pD.pos_from(pA) + sin(q(t))*N.x + (cos(q(t)) - 1)*N.y + + A pathway's length can be accessed via its ``length`` attribute. + + >>> obstacle_set_pathway.length.simplify() + sqrt(2)*(sqrt(cos(q(t)) + 1) + 2) + + A pathway's extension velocity can be accessed similarly via its + ``extension_velocity`` attribute. + + >>> obstacle_set_pathway.extension_velocity.simplify() + -sqrt(2)*sin(q(t))*Derivative(q(t), t)/(2*sqrt(cos(q(t)) + 1)) + + Parameters + ========== + + attachments : tuple[Point, ...] + The set of ``Point`` objects that define the segmented obstacle-set + pathway. + + """ + + def __init__(self, *attachments): + """Initializer for ``ObstacleSetPathway``. + + Parameters + ========== + + attachments : tuple[Point, ...] + The set of ``Point`` objects that define the segmented obstacle-set + pathway. + + """ + super().__init__(*attachments) + + @property + def attachments(self): + """The set of points defining a pathway's segmented path.""" + return self._attachments + + @attachments.setter + def attachments(self, attachments): + if hasattr(self, '_attachments'): + msg = ( + f'Can\'t set attribute `attachments` to {repr(attachments)} ' + f'as it is immutable.' + ) + raise AttributeError(msg) + if len(attachments) <= 2: + msg = ( + f'Value {repr(attachments)} passed to `attachments` was an ' + f'iterable of length {len(attachments)}, must be an iterable ' + f'of length 3 or greater.' + ) + raise ValueError(msg) + for i, point in enumerate(attachments): + if not isinstance(point, Point): + msg = ( + f'Value {repr(point)} passed to `attachments` at index ' + f'{i} was of type {type(point)}, must be {Point}.' + ) + raise TypeError(msg) + self._attachments = tuple(attachments) + + @property + def length(self): + """Exact analytical expression for the pathway's length.""" + length = S.Zero + attachment_pairs = zip(self.attachments[:-1], self.attachments[1:]) + for attachment_pair in attachment_pairs: + length += _point_pair_length(*attachment_pair) + return length + + @property + def extension_velocity(self): + """Exact analytical expression for the pathway's extension velocity.""" + extension_velocity = S.Zero + attachment_pairs = zip(self.attachments[:-1], self.attachments[1:]) + for attachment_pair in attachment_pairs: + extension_velocity += _point_pair_extension_velocity(*attachment_pair) + return extension_velocity + + def to_loads(self, force): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced in an + actuator that follows an obstacle-set pathway between four points and + produces an expansile force ``F``. First, create a pair of reference + frames, ``A`` and ``B``, in which the four points ``pA``, ``pB``, + ``pC``, and ``pD`` will be located. The first two points in frame ``A`` + and the second two in frame ``B``. Frame ``B`` will also be oriented + such that it relates to ``A`` via a rotation of ``q`` about an axis + ``N.z`` in a global frame (``N.z``, ``A.z``, and ``B.z`` are parallel). + + >>> from sympy.physics.mechanics import (ObstacleSetPathway, Point, + ... ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> q = dynamicsymbols('q') + >>> N = ReferenceFrame('N') + >>> N = ReferenceFrame('N') + >>> A = N.orientnew('A', 'axis', (0, N.x)) + >>> B = A.orientnew('B', 'axis', (q, N.z)) + >>> pO = Point('pO') + >>> pA, pB, pC, pD = Point('pA'), Point('pB'), Point('pC'), Point('pD') + >>> pA.set_pos(pO, A.x) + >>> pB.set_pos(pO, -A.y) + >>> pC.set_pos(pO, B.y) + >>> pD.set_pos(pO, B.x) + >>> obstacle_set_pathway = ObstacleSetPathway(pA, pB, pC, pD) + + Now create a symbol ``F`` to describe the magnitude of the (expansile) + force that will be produced along the pathway. The list of loads that + ``KanesMethod`` requires can be produced by calling the pathway's + ``to_loads`` method with ``F`` passed as the only argument. + + >>> from sympy import Symbol + >>> F = Symbol('F') + >>> obstacle_set_pathway.to_loads(F) + [(pA, sqrt(2)*F/2*A.x + sqrt(2)*F/2*A.y), + (pB, - sqrt(2)*F/2*A.x - sqrt(2)*F/2*A.y), + (pB, - F/sqrt(2*cos(q(t)) + 2)*A.y - F/sqrt(2*cos(q(t)) + 2)*B.y), + (pC, F/sqrt(2*cos(q(t)) + 2)*A.y + F/sqrt(2*cos(q(t)) + 2)*B.y), + (pC, - sqrt(2)*F/2*B.x + sqrt(2)*F/2*B.y), + (pD, sqrt(2)*F/2*B.x - sqrt(2)*F/2*B.y)] + + Parameters + ========== + + force : Expr + The force acting along the length of the pathway. It is assumed + that this ``Expr`` represents an expansile force. + + """ + loads = [] + attachment_pairs = zip(self.attachments[:-1], self.attachments[1:]) + for attachment_pair in attachment_pairs: + relative_position = _point_pair_relative_position(*attachment_pair) + length = _point_pair_length(*attachment_pair) + loads.extend([ + Force(attachment_pair[0], -force*relative_position/length), + Force(attachment_pair[1], force*relative_position/length), + ]) + return loads + + +class WrappingPathway(PathwayBase): + """Pathway that wraps a geometry object. + + Explanation + =========== + + A wrapping pathway interacts with a geometry object and forms a path that + wraps smoothly along its surface. The wrapping pathway along the geometry + object will be the geodesic that the geometry object defines based on the + two points. It will not interact with any other objects in the system, i.e. + a ``WrappingPathway`` will intersect other objects to ensure that the path + between its two ends (its attachments) is the shortest possible. + + To explain the sign conventions used for pathway length, extension + velocity, and direction of applied forces, we can ignore the geometry with + which the wrapping pathway interacts. A wrapping pathway is made up of two + points that can move relative to each other, and a pair of equal and + opposite forces acting on the points. If the positive time-varying + Euclidean distance between the two points is defined, then the "extension + velocity" is the time derivative of this distance. The extension velocity + is positive when the two points are moving away from each other and + negative when moving closer to each other. The direction for the force + acting on either point is determined by constructing a unit vector directed + from the other point to this point. This establishes a sign convention such + that a positive force magnitude tends to push the points apart. The + following diagram shows the positive force sense and the distance between + the points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + >>> from sympy.physics.mechanics import WrappingPathway + + To construct a wrapping pathway, like other pathways, a pair of points must + be passed, followed by an instance of a wrapping geometry class as a + keyword argument. We'll use a cylinder with radius ``r`` and its axis + parallel to ``N.x`` passing through a point ``pO``. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Point, ReferenceFrame, WrappingCylinder + >>> r = symbols('r') + >>> N = ReferenceFrame('N') + >>> pA, pB, pO = Point('pA'), Point('pB'), Point('pO') + >>> cylinder = WrappingCylinder(r, pO, N.x) + >>> wrapping_pathway = WrappingPathway(pA, pB, cylinder) + >>> wrapping_pathway + WrappingPathway(pA, pB, geometry=WrappingCylinder(radius=r, point=pO, + axis=N.x)) + + Parameters + ========== + + attachment_1 : Point + First of the pair of ``Point`` objects between which the wrapping + pathway spans. + attachment_2 : Point + Second of the pair of ``Point`` objects between which the wrapping + pathway spans. + geometry : WrappingGeometryBase + Geometry about which the pathway wraps. + + """ + + def __init__(self, attachment_1, attachment_2, geometry): + """Initializer for ``WrappingPathway``. + + Parameters + ========== + + attachment_1 : Point + First of the pair of ``Point`` objects between which the wrapping + pathway spans. + attachment_2 : Point + Second of the pair of ``Point`` objects between which the wrapping + pathway spans. + geometry : WrappingGeometryBase + Geometry about which the pathway wraps. + The geometry about which the pathway wraps. + + """ + super().__init__(attachment_1, attachment_2) + self.geometry = geometry + + @property + def geometry(self): + """Geometry around which the pathway wraps.""" + return self._geometry + + @geometry.setter + def geometry(self, geometry): + if hasattr(self, '_geometry'): + msg = ( + f'Can\'t set attribute `geometry` to {repr(geometry)} as it ' + f'is immutable.' + ) + raise AttributeError(msg) + if not isinstance(geometry, WrappingGeometryBase): + msg = ( + f'Value {repr(geometry)} passed to `geometry` was of type ' + f'{type(geometry)}, must be {WrappingGeometryBase}.' + ) + raise TypeError(msg) + self._geometry = geometry + + @property + def length(self): + """Exact analytical expression for the pathway's length.""" + return self.geometry.geodesic_length(*self.attachments) + + @property + def extension_velocity(self): + """Exact analytical expression for the pathway's extension velocity.""" + return self.length.diff(dynamicsymbols._t) + + def to_loads(self, force): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced in an + actuator that produces an expansile force ``F`` while wrapping around a + cylinder. First, create a cylinder with radius ``r`` and an axis + parallel to the ``N.z`` direction of the global frame ``N`` that also + passes through a point ``pO``. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (Point, ReferenceFrame, + ... WrappingCylinder) + >>> N = ReferenceFrame('N') + >>> r = symbols('r', positive=True) + >>> pO = Point('pO') + >>> cylinder = WrappingCylinder(r, pO, N.z) + + Create the pathway of the actuator using the ``WrappingPathway`` class, + defined to span between two points ``pA`` and ``pB``. Both points lie + on the surface of the cylinder and the location of ``pB`` is defined + relative to ``pA`` by the dynamics symbol ``q``. + + >>> from sympy import cos, sin + >>> from sympy.physics.mechanics import WrappingPathway, dynamicsymbols + >>> q = dynamicsymbols('q') + >>> pA = Point('pA') + >>> pB = Point('pB') + >>> pA.set_pos(pO, r*N.x) + >>> pB.set_pos(pO, r*(cos(q)*N.x + sin(q)*N.y)) + >>> pB.pos_from(pA) + (r*cos(q(t)) - r)*N.x + r*sin(q(t))*N.y + >>> pathway = WrappingPathway(pA, pB, cylinder) + + Now create a symbol ``F`` to describe the magnitude of the (expansile) + force that will be produced along the pathway. The list of loads that + ``KanesMethod`` requires can be produced by calling the pathway's + ``to_loads`` method with ``F`` passed as the only argument. + + >>> F = symbols('F') + >>> loads = pathway.to_loads(F) + >>> [load.__class__(load.location, load.vector.simplify()) for load in loads] + [(pA, F*N.y), (pB, F*sin(q(t))*N.x - F*cos(q(t))*N.y), + (pO, - F*sin(q(t))*N.x + F*(cos(q(t)) - 1)*N.y)] + + Parameters + ========== + + force : Expr + Magnitude of the force acting along the length of the pathway. It + is assumed that this ``Expr`` represents an expansile force. + + """ + pA, pB = self.attachments + pO = self.geometry.point + pA_force, pB_force = self.geometry.geodesic_end_vectors(pA, pB) + pO_force = -(pA_force + pB_force) + + loads = [ + Force(pA, force * pA_force), + Force(pB, force * pB_force), + Force(pO, force * pO_force), + ] + return loads + + def __repr__(self): + """Representation of a ``WrappingPathway``.""" + attachments = ', '.join(str(a) for a in self.attachments) + return ( + f'{self.__class__.__name__}({attachments}, ' + f'geometry={self.geometry})' + ) + + +def _point_pair_relative_position(point_1, point_2): + """The relative position between a pair of points.""" + return point_2.pos_from(point_1) + + +def _point_pair_length(point_1, point_2): + """The length of the direct linear path between two points.""" + return _point_pair_relative_position(point_1, point_2).magnitude() + + +def _point_pair_extension_velocity(point_1, point_2): + """The extension velocity of the direct linear path between two points.""" + return _point_pair_length(point_1, point_2).diff(dynamicsymbols._t) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/rigidbody.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/rigidbody.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc61ff468f7f26d98209a48ca59ffa12a570490 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/rigidbody.py @@ -0,0 +1,314 @@ +from sympy import Symbol, S +from sympy.physics.vector import ReferenceFrame, Dyadic, Point, dot +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.inertia import inertia_of_point_mass, Inertia +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['RigidBody'] + + +class RigidBody(BodyBase): + """An idealized rigid body. + + Explanation + =========== + + This is essentially a container which holds the various components which + describe a rigid body: a name, mass, center of mass, reference frame, and + inertia. + + All of these need to be supplied on creation, but can be changed + afterwards. + + Attributes + ========== + + name : string + The body's name. + masscenter : Point + The point which represents the center of mass of the rigid body. + frame : ReferenceFrame + The ReferenceFrame which the rigid body is fixed in. + mass : Sympifyable + The body's mass. + inertia : (Dyadic, Point) + The body's inertia about a point; stored in a tuple as shown above. + potential_energy : Sympifyable + The potential energy of the RigidBody. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.mechanics import ReferenceFrame, Point, RigidBody + >>> from sympy.physics.mechanics import outer + >>> m = Symbol('m') + >>> A = ReferenceFrame('A') + >>> P = Point('P') + >>> I = outer (A.x, A.x) + >>> inertia_tuple = (I, P) + >>> B = RigidBody('B', P, A, m, inertia_tuple) + >>> # Or you could change them afterwards + >>> m2 = Symbol('m2') + >>> B.mass = m2 + + """ + + def __init__(self, name, masscenter=None, frame=None, mass=None, + inertia=None): + super().__init__(name, masscenter, mass) + if frame is None: + frame = ReferenceFrame(f'{name}_frame') + self.frame = frame + if inertia is None: + ixx = Symbol(f'{name}_ixx') + iyy = Symbol(f'{name}_iyy') + izz = Symbol(f'{name}_izz') + izx = Symbol(f'{name}_izx') + ixy = Symbol(f'{name}_ixy') + iyz = Symbol(f'{name}_iyz') + inertia = Inertia.from_inertia_scalars(self.masscenter, self.frame, + ixx, iyy, izz, ixy, iyz, izx) + self.inertia = inertia + + def __repr__(self): + return (f'{self.__class__.__name__}({repr(self.name)}, masscenter=' + f'{repr(self.masscenter)}, frame={repr(self.frame)}, mass=' + f'{repr(self.mass)}, inertia={repr(self.inertia)})') + + @property + def frame(self): + """The ReferenceFrame fixed to the body.""" + return self._frame + + @frame.setter + def frame(self, F): + if not isinstance(F, ReferenceFrame): + raise TypeError("RigidBody frame must be a ReferenceFrame object.") + self._frame = F + + @property + def x(self): + """The basis Vector for the body, in the x direction. """ + return self.frame.x + + @property + def y(self): + """The basis Vector for the body, in the y direction. """ + return self.frame.y + + @property + def z(self): + """The basis Vector for the body, in the z direction. """ + return self.frame.z + + @property + def inertia(self): + """The body's inertia about a point; stored as (Dyadic, Point).""" + return self._inertia + + @inertia.setter + def inertia(self, I): + # check if I is of the form (Dyadic, Point) + if len(I) != 2 or not isinstance(I[0], Dyadic) or not isinstance(I[1], Point): + raise TypeError("RigidBody inertia must be a tuple of the form (Dyadic, Point).") + + self._inertia = Inertia(I[0], I[1]) + # have I S/O, want I S/S* + # I S/O = I S/S* + I S*/O; I S/S* = I S/O - I S*/O + # I_S/S* = I_S/O - I_S*/O + I_Ss_O = inertia_of_point_mass(self.mass, + self.masscenter.pos_from(I[1]), + self.frame) + self._central_inertia = I[0] - I_Ss_O + + @property + def central_inertia(self): + """The body's central inertia dyadic.""" + return self._central_inertia + + @central_inertia.setter + def central_inertia(self, I): + if not isinstance(I, Dyadic): + raise TypeError("RigidBody inertia must be a Dyadic object.") + self.inertia = Inertia(I, self.masscenter) + + def linear_momentum(self, frame): + """ Linear momentum of the rigid body. + + Explanation + =========== + + The linear momentum L, of a rigid body B, with respect to frame N is + given by: + + ``L = m * v`` + + where m is the mass of the rigid body, and v is the velocity of the mass + center of B in the frame N. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which linear momentum is desired. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, ReferenceFrame, outer + >>> from sympy.physics.mechanics import RigidBody, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> m, v = dynamicsymbols('m v') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, v * N.x) + >>> I = outer (N.x, N.x) + >>> Inertia_tuple = (I, P) + >>> B = RigidBody('B', P, N, m, Inertia_tuple) + >>> B.linear_momentum(N) + m*v*N.x + + """ + + return self.mass * self.masscenter.vel(frame) + + def angular_momentum(self, point, frame): + """Returns the angular momentum of the rigid body about a point in the + given frame. + + Explanation + =========== + + The angular momentum H of a rigid body B about some point O in a frame N + is given by: + + ``H = dot(I, w) + cross(r, m * v)`` + + where I and m are the central inertia dyadic and mass of rigid body B, w + is the angular velocity of body B in the frame N, r is the position + vector from point O to the mass center of B, and v is the velocity of + the mass center in the frame N. + + Parameters + ========== + + point : Point + The point about which angular momentum is desired. + frame : ReferenceFrame + The frame in which angular momentum is desired. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, ReferenceFrame, outer + >>> from sympy.physics.mechanics import RigidBody, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> m, v, r, omega = dynamicsymbols('m v r omega') + >>> N = ReferenceFrame('N') + >>> b = ReferenceFrame('b') + >>> b.set_ang_vel(N, omega * b.x) + >>> P = Point('P') + >>> P.set_vel(N, 1 * N.x) + >>> I = outer(b.x, b.x) + >>> B = RigidBody('B', P, b, m, (I, P)) + >>> B.angular_momentum(P, N) + omega*b.x + + """ + I = self.central_inertia + w = self.frame.ang_vel_in(frame) + m = self.mass + r = self.masscenter.pos_from(point) + v = self.masscenter.vel(frame) + + return I.dot(w) + r.cross(m * v) + + def kinetic_energy(self, frame): + """Kinetic energy of the rigid body. + + Explanation + =========== + + The kinetic energy, T, of a rigid body, B, is given by: + + ``T = 1/2 * (dot(dot(I, w), w) + dot(m * v, v))`` + + where I and m are the central inertia dyadic and mass of rigid body B + respectively, w is the body's angular velocity, and v is the velocity of + the body's mass center in the supplied ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The RigidBody's angular velocity and the velocity of it's mass + center are typically defined with respect to an inertial frame but + any relevant frame in which the velocities are known can be + supplied. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, ReferenceFrame, outer + >>> from sympy.physics.mechanics import RigidBody + >>> from sympy import symbols + >>> m, v, r, omega = symbols('m v r omega') + >>> N = ReferenceFrame('N') + >>> b = ReferenceFrame('b') + >>> b.set_ang_vel(N, omega * b.x) + >>> P = Point('P') + >>> P.set_vel(N, v * N.x) + >>> I = outer (b.x, b.x) + >>> inertia_tuple = (I, P) + >>> B = RigidBody('B', P, b, m, inertia_tuple) + >>> B.kinetic_energy(N) + m*v**2/2 + omega**2/2 + + """ + + rotational_KE = S.Half * dot( + self.frame.ang_vel_in(frame), + dot(self.central_inertia, self.frame.ang_vel_in(frame))) + translational_KE = S.Half * self.mass * dot(self.masscenter.vel(frame), + self.masscenter.vel(frame)) + return rotational_KE + translational_KE + + def set_potential_energy(self, scalar): + sympy_deprecation_warning( + """ +The sympy.physics.mechanics.RigidBody.set_potential_energy() +method is deprecated. Instead use + + B.potential_energy = scalar + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-set-potential-energy", + ) + self.potential_energy = scalar + + def parallel_axis(self, point, frame=None): + """Returns the inertia dyadic of the body with respect to another point. + + Parameters + ========== + + point : sympy.physics.vector.Point + The point to express the inertia dyadic about. + frame : sympy.physics.vector.ReferenceFrame + The reference frame used to construct the dyadic. + + Returns + ======= + + inertia : sympy.physics.vector.Dyadic + The inertia dyadic of the rigid body expressed about the provided + point. + + """ + if frame is None: + frame = self.frame + return self.central_inertia + inertia_of_point_mass( + self.mass, self.masscenter.pos_from(point), frame) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/system.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/system.py new file mode 100644 index 0000000000000000000000000000000000000000..c8e0657d7da54ca5aaad9b37b816235641968470 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/system.py @@ -0,0 +1,1553 @@ +from functools import wraps + +from sympy.core.basic import Basic +from sympy.matrices.immutable import ImmutableMatrix +from sympy.matrices.dense import Matrix, eye, zeros +from sympy.core.containers import OrderedSet +from sympy.physics.mechanics.actuator import ActuatorBase +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.functions import ( + Lagrangian, _validate_coordinates, find_dynamicsymbols) +from sympy.physics.mechanics.joint import Joint +from sympy.physics.mechanics.kane import KanesMethod +from sympy.physics.mechanics.lagrange import LagrangesMethod +from sympy.physics.mechanics.loads import _parse_load, gravity +from sympy.physics.mechanics.method import _Methods +from sympy.physics.mechanics.particle import Particle +from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols +from sympy.utilities.iterables import iterable +from sympy.utilities.misc import filldedent + +__all__ = ['SymbolicSystem', 'System'] + + +def _reset_eom_method(method): + """Decorator to reset the eom_method if a property is changed.""" + + @wraps(method) + def wrapper(self, *args, **kwargs): + self._eom_method = None + return method(self, *args, **kwargs) + + return wrapper + + +class System(_Methods): + """Class to define a multibody system and form its equations of motion. + + Explanation + =========== + + A ``System`` instance stores the different objects associated with a model, + including bodies, joints, constraints, and other relevant information. With + all the relationships between components defined, the ``System`` can be used + to form the equations of motion using a backend, such as ``KanesMethod``. + The ``System`` has been designed to be compatible with third-party + libraries for greater flexibility and integration with other tools. + + Attributes + ========== + + frame : ReferenceFrame + Inertial reference frame of the system. + fixed_point : Point + A fixed point in the inertial reference frame. + x : Vector + Unit vector fixed in the inertial reference frame. + y : Vector + Unit vector fixed in the inertial reference frame. + z : Vector + Unit vector fixed in the inertial reference frame. + q : ImmutableMatrix + Matrix of all the generalized coordinates, i.e. the independent + generalized coordinates stacked upon the dependent. + u : ImmutableMatrix + Matrix of all the generalized speeds, i.e. the independent generealized + speeds stacked upon the dependent. + q_ind : ImmutableMatrix + Matrix of the independent generalized coordinates. + q_dep : ImmutableMatrix + Matrix of the dependent generalized coordinates. + u_ind : ImmutableMatrix + Matrix of the independent generalized speeds. + u_dep : ImmutableMatrix + Matrix of the dependent generalized speeds. + u_aux : ImmutableMatrix + Matrix of auxiliary generalized speeds. + kdes : ImmutableMatrix + Matrix of the kinematical differential equations as expressions equated + to the zero matrix. + bodies : tuple of BodyBase subclasses + Tuple of all bodies that make up the system. + joints : tuple of Joint + Tuple of all joints that connect bodies in the system. + loads : tuple of LoadBase subclasses + Tuple of all loads that have been applied to the system. + actuators : tuple of ActuatorBase subclasses + Tuple of all actuators present in the system. + holonomic_constraints : ImmutableMatrix + Matrix with the holonomic constraints as expressions equated to the zero + matrix. + nonholonomic_constraints : ImmutableMatrix + Matrix with the nonholonomic constraints as expressions equated to the + zero matrix. + velocity_constraints : ImmutableMatrix + Matrix with the velocity constraints as expressions equated to the zero + matrix. These are by default derived as the time derivatives of the + holonomic constraints extended with the nonholonomic constraints. + eom_method : subclass of KanesMethod or LagrangesMethod + Backend for forming the equations of motion. + + Examples + ======== + + In the example below a cart with a pendulum is created. The cart moves along + the x axis of the rail and the pendulum rotates about the z axis. The length + of the pendulum is ``l`` with the pendulum represented as a particle. To + move the cart a time dependent force ``F`` is applied to the cart. + + We first need to import some functions and create some of our variables. + + >>> from sympy import symbols, simplify + >>> from sympy.physics.mechanics import ( + ... mechanics_printing, dynamicsymbols, RigidBody, Particle, + ... ReferenceFrame, PrismaticJoint, PinJoint, System) + >>> mechanics_printing(pretty_print=False) + >>> g, l = symbols('g l') + >>> F = dynamicsymbols('F') + + The next step is to create bodies. It is also useful to create a frame for + locating the particle with respect to the pin joint later on, as a particle + does not have a body-fixed frame. + + >>> rail = RigidBody('rail') + >>> cart = RigidBody('cart') + >>> bob = Particle('bob') + >>> bob_frame = ReferenceFrame('bob_frame') + + Initialize the system, with the rail as the Newtonian reference. The body is + also automatically added to the system. + + >>> system = System.from_newtonian(rail) + >>> print(system.bodies[0]) + rail + + Create the joints, while immediately also adding them to the system. + + >>> system.add_joints( + ... PrismaticJoint('slider', rail, cart, joint_axis=rail.x), + ... PinJoint('pin', cart, bob, joint_axis=cart.z, + ... child_interframe=bob_frame, + ... child_point=l * bob_frame.y) + ... ) + >>> system.joints + (PrismaticJoint: slider parent: rail child: cart, + PinJoint: pin parent: cart child: bob) + + While adding the joints, the associated generalized coordinates, generalized + speeds, kinematic differential equations and bodies are also added to the + system. + + >>> system.q + Matrix([ + [q_slider], + [ q_pin]]) + >>> system.u + Matrix([ + [u_slider], + [ u_pin]]) + >>> system.kdes + Matrix([ + [u_slider - q_slider'], + [ u_pin - q_pin']]) + >>> [body.name for body in system.bodies] + ['rail', 'cart', 'bob'] + + With the kinematics established, we can now apply gravity and the cart force + ``F``. + + >>> system.apply_uniform_gravity(-g * system.y) + >>> system.add_loads((cart.masscenter, F * rail.x)) + >>> system.loads + ((rail_masscenter, - g*rail_mass*rail_frame.y), + (cart_masscenter, - cart_mass*g*rail_frame.y), + (bob_masscenter, - bob_mass*g*rail_frame.y), + (cart_masscenter, F*rail_frame.x)) + + With the entire system defined, we can now form the equations of motion. + Before forming the equations of motion, one can also run some checks that + will try to identify some common errors. + + >>> system.validate_system() + >>> system.form_eoms() + Matrix([ + [bob_mass*l*u_pin**2*sin(q_pin) - bob_mass*l*cos(q_pin)*u_pin' + - (bob_mass + cart_mass)*u_slider' + F], + [ -bob_mass*g*l*sin(q_pin) - bob_mass*l**2*u_pin' + - bob_mass*l*cos(q_pin)*u_slider']]) + >>> simplify(system.mass_matrix) + Matrix([ + [ bob_mass + cart_mass, bob_mass*l*cos(q_pin)], + [bob_mass*l*cos(q_pin), bob_mass*l**2]]) + >>> system.forcing + Matrix([ + [bob_mass*l*u_pin**2*sin(q_pin) + F], + [ -bob_mass*g*l*sin(q_pin)]]) + + The complexity of the above example can be increased if we add a constraint + to prevent the particle from moving in the horizontal (x) direction. This + can be done by adding a holonomic constraint. After which we should also + redefine what our (in)dependent generalized coordinates and speeds are. + + >>> system.add_holonomic_constraints( + ... bob.masscenter.pos_from(rail.masscenter).dot(system.x) + ... ) + >>> system.q_ind = system.get_joint('pin').coordinates + >>> system.q_dep = system.get_joint('slider').coordinates + >>> system.u_ind = system.get_joint('pin').speeds + >>> system.u_dep = system.get_joint('slider').speeds + + With the updated system the equations of motion can be formed again. + + >>> system.validate_system() + >>> system.form_eoms() + Matrix([[-bob_mass*g*l*sin(q_pin) + - bob_mass*l**2*u_pin' + - bob_mass*l*cos(q_pin)*u_slider' + - l*(bob_mass*l*u_pin**2*sin(q_pin) + - bob_mass*l*cos(q_pin)*u_pin' + - (bob_mass + cart_mass)*u_slider')*cos(q_pin) + - l*F*cos(q_pin)]]) + >>> simplify(system.mass_matrix) + Matrix([ + [bob_mass*l**2*sin(q_pin)**2, -cart_mass*l*cos(q_pin)], + [ l*cos(q_pin), 1]]) + >>> simplify(system.forcing) + Matrix([ + [-l*(bob_mass*g*sin(q_pin) + bob_mass*l*u_pin**2*sin(2*q_pin)/2 + + F*cos(q_pin))], + [ + l*u_pin**2*sin(q_pin)]]) + + """ + + def __init__(self, frame=None, fixed_point=None): + """Initialize the system. + + Parameters + ========== + + frame : ReferenceFrame, optional + The inertial frame of the system. If none is supplied, a new frame + will be created. + fixed_point : Point, optional + A fixed point in the inertial reference frame. If none is supplied, + a new fixed_point will be created. + + """ + if frame is None: + frame = ReferenceFrame('inertial_frame') + elif not isinstance(frame, ReferenceFrame): + raise TypeError('Frame must be an instance of ReferenceFrame.') + self._frame = frame + if fixed_point is None: + fixed_point = Point('inertial_point') + elif not isinstance(fixed_point, Point): + raise TypeError('Fixed point must be an instance of Point.') + self._fixed_point = fixed_point + self._fixed_point.set_vel(self._frame, 0) + self._q_ind = ImmutableMatrix(1, 0, []).T + self._q_dep = ImmutableMatrix(1, 0, []).T + self._u_ind = ImmutableMatrix(1, 0, []).T + self._u_dep = ImmutableMatrix(1, 0, []).T + self._u_aux = ImmutableMatrix(1, 0, []).T + self._kdes = ImmutableMatrix(1, 0, []).T + self._hol_coneqs = ImmutableMatrix(1, 0, []).T + self._nonhol_coneqs = ImmutableMatrix(1, 0, []).T + self._vel_constrs = None + self._bodies = [] + self._joints = [] + self._loads = [] + self._actuators = [] + self._eom_method = None + + @classmethod + def from_newtonian(cls, newtonian): + """Constructs the system with respect to a Newtonian body.""" + if isinstance(newtonian, Particle): + raise TypeError('A Particle has no frame so cannot act as ' + 'the Newtonian.') + system = cls(frame=newtonian.frame, fixed_point=newtonian.masscenter) + system.add_bodies(newtonian) + return system + + @property + def fixed_point(self): + """Fixed point in the inertial reference frame.""" + return self._fixed_point + + @property + def frame(self): + """Inertial reference frame of the system.""" + return self._frame + + @property + def x(self): + """Unit vector fixed in the inertial reference frame.""" + return self._frame.x + + @property + def y(self): + """Unit vector fixed in the inertial reference frame.""" + return self._frame.y + + @property + def z(self): + """Unit vector fixed in the inertial reference frame.""" + return self._frame.z + + @property + def bodies(self): + """Tuple of all bodies that have been added to the system.""" + return tuple(self._bodies) + + @bodies.setter + @_reset_eom_method + def bodies(self, bodies): + bodies = self._objects_to_list(bodies) + self._check_objects(bodies, [], BodyBase, 'Bodies', 'bodies') + self._bodies = bodies + + @property + def joints(self): + """Tuple of all joints that have been added to the system.""" + return tuple(self._joints) + + @joints.setter + @_reset_eom_method + def joints(self, joints): + joints = self._objects_to_list(joints) + self._check_objects(joints, [], Joint, 'Joints', 'joints') + self._joints = [] + self.add_joints(*joints) + + @property + def loads(self): + """Tuple of loads that have been applied on the system.""" + return tuple(self._loads) + + @loads.setter + @_reset_eom_method + def loads(self, loads): + loads = self._objects_to_list(loads) + self._loads = [_parse_load(load) for load in loads] + + @property + def actuators(self): + """Tuple of actuators present in the system.""" + return tuple(self._actuators) + + @actuators.setter + @_reset_eom_method + def actuators(self, actuators): + actuators = self._objects_to_list(actuators) + self._check_objects(actuators, [], ActuatorBase, 'Actuators', + 'actuators') + self._actuators = actuators + + @property + def q(self): + """Matrix of all the generalized coordinates with the independent + stacked upon the dependent.""" + return self._q_ind.col_join(self._q_dep) + + @property + def u(self): + """Matrix of all the generalized speeds with the independent stacked + upon the dependent.""" + return self._u_ind.col_join(self._u_dep) + + @property + def q_ind(self): + """Matrix of the independent generalized coordinates.""" + return self._q_ind + + @q_ind.setter + @_reset_eom_method + def q_ind(self, q_ind): + self._q_ind, self._q_dep = self._parse_coordinates( + self._objects_to_list(q_ind), True, [], self.q_dep, 'coordinates') + + @property + def q_dep(self): + """Matrix of the dependent generalized coordinates.""" + return self._q_dep + + @q_dep.setter + @_reset_eom_method + def q_dep(self, q_dep): + self._q_ind, self._q_dep = self._parse_coordinates( + self._objects_to_list(q_dep), False, self.q_ind, [], 'coordinates') + + @property + def u_ind(self): + """Matrix of the independent generalized speeds.""" + return self._u_ind + + @u_ind.setter + @_reset_eom_method + def u_ind(self, u_ind): + self._u_ind, self._u_dep = self._parse_coordinates( + self._objects_to_list(u_ind), True, [], self.u_dep, 'speeds') + + @property + def u_dep(self): + """Matrix of the dependent generalized speeds.""" + return self._u_dep + + @u_dep.setter + @_reset_eom_method + def u_dep(self, u_dep): + self._u_ind, self._u_dep = self._parse_coordinates( + self._objects_to_list(u_dep), False, self.u_ind, [], 'speeds') + + @property + def u_aux(self): + """Matrix of auxiliary generalized speeds.""" + return self._u_aux + + @u_aux.setter + @_reset_eom_method + def u_aux(self, u_aux): + self._u_aux = self._parse_coordinates( + self._objects_to_list(u_aux), True, [], [], 'u_auxiliary')[0] + + @property + def kdes(self): + """Kinematical differential equations as expressions equated to the zero + matrix. These equations describe the coupling between the generalized + coordinates and the generalized speeds.""" + return self._kdes + + @kdes.setter + @_reset_eom_method + def kdes(self, kdes): + kdes = self._objects_to_list(kdes) + self._kdes = self._parse_expressions( + kdes, [], 'kinematic differential equations') + + @property + def holonomic_constraints(self): + """Matrix with the holonomic constraints as expressions equated to the + zero matrix.""" + return self._hol_coneqs + + @holonomic_constraints.setter + @_reset_eom_method + def holonomic_constraints(self, constraints): + constraints = self._objects_to_list(constraints) + self._hol_coneqs = self._parse_expressions( + constraints, [], 'holonomic constraints') + + @property + def nonholonomic_constraints(self): + """Matrix with the nonholonomic constraints as expressions equated to + the zero matrix.""" + return self._nonhol_coneqs + + @nonholonomic_constraints.setter + @_reset_eom_method + def nonholonomic_constraints(self, constraints): + constraints = self._objects_to_list(constraints) + self._nonhol_coneqs = self._parse_expressions( + constraints, [], 'nonholonomic constraints') + + @property + def velocity_constraints(self): + """Matrix with the velocity constraints as expressions equated to the + zero matrix. The velocity constraints are by default derived from the + holonomic and nonholonomic constraints unless they are explicitly set. + """ + if self._vel_constrs is None: + return self.holonomic_constraints.diff(dynamicsymbols._t).col_join( + self.nonholonomic_constraints) + return self._vel_constrs + + @velocity_constraints.setter + @_reset_eom_method + def velocity_constraints(self, constraints): + if constraints is None: + self._vel_constrs = None + return + constraints = self._objects_to_list(constraints) + self._vel_constrs = self._parse_expressions( + constraints, [], 'velocity constraints') + + @property + def eom_method(self): + """Backend for forming the equations of motion.""" + return self._eom_method + + @staticmethod + def _objects_to_list(lst): + """Helper to convert passed objects to a list.""" + if not iterable(lst): # Only one object + return [lst] + return list(lst[:]) # converts Matrix and tuple to flattened list + + @staticmethod + def _check_objects(objects, obj_lst, expected_type, obj_name, type_name): + """Helper to check the objects that are being added to the system. + + Explanation + =========== + This method checks that the objects that are being added to the system + are of the correct type and have not already been added. If any of the + objects are not of the correct type or have already been added, then + an error is raised. + + Parameters + ========== + objects : iterable + The objects that would be added to the system. + obj_lst : list + The list of objects that are already in the system. + expected_type : type + The type that the objects should be. + obj_name : str + The name of the category of objects. This string is used to + formulate the error message for the user. + type_name : str + The name of the type that the objects should be. This string is used + to formulate the error message for the user. + + """ + seen = set(obj_lst) + duplicates = set() + wrong_types = set() + for obj in objects: + if not isinstance(obj, expected_type): + wrong_types.add(obj) + if obj in seen: + duplicates.add(obj) + else: + seen.add(obj) + if wrong_types: + raise TypeError(f'{obj_name} {wrong_types} are not {type_name}.') + if duplicates: + raise ValueError(f'{obj_name} {duplicates} have already been added ' + f'to the system.') + + def _parse_coordinates(self, new_coords, independent, old_coords_ind, + old_coords_dep, coord_type='coordinates'): + """Helper to parse coordinates and speeds.""" + # Construct lists of the independent and dependent coordinates + coords_ind, coords_dep = old_coords_ind[:], old_coords_dep[:] + if not iterable(independent): + independent = [independent] * len(new_coords) + for coord, indep in zip(new_coords, independent): + if indep: + coords_ind.append(coord) + else: + coords_dep.append(coord) + # Check types and duplicates + current = {'coordinates': self.q_ind[:] + self.q_dep[:], + 'speeds': self.u_ind[:] + self.u_dep[:], + 'u_auxiliary': self._u_aux[:], + coord_type: coords_ind + coords_dep} + _validate_coordinates(**current) + return (ImmutableMatrix(1, len(coords_ind), coords_ind).T, + ImmutableMatrix(1, len(coords_dep), coords_dep).T) + + @staticmethod + def _parse_expressions(new_expressions, old_expressions, name, + check_negatives=False): + """Helper to parse expressions like constraints.""" + old_expressions = old_expressions[:] + new_expressions = list(new_expressions) # Converts a possible tuple + if check_negatives: + check_exprs = old_expressions + [-expr for expr in old_expressions] + else: + check_exprs = old_expressions + System._check_objects(new_expressions, check_exprs, Basic, name, + 'expressions') + for expr in new_expressions: + if expr == 0: + raise ValueError(f'Parsed {name} are zero.') + return ImmutableMatrix(1, len(old_expressions) + len(new_expressions), + old_expressions + new_expressions).T + + @_reset_eom_method + def add_coordinates(self, *coordinates, independent=True): + """Add generalized coordinate(s) to the system. + + Parameters + ========== + + *coordinates : dynamicsymbols + One or more generalized coordinates to be added to the system. + independent : bool or list of bool, optional + Boolean whether a coordinate is dependent or independent. The + default is True, so the coordinates are added as independent by + default. + + """ + self._q_ind, self._q_dep = self._parse_coordinates( + coordinates, independent, self.q_ind, self.q_dep, 'coordinates') + + @_reset_eom_method + def add_speeds(self, *speeds, independent=True): + """Add generalized speed(s) to the system. + + Parameters + ========== + + *speeds : dynamicsymbols + One or more generalized speeds to be added to the system. + independent : bool or list of bool, optional + Boolean whether a speed is dependent or independent. The default is + True, so the speeds are added as independent by default. + + """ + self._u_ind, self._u_dep = self._parse_coordinates( + speeds, independent, self.u_ind, self.u_dep, 'speeds') + + @_reset_eom_method + def add_auxiliary_speeds(self, *speeds): + """Add auxiliary speed(s) to the system. + + Parameters + ========== + + *speeds : dynamicsymbols + One or more auxiliary speeds to be added to the system. + + """ + self._u_aux = self._parse_coordinates( + speeds, True, self._u_aux, [], 'u_auxiliary')[0] + + @_reset_eom_method + def add_kdes(self, *kdes): + """Add kinematic differential equation(s) to the system. + + Parameters + ========== + + *kdes : Expr + One or more kinematic differential equations. + + """ + self._kdes = self._parse_expressions( + kdes, self.kdes, 'kinematic differential equations', + check_negatives=True) + + @_reset_eom_method + def add_holonomic_constraints(self, *constraints): + """Add holonomic constraint(s) to the system. + + Parameters + ========== + + *constraints : Expr + One or more holonomic constraints, which are expressions that should + be zero. + + """ + self._hol_coneqs = self._parse_expressions( + constraints, self._hol_coneqs, 'holonomic constraints', + check_negatives=True) + + @_reset_eom_method + def add_nonholonomic_constraints(self, *constraints): + """Add nonholonomic constraint(s) to the system. + + Parameters + ========== + + *constraints : Expr + One or more nonholonomic constraints, which are expressions that + should be zero. + + """ + self._nonhol_coneqs = self._parse_expressions( + constraints, self._nonhol_coneqs, 'nonholonomic constraints', + check_negatives=True) + + @_reset_eom_method + def add_bodies(self, *bodies): + """Add body(ies) to the system. + + Parameters + ========== + + bodies : Particle or RigidBody + One or more bodies. + + """ + self._check_objects(bodies, self.bodies, BodyBase, 'Bodies', 'bodies') + self._bodies.extend(bodies) + + @_reset_eom_method + def add_loads(self, *loads): + """Add load(s) to the system. + + Parameters + ========== + + *loads : Force or Torque + One or more loads. + + """ + loads = [_parse_load(load) for load in loads] # Checks the loads + self._loads.extend(loads) + + @_reset_eom_method + def apply_uniform_gravity(self, acceleration): + """Apply uniform gravity to all bodies in the system by adding loads. + + Parameters + ========== + + acceleration : Vector + The acceleration due to gravity. + + """ + self.add_loads(*gravity(acceleration, *self.bodies)) + + @_reset_eom_method + def add_actuators(self, *actuators): + """Add actuator(s) to the system. + + Parameters + ========== + + *actuators : subclass of ActuatorBase + One or more actuators. + + """ + self._check_objects(actuators, self.actuators, ActuatorBase, + 'Actuators', 'actuators') + self._actuators.extend(actuators) + + @_reset_eom_method + def add_joints(self, *joints): + """Add joint(s) to the system. + + Explanation + =========== + + This methods adds one or more joints to the system including its + associated objects, i.e. generalized coordinates, generalized speeds, + kinematic differential equations and the bodies. + + Parameters + ========== + + *joints : subclass of Joint + One or more joints. + + Notes + ===== + + For the generalized coordinates, generalized speeds and bodies it is + checked whether they are already known by the system instance. If they + are, then they are not added. The kinematic differential equations are + however always added to the system, so you should not also manually add + those on beforehand. + + """ + self._check_objects(joints, self.joints, Joint, 'Joints', 'joints') + self._joints.extend(joints) + coordinates, speeds, kdes, bodies = (OrderedSet() for _ in range(4)) + for joint in joints: + coordinates.update(joint.coordinates) + speeds.update(joint.speeds) + kdes.update(joint.kdes) + bodies.update((joint.parent, joint.child)) + coordinates = coordinates.difference(self.q) + speeds = speeds.difference(self.u) + kdes = kdes.difference(self.kdes[:] + (-self.kdes)[:]) + bodies = bodies.difference(self.bodies) + self.add_coordinates(*tuple(coordinates)) + self.add_speeds(*tuple(speeds)) + self.add_kdes(*(kde for kde in tuple(kdes) if not kde == 0)) + self.add_bodies(*tuple(bodies)) + + def get_body(self, name): + """Retrieve a body from the system by name. + + Parameters + ========== + + name : str + The name of the body to retrieve. + + Returns + ======= + + RigidBody or Particle + The body with the given name, or None if no such body exists. + + """ + for body in self._bodies: + if body.name == name: + return body + + def get_joint(self, name): + """Retrieve a joint from the system by name. + + Parameters + ========== + + name : str + The name of the joint to retrieve. + + Returns + ======= + + subclass of Joint + The joint with the given name, or None if no such joint exists. + + """ + for joint in self._joints: + if joint.name == name: + return joint + + def _form_eoms(self): + return self.form_eoms() + + def form_eoms(self, eom_method=KanesMethod, **kwargs): + """Form the equations of motion of the system. + + Parameters + ========== + + eom_method : subclass of KanesMethod or LagrangesMethod + Backend class to be used for forming the equations of motion. The + default is ``KanesMethod``. + + Returns + ======== + + ImmutableMatrix + Vector of equations of motions. + + Examples + ======== + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + >>> from sympy import S, symbols + >>> from sympy.physics.mechanics import ( + ... LagrangesMethod, dynamicsymbols, PrismaticJoint, Particle, + ... RigidBody, System) + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> m, k, b = symbols('m k b') + >>> wall = RigidBody('W') + >>> system = System.from_newtonian(wall) + >>> bob = Particle('P', mass=m) + >>> bob.potential_energy = S.Half * k * q**2 + >>> system.add_joints(PrismaticJoint('J', wall, bob, q, qd)) + >>> system.add_loads((bob.masscenter, b * qd * system.x)) + >>> system.form_eoms(LagrangesMethod) + Matrix([[-b*Derivative(q(t), t) + k*q(t) + m*Derivative(q(t), (t, 2))]]) + + We can also solve for the states using the 'rhs' method. + + >>> system.rhs() + Matrix([ + [ Derivative(q(t), t)], + [(b*Derivative(q(t), t) - k*q(t))/m]]) + + """ + # KanesMethod does not accept empty iterables + loads = self.loads + tuple( + load for act in self.actuators for load in act.to_loads()) + loads = loads if loads else None + if issubclass(eom_method, KanesMethod): + disallowed_kwargs = { + "frame", "q_ind", "u_ind", "kd_eqs", "q_dependent", + "u_dependent", "u_auxiliary", "configuration_constraints", + "velocity_constraints", "forcelist", "bodies"} + wrong_kwargs = disallowed_kwargs.intersection(kwargs) + if wrong_kwargs: + raise ValueError( + f"The following keyword arguments are not allowed to be " + f"overwritten in {eom_method.__name__}: {wrong_kwargs}.") + kwargs = {"frame": self.frame, "q_ind": self.q_ind, + "u_ind": self.u_ind, "kd_eqs": self.kdes, + "q_dependent": self.q_dep, "u_dependent": self.u_dep, + "configuration_constraints": self.holonomic_constraints, + "velocity_constraints": self.velocity_constraints, + "u_auxiliary": self.u_aux, + "forcelist": loads, "bodies": self.bodies, + "explicit_kinematics": False, **kwargs} + self._eom_method = eom_method(**kwargs) + elif issubclass(eom_method, LagrangesMethod): + disallowed_kwargs = { + "frame", "qs", "forcelist", "bodies", "hol_coneqs", + "nonhol_coneqs", "Lagrangian"} + wrong_kwargs = disallowed_kwargs.intersection(kwargs) + if wrong_kwargs: + raise ValueError( + f"The following keyword arguments are not allowed to be " + f"overwritten in {eom_method.__name__}: {wrong_kwargs}.") + kwargs = {"frame": self.frame, "qs": self.q, "forcelist": loads, + "bodies": self.bodies, + "hol_coneqs": self.holonomic_constraints, + "nonhol_coneqs": self.nonholonomic_constraints, **kwargs} + if "Lagrangian" not in kwargs: + kwargs["Lagrangian"] = Lagrangian(kwargs["frame"], + *kwargs["bodies"]) + self._eom_method = eom_method(**kwargs) + else: + raise NotImplementedError(f'{eom_method} has not been implemented.') + return self.eom_method._form_eoms() + + def rhs(self, inv_method=None): + """Compute the equations of motion in the explicit form. + + Parameters + ========== + + inv_method : str + The specific sympy inverse matrix calculation method to use. For a + list of valid methods, see + :meth:`~sympy.matrices.matrixbase.MatrixBase.inv` + + Returns + ======== + + ImmutableMatrix + Equations of motion in the explicit form. + + See Also + ======== + + sympy.physics.mechanics.kane.KanesMethod.rhs: + KanesMethod's ``rhs`` function. + sympy.physics.mechanics.lagrange.LagrangesMethod.rhs: + LagrangesMethod's ``rhs`` function. + + """ + return self.eom_method.rhs(inv_method=inv_method) + + @property + def mass_matrix(self): + r"""The mass matrix of the system. + + Explanation + =========== + + The mass matrix $M_d$ and the forcing vector $f_d$ of a system describe + the system's dynamics according to the following equations: + + .. math:: + M_d \dot{u} = f_d + + where $\dot{u}$ is the time derivative of the generalized speeds. + + """ + return self.eom_method.mass_matrix + + @property + def mass_matrix_full(self): + r"""The mass matrix of the system, augmented by the kinematic + differential equations in explicit or implicit form. + + Explanation + =========== + + The full mass matrix $M_m$ and the full forcing vector $f_m$ of a system + describe the dynamics and kinematics according to the following + equation: + + .. math:: + M_m \dot{x} = f_m + + where $x$ is the state vector stacking $q$ and $u$. + + """ + return self.eom_method.mass_matrix_full + + @property + def forcing(self): + """The forcing vector of the system.""" + return self.eom_method.forcing + + @property + def forcing_full(self): + """The forcing vector of the system, augmented by the kinematic + differential equations in explicit or implicit form.""" + return self.eom_method.forcing_full + + def validate_system(self, eom_method=KanesMethod, check_duplicates=False): + """Validates the system using some basic checks. + + Explanation + =========== + + This method validates the system based on the following checks: + + - The number of dependent generalized coordinates should equal the + number of holonomic constraints. + - All generalized coordinates defined by the joints should also be known + to the system. + - If ``KanesMethod`` is used as a ``eom_method``: + - All generalized speeds and kinematic differential equations + defined by the joints should also be known to the system. + - The number of dependent generalized speeds should equal the number + of velocity constraints. + - The number of generalized coordinates should be less than or equal + to the number of generalized speeds. + - The number of generalized coordinates should equal the number of + kinematic differential equations. + - If ``LagrangesMethod`` is used as ``eom_method``: + - There should not be any generalized speeds that are not + derivatives of the generalized coordinates (this includes the + generalized speeds defined by the joints). + + Parameters + ========== + + eom_method : subclass of KanesMethod or LagrangesMethod + Backend class that will be used for forming the equations of motion. + There are different checks for the different backends. The default + is ``KanesMethod``. + check_duplicates : bool + Boolean whether the system should be checked for duplicate + definitions. The default is False, because duplicates are already + checked when adding objects to the system. + + Notes + ===== + + This method is not guaranteed to be backwards compatible as it may + improve over time. The method can become both more and less strict in + certain areas. However a well-defined system should always pass all + these tests. + + """ + msgs = [] + # Save some data in variables + n_hc = self.holonomic_constraints.shape[0] + n_vc = self.velocity_constraints.shape[0] + n_q_dep, n_u_dep = self.q_dep.shape[0], self.u_dep.shape[0] + q_set, u_set = set(self.q), set(self.u) + n_q, n_u = len(q_set), len(u_set) + # Check number of holonomic constraints + if n_q_dep != n_hc: + msgs.append(filldedent(f""" + The number of dependent generalized coordinates {n_q_dep} should be + equal to the number of holonomic constraints {n_hc}.""")) + # Check if all joint coordinates and speeds are present + missing_q = set() + for joint in self.joints: + missing_q.update(set(joint.coordinates).difference(q_set)) + if missing_q: + msgs.append(filldedent(f""" + The generalized coordinates {missing_q} used in joints are not added + to the system.""")) + # Method dependent checks + if issubclass(eom_method, KanesMethod): + n_kdes = len(self.kdes) + missing_kdes, missing_u = set(), set() + for joint in self.joints: + missing_u.update(set(joint.speeds).difference(u_set)) + missing_kdes.update(set(joint.kdes).difference( + self.kdes[:] + (-self.kdes)[:])) + if missing_u: + msgs.append(filldedent(f""" + The generalized speeds {missing_u} used in joints are not added + to the system.""")) + if missing_kdes: + msgs.append(filldedent(f""" + The kinematic differential equations {missing_kdes} used in + joints are not added to the system.""")) + if n_u_dep != n_vc: + msgs.append(filldedent(f""" + The number of dependent generalized speeds {n_u_dep} should be + equal to the number of velocity constraints {n_vc}.""")) + if n_q > n_u: + msgs.append(filldedent(f""" + The number of generalized coordinates {n_q} should be less than + or equal to the number of generalized speeds {n_u}.""")) + if n_u != n_kdes: + msgs.append(filldedent(f""" + The number of generalized speeds {n_u} should be equal to the + number of kinematic differential equations {n_kdes}.""")) + elif issubclass(eom_method, LagrangesMethod): + not_qdots = set(self.u).difference(self.q.diff(dynamicsymbols._t)) + for joint in self.joints: + not_qdots.update(set( + joint.speeds).difference(self.q.diff(dynamicsymbols._t))) + if not_qdots: + msgs.append(filldedent(f""" + The generalized speeds {not_qdots} are not supported by this + method. Only derivatives of the generalized coordinates are + supported. If these symbols are used in your expressions, then + this will result in wrong equations of motion.""")) + if self.u_aux: + msgs.append(filldedent(f""" + This method does not support auxiliary speeds. If these symbols + are used in your expressions, then this will result in wrong + equations of motion. The auxiliary speeds are {self.u_aux}.""")) + else: + raise NotImplementedError(f'{eom_method} has not been implemented.') + if check_duplicates: # Should be redundant + duplicates_to_check = [('generalized coordinates', self.q), + ('generalized speeds', self.u), + ('auxiliary speeds', self.u_aux), + ('bodies', self.bodies), + ('joints', self.joints)] + for name, lst in duplicates_to_check: + seen = set() + duplicates = {x for x in lst if x in seen or seen.add(x)} + if duplicates: + msgs.append(filldedent(f""" + The {name} {duplicates} exist multiple times within the + system.""")) + if msgs: + raise ValueError('\n'.join(msgs)) + + +class SymbolicSystem: + """SymbolicSystem is a class that contains all the information about a + system in a symbolic format such as the equations of motions and the bodies + and loads in the system. + + There are three ways that the equations of motion can be described for + Symbolic System: + + + [1] Explicit form where the kinematics and dynamics are combined + x' = F_1(x, t, r, p) + + [2] Implicit form where the kinematics and dynamics are combined + M_2(x, p) x' = F_2(x, t, r, p) + + [3] Implicit form where the kinematics and dynamics are separate + M_3(q, p) u' = F_3(q, u, t, r, p) + q' = G(q, u, t, r, p) + + where + + x : states, e.g. [q, u] + t : time + r : specified (exogenous) inputs + p : constants + q : generalized coordinates + u : generalized speeds + F_1 : right hand side of the combined equations in explicit form + F_2 : right hand side of the combined equations in implicit form + F_3 : right hand side of the dynamical equations in implicit form + M_2 : mass matrix of the combined equations in implicit form + M_3 : mass matrix of the dynamical equations in implicit form + G : right hand side of the kinematical differential equations + + Parameters + ========== + + coord_states : ordered iterable of functions of time + This input will either be a collection of the coordinates or states + of the system depending on whether or not the speeds are also + given. If speeds are specified this input will be assumed to + be the coordinates otherwise this input will be assumed to + be the states. + + right_hand_side : Matrix + This variable is the right hand side of the equations of motion in + any of the forms. The specific form will be assumed depending on + whether a mass matrix or coordinate derivatives are given. + + speeds : ordered iterable of functions of time, optional + This is a collection of the generalized speeds of the system. If + given it will be assumed that the first argument (coord_states) + will represent the generalized coordinates of the system. + + mass_matrix : Matrix, optional + The matrix of the implicit forms of the equations of motion (forms + [2] and [3]). The distinction between the forms is determined by + whether or not the coordinate derivatives are passed in. If + they are given form [3] will be assumed otherwise form [2] is + assumed. + + coordinate_derivatives : Matrix, optional + The right hand side of the kinematical equations in explicit form. + If given it will be assumed that the equations of motion are being + entered in form [3]. + + alg_con : Iterable, optional + The indexes of the rows in the equations of motion that contain + algebraic constraints instead of differential equations. If the + equations are input in form [3], it will be assumed the indexes are + referencing the mass_matrix/right_hand_side combination and not the + coordinate_derivatives. + + output_eqns : Dictionary, optional + Any output equations that are desired to be tracked are stored in a + dictionary where the key corresponds to the name given for the + specific equation and the value is the equation itself in symbolic + form + + coord_idxs : Iterable, optional + If coord_states corresponds to the states rather than the + coordinates this variable will tell SymbolicSystem which indexes of + the states correspond to generalized coordinates. + + speed_idxs : Iterable, optional + If coord_states corresponds to the states rather than the + coordinates this variable will tell SymbolicSystem which indexes of + the states correspond to generalized speeds. + + bodies : iterable of Body/Rigidbody objects, optional + Iterable containing the bodies of the system + + loads : iterable of load instances (described below), optional + Iterable containing the loads of the system where forces are given + by (point of application, force vector) and torques are given by + (reference frame acting upon, torque vector). Ex [(point, force), + (ref_frame, torque)] + + Attributes + ========== + + coordinates : Matrix, shape(n, 1) + This is a matrix containing the generalized coordinates of the system + + speeds : Matrix, shape(m, 1) + This is a matrix containing the generalized speeds of the system + + states : Matrix, shape(o, 1) + This is a matrix containing the state variables of the system + + alg_con : List + This list contains the indices of the algebraic constraints in the + combined equations of motion. The presence of these constraints + requires that a DAE solver be used instead of an ODE solver. + If the system is given in form [3] the alg_con variable will be + adjusted such that it is a representation of the combined kinematics + and dynamics thus make sure it always matches the mass matrix + entered. + + dyn_implicit_mat : Matrix, shape(m, m) + This is the M matrix in form [3] of the equations of motion (the mass + matrix or generalized inertia matrix of the dynamical equations of + motion in implicit form). + + dyn_implicit_rhs : Matrix, shape(m, 1) + This is the F vector in form [3] of the equations of motion (the right + hand side of the dynamical equations of motion in implicit form). + + comb_implicit_mat : Matrix, shape(o, o) + This is the M matrix in form [2] of the equations of motion. + This matrix contains a block diagonal structure where the top + left block (the first rows) represent the matrix in the + implicit form of the kinematical equations and the bottom right + block (the last rows) represent the matrix in the implicit form + of the dynamical equations. + + comb_implicit_rhs : Matrix, shape(o, 1) + This is the F vector in form [2] of the equations of motion. The top + part of the vector represents the right hand side of the implicit form + of the kinemaical equations and the bottom of the vector represents the + right hand side of the implicit form of the dynamical equations of + motion. + + comb_explicit_rhs : Matrix, shape(o, 1) + This vector represents the right hand side of the combined equations of + motion in explicit form (form [1] from above). + + kin_explicit_rhs : Matrix, shape(m, 1) + This is the right hand side of the explicit form of the kinematical + equations of motion as can be seen in form [3] (the G matrix). + + output_eqns : Dictionary + If output equations were given they are stored in a dictionary where + the key corresponds to the name given for the specific equation and + the value is the equation itself in symbolic form + + bodies : Tuple + If the bodies in the system were given they are stored in a tuple for + future access + + loads : Tuple + If the loads in the system were given they are stored in a tuple for + future access. This includes forces and torques where forces are given + by (point of application, force vector) and torques are given by + (reference frame acted upon, torque vector). + + Example + ======= + + As a simple example, the dynamics of a simple pendulum will be input into a + SymbolicSystem object manually. First some imports will be needed and then + symbols will be set up for the length of the pendulum (l), mass at the end + of the pendulum (m), and a constant for gravity (g). :: + + >>> from sympy import Matrix, sin, symbols + >>> from sympy.physics.mechanics import dynamicsymbols, SymbolicSystem + >>> l, m, g = symbols('l m g') + + The system will be defined by an angle of theta from the vertical and a + generalized speed of omega will be used where omega = theta_dot. :: + + >>> theta, omega = dynamicsymbols('theta omega') + + Now the equations of motion are ready to be formed and passed to the + SymbolicSystem object. :: + + >>> kin_explicit_rhs = Matrix([omega]) + >>> dyn_implicit_mat = Matrix([l**2 * m]) + >>> dyn_implicit_rhs = Matrix([-g * l * m * sin(theta)]) + >>> symsystem = SymbolicSystem([theta], dyn_implicit_rhs, [omega], + ... dyn_implicit_mat) + + Notes + ===== + + m : number of generalized speeds + n : number of generalized coordinates + o : number of states + + """ + + def __init__(self, coord_states, right_hand_side, speeds=None, + mass_matrix=None, coordinate_derivatives=None, alg_con=None, + output_eqns={}, coord_idxs=None, speed_idxs=None, bodies=None, + loads=None): + """Initializes a SymbolicSystem object""" + + # Extract information on speeds, coordinates and states + if speeds is None: + self._states = Matrix(coord_states) + + if coord_idxs is None: + self._coordinates = None + else: + coords = [coord_states[i] for i in coord_idxs] + self._coordinates = Matrix(coords) + + if speed_idxs is None: + self._speeds = None + else: + speeds_inter = [coord_states[i] for i in speed_idxs] + self._speeds = Matrix(speeds_inter) + else: + self._coordinates = Matrix(coord_states) + self._speeds = Matrix(speeds) + self._states = self._coordinates.col_join(self._speeds) + + # Extract equations of motion form + if coordinate_derivatives is not None: + self._kin_explicit_rhs = coordinate_derivatives + self._dyn_implicit_rhs = right_hand_side + self._dyn_implicit_mat = mass_matrix + self._comb_implicit_rhs = None + self._comb_implicit_mat = None + self._comb_explicit_rhs = None + elif mass_matrix is not None: + self._kin_explicit_rhs = None + self._dyn_implicit_rhs = None + self._dyn_implicit_mat = None + self._comb_implicit_rhs = right_hand_side + self._comb_implicit_mat = mass_matrix + self._comb_explicit_rhs = None + else: + self._kin_explicit_rhs = None + self._dyn_implicit_rhs = None + self._dyn_implicit_mat = None + self._comb_implicit_rhs = None + self._comb_implicit_mat = None + self._comb_explicit_rhs = right_hand_side + + # Set the remainder of the inputs as instance attributes + if alg_con is not None and coordinate_derivatives is not None: + alg_con = [i + len(coordinate_derivatives) for i in alg_con] + self._alg_con = alg_con + self.output_eqns = output_eqns + + # Change the body and loads iterables to tuples if they are not tuples + # already + if not isinstance(bodies, tuple) and bodies is not None: + bodies = tuple(bodies) + if not isinstance(loads, tuple) and loads is not None: + loads = tuple(loads) + self._bodies = bodies + self._loads = loads + + @property + def coordinates(self): + """Returns the column matrix of the generalized coordinates""" + if self._coordinates is None: + raise AttributeError("The coordinates were not specified.") + else: + return self._coordinates + + @property + def speeds(self): + """Returns the column matrix of generalized speeds""" + if self._speeds is None: + raise AttributeError("The speeds were not specified.") + else: + return self._speeds + + @property + def states(self): + """Returns the column matrix of the state variables""" + return self._states + + @property + def alg_con(self): + """Returns a list with the indices of the rows containing algebraic + constraints in the combined form of the equations of motion""" + return self._alg_con + + @property + def dyn_implicit_mat(self): + """Returns the matrix, M, corresponding to the dynamic equations in + implicit form, M x' = F, where the kinematical equations are not + included""" + if self._dyn_implicit_mat is None: + raise AttributeError("dyn_implicit_mat is not specified for " + "equations of motion form [1] or [2].") + else: + return self._dyn_implicit_mat + + @property + def dyn_implicit_rhs(self): + """Returns the column matrix, F, corresponding to the dynamic equations + in implicit form, M x' = F, where the kinematical equations are not + included""" + if self._dyn_implicit_rhs is None: + raise AttributeError("dyn_implicit_rhs is not specified for " + "equations of motion form [1] or [2].") + else: + return self._dyn_implicit_rhs + + @property + def comb_implicit_mat(self): + """Returns the matrix, M, corresponding to the equations of motion in + implicit form (form [2]), M x' = F, where the kinematical equations are + included""" + if self._comb_implicit_mat is None: + if self._dyn_implicit_mat is not None: + num_kin_eqns = len(self._kin_explicit_rhs) + num_dyn_eqns = len(self._dyn_implicit_rhs) + zeros1 = zeros(num_kin_eqns, num_dyn_eqns) + zeros2 = zeros(num_dyn_eqns, num_kin_eqns) + inter1 = eye(num_kin_eqns).row_join(zeros1) + inter2 = zeros2.row_join(self._dyn_implicit_mat) + self._comb_implicit_mat = inter1.col_join(inter2) + return self._comb_implicit_mat + else: + raise AttributeError("comb_implicit_mat is not specified for " + "equations of motion form [1].") + else: + return self._comb_implicit_mat + + @property + def comb_implicit_rhs(self): + """Returns the column matrix, F, corresponding to the equations of + motion in implicit form (form [2]), M x' = F, where the kinematical + equations are included""" + if self._comb_implicit_rhs is None: + if self._dyn_implicit_rhs is not None: + kin_inter = self._kin_explicit_rhs + dyn_inter = self._dyn_implicit_rhs + self._comb_implicit_rhs = kin_inter.col_join(dyn_inter) + return self._comb_implicit_rhs + else: + raise AttributeError("comb_implicit_mat is not specified for " + "equations of motion in form [1].") + else: + return self._comb_implicit_rhs + + def compute_explicit_form(self): + """If the explicit right hand side of the combined equations of motion + is to provided upon initialization, this method will calculate it. This + calculation can potentially take awhile to compute.""" + if self._comb_explicit_rhs is not None: + raise AttributeError("comb_explicit_rhs is already formed.") + + inter1 = getattr(self, 'kin_explicit_rhs', None) + if inter1 is not None: + inter2 = self._dyn_implicit_mat.LUsolve(self._dyn_implicit_rhs) + out = inter1.col_join(inter2) + else: + out = self._comb_implicit_mat.LUsolve(self._comb_implicit_rhs) + + self._comb_explicit_rhs = out + + @property + def comb_explicit_rhs(self): + """Returns the right hand side of the equations of motion in explicit + form, x' = F, where the kinematical equations are included""" + if self._comb_explicit_rhs is None: + raise AttributeError("Please run .combute_explicit_form before " + "attempting to access comb_explicit_rhs.") + else: + return self._comb_explicit_rhs + + @property + def kin_explicit_rhs(self): + """Returns the right hand side of the kinematical equations in explicit + form, q' = G""" + if self._kin_explicit_rhs is None: + raise AttributeError("kin_explicit_rhs is not specified for " + "equations of motion form [1] or [2].") + else: + return self._kin_explicit_rhs + + def dynamic_symbols(self): + """Returns a column matrix containing all of the symbols in the system + that depend on time""" + # Create a list of all of the expressions in the equations of motion + if self._comb_explicit_rhs is None: + eom_expressions = (self.comb_implicit_mat[:] + + self.comb_implicit_rhs[:]) + else: + eom_expressions = (self._comb_explicit_rhs[:]) + + functions_of_time = set() + for expr in eom_expressions: + functions_of_time = functions_of_time.union( + find_dynamicsymbols(expr)) + functions_of_time = functions_of_time.union(self._states) + + return tuple(functions_of_time) + + def constant_symbols(self): + """Returns a column matrix containing all of the symbols in the system + that do not depend on time""" + # Create a list of all of the expressions in the equations of motion + if self._comb_explicit_rhs is None: + eom_expressions = (self.comb_implicit_mat[:] + + self.comb_implicit_rhs[:]) + else: + eom_expressions = (self._comb_explicit_rhs[:]) + + constants = set() + for expr in eom_expressions: + constants = constants.union(expr.free_symbols) + constants.remove(dynamicsymbols._t) + + return tuple(constants) + + @property + def bodies(self): + """Returns the bodies in the system""" + if self._bodies is None: + raise AttributeError("bodies were not specified for the system.") + else: + return self._bodies + + @property + def loads(self): + """Returns the loads in the system""" + if self._loads is None: + raise AttributeError("loads were not specified for the system.") + else: + return self._loads diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/wrapping_geometry.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/wrapping_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..47ed3c1c463499b024afb9e31cfa2ecd77534132 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/wrapping_geometry.py @@ -0,0 +1,641 @@ +"""Geometry objects for use by wrapping pathways.""" + +from abc import ABC, abstractmethod + +from sympy import Integer, acos, pi, sqrt, sympify, tan +from sympy.core.relational import Eq +from sympy.functions.elementary.trigonometric import atan2 +from sympy.polys.polytools import cancel +from sympy.physics.vector import Vector, dot +from sympy.simplify.simplify import trigsimp + + +__all__ = [ + 'WrappingGeometryBase', + 'WrappingCylinder', + 'WrappingSphere', +] + + +class WrappingGeometryBase(ABC): + """Abstract base class for all geometry classes to inherit from. + + Notes + ===== + + Instances of this class cannot be directly instantiated by users. However, + it can be used to created custom geometry types through subclassing. + + """ + + @property + @abstractmethod + def point(cls): + """The point with which the geometry is associated.""" + pass + + @abstractmethod + def point_on_surface(self, point): + """Returns ``True`` if a point is on the geometry's surface. + + Parameters + ========== + point : Point + The point for which it's to be ascertained if it's on the + geometry's surface or not. + + """ + pass + + @abstractmethod + def geodesic_length(self, point_1, point_2): + """Returns the shortest distance between two points on a geometry's + surface. + + Parameters + ========== + + point_1 : Point + The point from which the geodesic length should be calculated. + point_2 : Point + The point to which the geodesic length should be calculated. + + """ + pass + + @abstractmethod + def geodesic_end_vectors(self, point_1, point_2): + """The vectors parallel to the geodesic at the two end points. + + Parameters + ========== + + point_1 : Point + The point from which the geodesic originates. + point_2 : Point + The point at which the geodesic terminates. + + """ + pass + + def __repr__(self): + """Default representation of a geometry object.""" + return f'{self.__class__.__name__}()' + + +class WrappingSphere(WrappingGeometryBase): + """A solid spherical object. + + Explanation + =========== + + A wrapping geometry that allows for circular arcs to be defined between + pairs of points. These paths are always geodetic (the shortest possible). + + Examples + ======== + + To create a ``WrappingSphere`` instance, a ``Symbol`` denoting its radius + and ``Point`` at which its center will be located are needed: + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Point, WrappingSphere + >>> r = symbols('r') + >>> pO = Point('pO') + + A sphere with radius ``r`` centered on ``pO`` can be instantiated with: + + >>> WrappingSphere(r, pO) + WrappingSphere(radius=r, point=pO) + + Parameters + ========== + + radius : Symbol + Radius of the sphere. This symbol must represent a value that is + positive and constant, i.e. it cannot be a dynamic symbol, nor can it + be an expression. + point : Point + A point at which the sphere is centered. + + See Also + ======== + + WrappingCylinder: Cylindrical geometry where the wrapping direction can be + defined. + + """ + + def __init__(self, radius, point): + """Initializer for ``WrappingSphere``. + + Parameters + ========== + + radius : Symbol + The radius of the sphere. + point : Point + A point on which the sphere is centered. + + """ + self.radius = radius + self.point = point + + @property + def radius(self): + """Radius of the sphere.""" + return self._radius + + @radius.setter + def radius(self, radius): + self._radius = radius + + @property + def point(self): + """A point on which the sphere is centered.""" + return self._point + + @point.setter + def point(self, point): + self._point = point + + def point_on_surface(self, point): + """Returns ``True`` if a point is on the sphere's surface. + + Parameters + ========== + + point : Point + The point for which it's to be ascertained if it's on the sphere's + surface or not. This point's position relative to the sphere's + center must be a simple expression involving the radius of the + sphere, otherwise this check will likely not work. + + """ + point_vector = point.pos_from(self.point) + if isinstance(point_vector, Vector): + point_radius_squared = dot(point_vector, point_vector) + else: + point_radius_squared = point_vector**2 + return Eq(point_radius_squared, self.radius**2) == True + + def geodesic_length(self, point_1, point_2): + r"""Returns the shortest distance between two points on the sphere's + surface. + + Explanation + =========== + + The geodesic length, i.e. the shortest arc along the surface of a + sphere, connecting two points can be calculated using the formula: + + .. math:: + + l = \arccos\left(\mathbf{v}_1 \cdot \mathbf{v}_2\right) + + where $\mathbf{v}_1$ and $\mathbf{v}_2$ are the unit vectors from the + sphere's center to the first and second points on the sphere's surface + respectively. Note that the actual path that the geodesic will take is + undefined when the two points are directly opposite one another. + + Examples + ======== + + A geodesic length can only be calculated between two points on the + sphere's surface. Firstly, a ``WrappingSphere`` instance must be + created along with two points that will lie on its surface: + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (Point, ReferenceFrame, + ... WrappingSphere) + >>> N = ReferenceFrame('N') + >>> r = symbols('r') + >>> pO = Point('pO') + >>> pO.set_vel(N, 0) + >>> sphere = WrappingSphere(r, pO) + >>> p1 = Point('p1') + >>> p2 = Point('p2') + + Let's assume that ``p1`` lies at a distance of ``r`` in the ``N.x`` + direction from ``pO`` and that ``p2`` is located on the sphere's + surface in the ``N.y + N.z`` direction from ``pO``. These positions can + be set with: + + >>> p1.set_pos(pO, r*N.x) + >>> p1.pos_from(pO) + r*N.x + >>> p2.set_pos(pO, r*(N.y + N.z).normalize()) + >>> p2.pos_from(pO) + sqrt(2)*r/2*N.y + sqrt(2)*r/2*N.z + + The geodesic length, which is in this case is a quarter of the sphere's + circumference, can be calculated using the ``geodesic_length`` method: + + >>> sphere.geodesic_length(p1, p2) + pi*r/2 + + If the ``geodesic_length`` method is passed an argument, the ``Point`` + that doesn't lie on the sphere's surface then a ``ValueError`` is + raised because it's not possible to calculate a value in this case. + + Parameters + ========== + + point_1 : Point + Point from which the geodesic length should be calculated. + point_2 : Point + Point to which the geodesic length should be calculated. + + """ + for point in (point_1, point_2): + if not self.point_on_surface(point): + msg = ( + f'Geodesic length cannot be calculated as point {point} ' + f'with radius {point.pos_from(self.point).magnitude()} ' + f'from the sphere\'s center {self.point} does not lie on ' + f'the surface of {self} with radius {self.radius}.' + ) + raise ValueError(msg) + point_1_vector = point_1.pos_from(self.point).normalize() + point_2_vector = point_2.pos_from(self.point).normalize() + central_angle = acos(point_2_vector.dot(point_1_vector)) + geodesic_length = self.radius*central_angle + return geodesic_length + + def geodesic_end_vectors(self, point_1, point_2): + """The vectors parallel to the geodesic at the two end points. + + Parameters + ========== + + point_1 : Point + The point from which the geodesic originates. + point_2 : Point + The point at which the geodesic terminates. + + """ + pA, pB = point_1, point_2 + pO = self.point + pA_vec = pA.pos_from(pO) + pB_vec = pB.pos_from(pO) + + if pA_vec.cross(pB_vec) == 0: + msg = ( + f'Can\'t compute geodesic end vectors for the pair of points ' + f'{pA} and {pB} on a sphere {self} as they are diametrically ' + f'opposed, thus the geodesic is not defined.' + ) + raise ValueError(msg) + + return ( + pA_vec.cross(pB.pos_from(pA)).cross(pA_vec).normalize(), + pB_vec.cross(pA.pos_from(pB)).cross(pB_vec).normalize(), + ) + + def __repr__(self): + """Representation of a ``WrappingSphere``.""" + return ( + f'{self.__class__.__name__}(radius={self.radius}, ' + f'point={self.point})' + ) + + +class WrappingCylinder(WrappingGeometryBase): + """A solid (infinite) cylindrical object. + + Explanation + =========== + + A wrapping geometry that allows for circular arcs to be defined between + pairs of points. These paths are always geodetic (the shortest possible) in + the sense that they will be a straight line on the unwrapped cylinder's + surface. However, it is also possible for a direction to be specified, i.e. + paths can be influenced such that they either wrap along the shortest side + or the longest side of the cylinder. To define these directions, rotations + are in the positive direction following the right-hand rule. + + Examples + ======== + + To create a ``WrappingCylinder`` instance, a ``Symbol`` denoting its + radius, a ``Vector`` defining its axis, and a ``Point`` through which its + axis passes are needed: + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (Point, ReferenceFrame, + ... WrappingCylinder) + >>> N = ReferenceFrame('N') + >>> r = symbols('r') + >>> pO = Point('pO') + >>> ax = N.x + + A cylinder with radius ``r``, and axis parallel to ``N.x`` passing through + ``pO`` can be instantiated with: + + >>> WrappingCylinder(r, pO, ax) + WrappingCylinder(radius=r, point=pO, axis=N.x) + + Parameters + ========== + + radius : Symbol + The radius of the cylinder. + point : Point + A point through which the cylinder's axis passes. + axis : Vector + The axis along which the cylinder is aligned. + + See Also + ======== + + WrappingSphere: Spherical geometry where the wrapping direction is always + geodetic. + + """ + + def __init__(self, radius, point, axis): + """Initializer for ``WrappingCylinder``. + + Parameters + ========== + + radius : Symbol + The radius of the cylinder. This symbol must represent a value that + is positive and constant, i.e. it cannot be a dynamic symbol. + point : Point + A point through which the cylinder's axis passes. + axis : Vector + The axis along which the cylinder is aligned. + + """ + self.radius = radius + self.point = point + self.axis = axis + + @property + def radius(self): + """Radius of the cylinder.""" + return self._radius + + @radius.setter + def radius(self, radius): + self._radius = radius + + @property + def point(self): + """A point through which the cylinder's axis passes.""" + return self._point + + @point.setter + def point(self, point): + self._point = point + + @property + def axis(self): + """Axis along which the cylinder is aligned.""" + return self._axis + + @axis.setter + def axis(self, axis): + self._axis = axis.normalize() + + def point_on_surface(self, point): + """Returns ``True`` if a point is on the cylinder's surface. + + Parameters + ========== + + point : Point + The point for which it's to be ascertained if it's on the + cylinder's surface or not. This point's position relative to the + cylinder's axis must be a simple expression involving the radius of + the sphere, otherwise this check will likely not work. + + """ + relative_position = point.pos_from(self.point) + parallel = relative_position.dot(self.axis) * self.axis + point_vector = relative_position - parallel + if isinstance(point_vector, Vector): + point_radius_squared = dot(point_vector, point_vector) + else: + point_radius_squared = point_vector**2 + return Eq(trigsimp(point_radius_squared), self.radius**2) == True + + def geodesic_length(self, point_1, point_2): + """The shortest distance between two points on a geometry's surface. + + Explanation + =========== + + The geodesic length, i.e. the shortest arc along the surface of a + cylinder, connecting two points. It can be calculated using Pythagoras' + theorem. The first short side is the distance between the two points on + the cylinder's surface parallel to the cylinder's axis. The second + short side is the arc of a circle between the two points of the + cylinder's surface perpendicular to the cylinder's axis. The resulting + hypotenuse is the geodesic length. + + Examples + ======== + + A geodesic length can only be calculated between two points on the + cylinder's surface. Firstly, a ``WrappingCylinder`` instance must be + created along with two points that will lie on its surface: + + >>> from sympy import symbols, cos, sin + >>> from sympy.physics.mechanics import (Point, ReferenceFrame, + ... WrappingCylinder, dynamicsymbols) + >>> N = ReferenceFrame('N') + >>> r = symbols('r') + >>> pO = Point('pO') + >>> pO.set_vel(N, 0) + >>> cylinder = WrappingCylinder(r, pO, N.x) + >>> p1 = Point('p1') + >>> p2 = Point('p2') + + Let's assume that ``p1`` is located at ``N.x + r*N.y`` relative to + ``pO`` and that ``p2`` is located at ``r*(cos(q)*N.y + sin(q)*N.z)`` + relative to ``pO``, where ``q(t)`` is a generalized coordinate + specifying the angle rotated around the ``N.x`` axis according to the + right-hand rule where ``N.y`` is zero. These positions can be set with: + + >>> q = dynamicsymbols('q') + >>> p1.set_pos(pO, N.x + r*N.y) + >>> p1.pos_from(pO) + N.x + r*N.y + >>> p2.set_pos(pO, r*(cos(q)*N.y + sin(q)*N.z).normalize()) + >>> p2.pos_from(pO).simplify() + r*cos(q(t))*N.y + r*sin(q(t))*N.z + + The geodesic length, which is in this case a is the hypotenuse of a + right triangle where the other two side lengths are ``1`` (parallel to + the cylinder's axis) and ``r*q(t)`` (parallel to the cylinder's cross + section), can be calculated using the ``geodesic_length`` method: + + >>> cylinder.geodesic_length(p1, p2).simplify() + sqrt(r**2*q(t)**2 + 1) + + If the ``geodesic_length`` method is passed an argument ``Point`` that + doesn't lie on the sphere's surface then a ``ValueError`` is raised + because it's not possible to calculate a value in this case. + + Parameters + ========== + + point_1 : Point + Point from which the geodesic length should be calculated. + point_2 : Point + Point to which the geodesic length should be calculated. + + """ + for point in (point_1, point_2): + if not self.point_on_surface(point): + msg = ( + f'Geodesic length cannot be calculated as point {point} ' + f'with radius {point.pos_from(self.point).magnitude()} ' + f'from the cylinder\'s center {self.point} does not lie on ' + f'the surface of {self} with radius {self.radius} and axis ' + f'{self.axis}.' + ) + raise ValueError(msg) + + relative_position = point_2.pos_from(point_1) + parallel_length = relative_position.dot(self.axis) + + point_1_relative_position = point_1.pos_from(self.point) + point_1_perpendicular_vector = ( + point_1_relative_position + - point_1_relative_position.dot(self.axis)*self.axis + ).normalize() + + point_2_relative_position = point_2.pos_from(self.point) + point_2_perpendicular_vector = ( + point_2_relative_position + - point_2_relative_position.dot(self.axis)*self.axis + ).normalize() + + central_angle = _directional_atan( + cancel(point_1_perpendicular_vector + .cross(point_2_perpendicular_vector) + .dot(self.axis)), + cancel(point_1_perpendicular_vector.dot(point_2_perpendicular_vector)), + ) + + planar_arc_length = self.radius*central_angle + geodesic_length = sqrt(parallel_length**2 + planar_arc_length**2) + return geodesic_length + + def geodesic_end_vectors(self, point_1, point_2): + """The vectors parallel to the geodesic at the two end points. + + Parameters + ========== + + point_1 : Point + The point from which the geodesic originates. + point_2 : Point + The point at which the geodesic terminates. + + """ + point_1_from_origin_point = point_1.pos_from(self.point) + point_2_from_origin_point = point_2.pos_from(self.point) + + if point_1_from_origin_point == point_2_from_origin_point: + msg = ( + f'Cannot compute geodesic end vectors for coincident points ' + f'{point_1} and {point_2} as no geodesic exists.' + ) + raise ValueError(msg) + + point_1_parallel = point_1_from_origin_point.dot(self.axis) * self.axis + point_2_parallel = point_2_from_origin_point.dot(self.axis) * self.axis + point_1_normal = (point_1_from_origin_point - point_1_parallel) + point_2_normal = (point_2_from_origin_point - point_2_parallel) + + if point_1_normal == point_2_normal: + point_1_perpendicular = Vector(0) + point_2_perpendicular = Vector(0) + else: + point_1_perpendicular = self.axis.cross(point_1_normal).normalize() + point_2_perpendicular = -self.axis.cross(point_2_normal).normalize() + + geodesic_length = self.geodesic_length(point_1, point_2) + relative_position = point_2.pos_from(point_1) + parallel_length = relative_position.dot(self.axis) + planar_arc_length = sqrt(geodesic_length**2 - parallel_length**2) + + point_1_vector = ( + planar_arc_length * point_1_perpendicular + + parallel_length * self.axis + ).normalize() + point_2_vector = ( + planar_arc_length * point_2_perpendicular + - parallel_length * self.axis + ).normalize() + + return (point_1_vector, point_2_vector) + + def __repr__(self): + """Representation of a ``WrappingCylinder``.""" + return ( + f'{self.__class__.__name__}(radius={self.radius}, ' + f'point={self.point}, axis={self.axis})' + ) + + +def _directional_atan(numerator, denominator): + """Compute atan in a directional sense as required for geodesics. + + Explanation + =========== + + To be able to control the direction of the geodesic length along the + surface of a cylinder a dedicated arctangent function is needed that + properly handles the directionality of different case. This function + ensures that the central angle is always positive but shifting the case + where ``atan2`` would return a negative angle to be centered around + ``2*pi``. + + Notes + ===== + + This function only handles very specific cases, i.e. the ones that are + expected to be encountered when calculating symbolic geodesics on uniformly + curved surfaces. As such, ``NotImplemented`` errors can be raised in many + cases. This function is named with a leader underscore to indicate that it + only aims to provide very specific functionality within the private scope + of this module. + + """ + + if numerator.is_number and denominator.is_number: + angle = atan2(numerator, denominator) + if angle < 0: + angle += 2 * pi + elif numerator.is_number: + msg = ( + f'Cannot compute a directional atan when the numerator {numerator} ' + f'is numeric and the denominator {denominator} is symbolic.' + ) + raise NotImplementedError(msg) + elif denominator.is_number: + msg = ( + f'Cannot compute a directional atan when the numerator {numerator} ' + f'is symbolic and the denominator {denominator} is numeric.' + ) + raise NotImplementedError(msg) + else: + ratio = sympify(trigsimp(numerator / denominator)) + if isinstance(ratio, tan): + angle = ratio.args[0] + elif ( + ratio.is_Mul + and ratio.args[0] == Integer(-1) + and isinstance(ratio.args[1], tan) + ): + angle = 2 * pi - ratio.args[1].args[0] + else: + msg = f'Cannot compute a directional atan for the value {ratio}.' + raise NotImplementedError(msg) + + return angle diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/paulialgebra.py b/.venv/lib/python3.13/site-packages/sympy/physics/paulialgebra.py new file mode 100644 index 0000000000000000000000000000000000000000..300957354ff34907035aa1d1a48b00276230a1e5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/paulialgebra.py @@ -0,0 +1,231 @@ +""" +This module implements Pauli algebra by subclassing Symbol. Only algebraic +properties of Pauli matrices are used (we do not use the Matrix class). + +See the documentation to the class Pauli for examples. + +References +========== + +.. [1] https://en.wikipedia.org/wiki/Pauli_matrices +""" + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.core.symbol import Symbol +from sympy.physics.quantum import TensorProduct + +__all__ = ['evaluate_pauli_product'] + + +def delta(i, j): + """ + Returns 1 if ``i == j``, else 0. + + This is used in the multiplication of Pauli matrices. + + Examples + ======== + + >>> from sympy.physics.paulialgebra import delta + >>> delta(1, 1) + 1 + >>> delta(2, 3) + 0 + """ + if i == j: + return 1 + else: + return 0 + + +def epsilon(i, j, k): + """ + Return 1 if i,j,k is equal to (1,2,3), (2,3,1), or (3,1,2); + -1 if ``i``,``j``,``k`` is equal to (1,3,2), (3,2,1), or (2,1,3); + else return 0. + + This is used in the multiplication of Pauli matrices. + + Examples + ======== + + >>> from sympy.physics.paulialgebra import epsilon + >>> epsilon(1, 2, 3) + 1 + >>> epsilon(1, 3, 2) + -1 + """ + if (i, j, k) in ((1, 2, 3), (2, 3, 1), (3, 1, 2)): + return 1 + elif (i, j, k) in ((1, 3, 2), (3, 2, 1), (2, 1, 3)): + return -1 + else: + return 0 + + +class Pauli(Symbol): + """ + The class representing algebraic properties of Pauli matrices. + + Explanation + =========== + + The symbol used to display the Pauli matrices can be changed with an + optional parameter ``label="sigma"``. Pauli matrices with different + ``label`` attributes cannot multiply together. + + If the left multiplication of symbol or number with Pauli matrix is needed, + please use parentheses to separate Pauli and symbolic multiplication + (for example: 2*I*(Pauli(3)*Pauli(2))). + + Another variant is to use evaluate_pauli_product function to evaluate + the product of Pauli matrices and other symbols (with commutative + multiply rules). + + See Also + ======== + + evaluate_pauli_product + + Examples + ======== + + >>> from sympy.physics.paulialgebra import Pauli + >>> Pauli(1) + sigma1 + >>> Pauli(1)*Pauli(2) + I*sigma3 + >>> Pauli(1)*Pauli(1) + 1 + >>> Pauli(3)**4 + 1 + >>> Pauli(1)*Pauli(2)*Pauli(3) + I + + >>> from sympy.physics.paulialgebra import Pauli + >>> Pauli(1, label="tau") + tau1 + >>> Pauli(1)*Pauli(2, label="tau") + sigma1*tau2 + >>> Pauli(1, label="tau")*Pauli(2, label="tau") + I*tau3 + + >>> from sympy import I + >>> I*(Pauli(2)*Pauli(3)) + -sigma1 + + >>> from sympy.physics.paulialgebra import evaluate_pauli_product + >>> f = I*Pauli(2)*Pauli(3) + >>> f + I*sigma2*sigma3 + >>> evaluate_pauli_product(f) + -sigma1 + """ + + __slots__ = ("i", "label") + + def __new__(cls, i, label="sigma"): + if i not in [1, 2, 3]: + raise IndexError("Invalid Pauli index") + obj = Symbol.__new__(cls, "%s%d" %(label,i), commutative=False, hermitian=True) + obj.i = i + obj.label = label + return obj + + def __getnewargs_ex__(self): + return (self.i, self.label), {} + + def _hashable_content(self): + return (self.i, self.label) + + # FIXME don't work for -I*Pauli(2)*Pauli(3) + def __mul__(self, other): + if isinstance(other, Pauli): + j = self.i + k = other.i + jlab = self.label + klab = other.label + + if jlab == klab: + return delta(j, k) \ + + I*epsilon(j, k, 1)*Pauli(1,jlab) \ + + I*epsilon(j, k, 2)*Pauli(2,jlab) \ + + I*epsilon(j, k, 3)*Pauli(3,jlab) + return super().__mul__(other) + + def _eval_power(b, e): + if e.is_Integer and e.is_positive: + return super().__pow__(int(e) % 2) + + +def evaluate_pauli_product(arg): + '''Help function to evaluate Pauli matrices product + with symbolic objects. + + Parameters + ========== + + arg: symbolic expression that contains Paulimatrices + + Examples + ======== + + >>> from sympy.physics.paulialgebra import Pauli, evaluate_pauli_product + >>> from sympy import I + >>> evaluate_pauli_product(I*Pauli(1)*Pauli(2)) + -sigma3 + + >>> from sympy.abc import x + >>> evaluate_pauli_product(x**2*Pauli(2)*Pauli(1)) + -I*x**2*sigma3 + ''' + start = arg + end = arg + + if isinstance(arg, Pow) and isinstance(arg.args[0], Pauli): + if arg.args[1].is_odd: + return arg.args[0] + else: + return 1 + + if isinstance(arg, Add): + return Add(*[evaluate_pauli_product(part) for part in arg.args]) + + if isinstance(arg, TensorProduct): + return TensorProduct(*[evaluate_pauli_product(part) for part in arg.args]) + + elif not(isinstance(arg, Mul)): + return arg + + while not start == end or start == arg and end == arg: + start = end + + tmp = start.as_coeff_mul() + sigma_product = 1 + com_product = 1 + keeper = 1 + + for el in tmp[1]: + if isinstance(el, Pauli): + sigma_product *= el + elif not el.is_commutative: + if isinstance(el, Pow) and isinstance(el.args[0], Pauli): + if el.args[1].is_odd: + sigma_product *= el.args[0] + elif isinstance(el, TensorProduct): + keeper = keeper*sigma_product*\ + TensorProduct( + *[evaluate_pauli_product(part) for part in el.args] + ) + sigma_product = 1 + else: + keeper = keeper*sigma_product*el + sigma_product = 1 + else: + com_product *= el + end = tmp[0]*keeper*sigma_product*com_product + if end == arg: break + return end diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/pring.py b/.venv/lib/python3.13/site-packages/sympy/physics/pring.py new file mode 100644 index 0000000000000000000000000000000000000000..325f4ff98a8c9fc428b4e332153af533f4d199ca --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/pring.py @@ -0,0 +1,94 @@ +from sympy.core.numbers import (I, pi) +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.quantum.constants import hbar + + +def wavefunction(n, x): + """ + Returns the wavefunction for particle on ring. + + Parameters + ========== + + n : The quantum number. + Here ``n`` can be positive as well as negative + which can be used to describe the direction of motion of particle. + x : + The angle. + + Examples + ======== + + >>> from sympy.physics.pring import wavefunction + >>> from sympy import Symbol, integrate, pi + >>> x=Symbol("x") + >>> wavefunction(1, x) + sqrt(2)*exp(I*x)/(2*sqrt(pi)) + >>> wavefunction(2, x) + sqrt(2)*exp(2*I*x)/(2*sqrt(pi)) + >>> wavefunction(3, x) + sqrt(2)*exp(3*I*x)/(2*sqrt(pi)) + + The normalization of the wavefunction is: + + >>> integrate(wavefunction(2, x)*wavefunction(-2, x), (x, 0, 2*pi)) + 1 + >>> integrate(wavefunction(4, x)*wavefunction(-4, x), (x, 0, 2*pi)) + 1 + + References + ========== + + .. [1] Atkins, Peter W.; Friedman, Ronald (2005). Molecular Quantum + Mechanics (4th ed.). Pages 71-73. + + """ + # sympify arguments + n, x = S(n), S(x) + return exp(n * I * x) / sqrt(2 * pi) + + +def energy(n, m, r): + """ + Returns the energy of the state corresponding to quantum number ``n``. + + E=(n**2 * (hcross)**2) / (2 * m * r**2) + + Parameters + ========== + + n : + The quantum number. + m : + Mass of the particle. + r : + Radius of circle. + + Examples + ======== + + >>> from sympy.physics.pring import energy + >>> from sympy import Symbol + >>> m=Symbol("m") + >>> r=Symbol("r") + >>> energy(1, m, r) + hbar**2/(2*m*r**2) + >>> energy(2, m, r) + 2*hbar**2/(m*r**2) + >>> energy(-2, 2.0, 3.0) + 0.111111111111111*hbar**2 + + References + ========== + + .. [1] Atkins, Peter W.; Friedman, Ronald (2005). Molecular Quantum + Mechanics (4th ed.). Pages 71-73. + + """ + n, m, r = S(n), S(m), S(r) + if n.is_integer: + return (n**2 * hbar**2) / (2 * m * r**2) + else: + raise ValueError("'n' must be integer") diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/qho_1d.py b/.venv/lib/python3.13/site-packages/sympy/physics/qho_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..f418e0e954656923fbfa64cea2145581ddf65aea --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/qho_1d.py @@ -0,0 +1,88 @@ +from sympy.core import S, pi, Rational +from sympy.functions import hermite, sqrt, exp, factorial, Abs +from sympy.physics.quantum.constants import hbar + + +def psi_n(n, x, m, omega): + """ + Returns the wavefunction psi_{n} for the One-dimensional harmonic oscillator. + + Parameters + ========== + + n : + the "nodal" quantum number. Corresponds to the number of nodes in the + wavefunction. ``n >= 0`` + x : + x coordinate. + m : + Mass of the particle. + omega : + Angular frequency of the oscillator. + + Examples + ======== + + >>> from sympy.physics.qho_1d import psi_n + >>> from sympy.abc import m, x, omega + >>> psi_n(0, x, m, omega) + (m*omega)**(1/4)*exp(-m*omega*x**2/(2*hbar))/(hbar**(1/4)*pi**(1/4)) + + """ + + # sympify arguments + n, x, m, omega = map(S, [n, x, m, omega]) + nu = m * omega / hbar + # normalization coefficient + C = (nu/pi)**Rational(1, 4) * sqrt(1/(2**n*factorial(n))) + + return C * exp(-nu* x**2 /2) * hermite(n, sqrt(nu)*x) + + +def E_n(n, omega): + """ + Returns the Energy of the One-dimensional harmonic oscillator. + + Parameters + ========== + + n : + The "nodal" quantum number. + omega : + The harmonic oscillator angular frequency. + + Notes + ===== + + The unit of the returned value matches the unit of hw, since the energy is + calculated as: + + E_n = hbar * omega*(n + 1/2) + + Examples + ======== + + >>> from sympy.physics.qho_1d import E_n + >>> from sympy.abc import x, omega + >>> E_n(x, omega) + hbar*omega*(x + 1/2) + """ + + return hbar * omega * (n + S.Half) + + +def coherent_state(n, alpha): + """ + Returns for the coherent states of 1D harmonic oscillator. + See https://en.wikipedia.org/wiki/Coherent_states + + Parameters + ========== + + n : + The "nodal" quantum number. + alpha : + The eigen value of annihilation operator. + """ + + return exp(- Abs(alpha)**2/2)*(alpha**n)/sqrt(factorial(n)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/secondquant.py b/.venv/lib/python3.13/site-packages/sympy/physics/secondquant.py new file mode 100644 index 0000000000000000000000000000000000000000..189e8e8b50c785759b03f19f28285f7988cfca75 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/secondquant.py @@ -0,0 +1,3125 @@ +""" +Second quantization operators and states for bosons. + +This follow the formulation of Fetter and Welecka, "Quantum Theory +of Many-Particle Systems." +""" +from collections import defaultdict + +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Dummy, Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.dense import zeros +from sympy.printing.str import StrPrinter +from sympy.utilities.iterables import has_dups + +__all__ = [ + 'Dagger', + 'KroneckerDelta', + 'BosonicOperator', + 'AnnihilateBoson', + 'CreateBoson', + 'AnnihilateFermion', + 'CreateFermion', + 'FockState', + 'FockStateBra', + 'FockStateKet', + 'FockStateBosonKet', + 'FockStateBosonBra', + 'FockStateFermionKet', + 'FockStateFermionBra', + 'BBra', + 'BKet', + 'FBra', + 'FKet', + 'F', + 'Fd', + 'B', + 'Bd', + 'apply_operators', + 'InnerProduct', + 'BosonicBasis', + 'VarBosonicBasis', + 'FixedBosonicBasis', + 'Commutator', + 'matrix_rep', + 'contraction', + 'wicks', + 'NO', + 'evaluate_deltas', + 'AntiSymmetricTensor', + 'substitute_dummies', + 'PermutationOperator', + 'simplify_index_permutations', +] + + +class SecondQuantizationError(Exception): + pass + + +class AppliesOnlyToSymbolicIndex(SecondQuantizationError): + pass + + +class ContractionAppliesOnlyToFermions(SecondQuantizationError): + pass + + +class ViolationOfPauliPrinciple(SecondQuantizationError): + pass + + +class SubstitutionOfAmbigousOperatorFailed(SecondQuantizationError): + pass + + +class WicksTheoremDoesNotApply(SecondQuantizationError): + pass + + +class Dagger(Expr): + """ + Hermitian conjugate of creation/annihilation operators. + + Examples + ======== + + >>> from sympy import I + >>> from sympy.physics.secondquant import Dagger, B, Bd + >>> Dagger(2*I) + -2*I + >>> Dagger(B(0)) + CreateBoson(0) + >>> Dagger(Bd(0)) + AnnihilateBoson(0) + + """ + + def __new__(cls, arg): + arg = sympify(arg) + r = cls.eval(arg) + if isinstance(r, Basic): + return r + obj = Basic.__new__(cls, arg) + return obj + + @classmethod + def eval(cls, arg): + """ + Evaluates the Dagger instance. + + Examples + ======== + + >>> from sympy import I + >>> from sympy.physics.secondquant import Dagger, B, Bd + >>> Dagger(2*I) + -2*I + >>> Dagger(B(0)) + CreateBoson(0) + >>> Dagger(Bd(0)) + AnnihilateBoson(0) + + The eval() method is called automatically. + + """ + dagger = getattr(arg, '_dagger_', None) + if dagger is not None: + return dagger() + if isinstance(arg, Symbol) and arg.is_commutative: + return conjugate(arg) + if isinstance(arg, Basic): + if arg.is_Add: + return Add(*tuple(map(Dagger, arg.args))) + if arg.is_Mul: + return Mul(*tuple(map(Dagger, reversed(arg.args)))) + if arg.is_Number: + return arg + if arg.is_Pow: + return Pow(Dagger(arg.args[0]), arg.args[1]) + if arg == I: + return -arg + if isinstance(arg, Function): + if all(a.is_commutative for a in arg.args): + return arg.func(*[Dagger(a) for a in arg.args]) + else: + return None + + def _dagger_(self): + return self.args[0] + + +class TensorSymbol(Expr): + + is_commutative = True + + +class AntiSymmetricTensor(TensorSymbol): + """Stores upper and lower indices in separate Tuple's. + + Each group of indices is assumed to be antisymmetric. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import AntiSymmetricTensor + >>> i, j = symbols('i j', below_fermi=True) + >>> a, b = symbols('a b', above_fermi=True) + >>> AntiSymmetricTensor('v', (a, i), (b, j)) + AntiSymmetricTensor(v, (a, i), (b, j)) + >>> AntiSymmetricTensor('v', (i, a), (b, j)) + -AntiSymmetricTensor(v, (a, i), (b, j)) + + As you can see, the indices are automatically sorted to a canonical form. + + """ + + def __new__(cls, symbol, upper, lower): + + try: + upper, signu = _sort_anticommuting_fermions( + upper, key=_sqkey_index) + lower, signl = _sort_anticommuting_fermions( + lower, key=_sqkey_index) + + except ViolationOfPauliPrinciple: + return S.Zero + + symbol = sympify(symbol) + upper = Tuple(*upper) + lower = Tuple(*lower) + + if (signu + signl) % 2: + return -TensorSymbol.__new__(cls, symbol, upper, lower) + else: + + return TensorSymbol.__new__(cls, symbol, upper, lower) + + def _latex(self, printer): + return "{%s^{%s}_{%s}}" % ( + self.symbol, + "".join([ printer._print(i) for i in self.args[1]]), + "".join([ printer._print(i) for i in self.args[2]]) + ) + + @property + def symbol(self): + """ + Returns the symbol of the tensor. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import AntiSymmetricTensor + >>> i, j = symbols('i,j', below_fermi=True) + >>> a, b = symbols('a,b', above_fermi=True) + >>> AntiSymmetricTensor('v', (a, i), (b, j)) + AntiSymmetricTensor(v, (a, i), (b, j)) + >>> AntiSymmetricTensor('v', (a, i), (b, j)).symbol + v + + """ + return self.args[0] + + @property + def upper(self): + """ + Returns the upper indices. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import AntiSymmetricTensor + >>> i, j = symbols('i,j', below_fermi=True) + >>> a, b = symbols('a,b', above_fermi=True) + >>> AntiSymmetricTensor('v', (a, i), (b, j)) + AntiSymmetricTensor(v, (a, i), (b, j)) + >>> AntiSymmetricTensor('v', (a, i), (b, j)).upper + (a, i) + + + """ + return self.args[1] + + @property + def lower(self): + """ + Returns the lower indices. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import AntiSymmetricTensor + >>> i, j = symbols('i,j', below_fermi=True) + >>> a, b = symbols('a,b', above_fermi=True) + >>> AntiSymmetricTensor('v', (a, i), (b, j)) + AntiSymmetricTensor(v, (a, i), (b, j)) + >>> AntiSymmetricTensor('v', (a, i), (b, j)).lower + (b, j) + + """ + return self.args[2] + + def __str__(self): + return "%s(%s,%s)" % self.args + + +class SqOperator(Expr): + """ + Base class for Second Quantization operators. + """ + + op_symbol = 'sq' + + is_commutative = False + + def __new__(cls, k): + obj = Basic.__new__(cls, sympify(k)) + return obj + + @property + def state(self): + """ + Returns the state index related to this operator. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F, Fd, B, Bd + >>> p = Symbol('p') + >>> F(p).state + p + >>> Fd(p).state + p + >>> B(p).state + p + >>> Bd(p).state + p + + """ + return self.args[0] + + @property + def is_symbolic(self): + """ + Returns True if the state is a symbol (as opposed to a number). + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> p = Symbol('p') + >>> F(p).is_symbolic + True + >>> F(1).is_symbolic + False + + """ + if self.state.is_Integer: + return False + else: + return True + + def __repr__(self): + return NotImplemented + + def __str__(self): + return "%s(%r)" % (self.op_symbol, self.state) + + def apply_operator(self, state): + """ + Applies an operator to itself. + """ + raise NotImplementedError('implement apply_operator in a subclass') + + +class BosonicOperator(SqOperator): + pass + + +class Annihilator(SqOperator): + pass + + +class Creator(SqOperator): + pass + + +class AnnihilateBoson(BosonicOperator, Annihilator): + """ + Bosonic annihilation operator. + + Examples + ======== + + >>> from sympy.physics.secondquant import B + >>> from sympy.abc import x + >>> B(x) + AnnihilateBoson(x) + """ + + op_symbol = 'b' + + def _dagger_(self): + return CreateBoson(self.state) + + def apply_operator(self, state): + """ + Apply state to self if self is not symbolic and state is a FockStateKet, else + multiply self by state. + + Examples + ======== + + >>> from sympy.physics.secondquant import B, BKet + >>> from sympy.abc import x, y, n + >>> B(x).apply_operator(y) + y*AnnihilateBoson(x) + >>> B(0).apply_operator(BKet((n,))) + sqrt(n)*FockStateBosonKet((n - 1,)) + + """ + if not self.is_symbolic and isinstance(state, FockStateKet): + element = self.state + amp = sqrt(state[element]) + return amp*state.down(element) + else: + return Mul(self, state) + + def __repr__(self): + return "AnnihilateBoson(%s)" % self.state + + def _latex(self, printer): + if self.state is S.Zero: + return "b_{0}" + else: + return "b_{%s}" % printer._print(self.state) + +class CreateBoson(BosonicOperator, Creator): + """ + Bosonic creation operator. + """ + + op_symbol = 'b+' + + def _dagger_(self): + return AnnihilateBoson(self.state) + + def apply_operator(self, state): + """ + Apply state to self if self is not symbolic and state is a FockStateKet, else + multiply self by state. + + Examples + ======== + + >>> from sympy.physics.secondquant import B, Dagger, BKet + >>> from sympy.abc import x, y, n + >>> Dagger(B(x)).apply_operator(y) + y*CreateBoson(x) + >>> B(0).apply_operator(BKet((n,))) + sqrt(n)*FockStateBosonKet((n - 1,)) + """ + if not self.is_symbolic and isinstance(state, FockStateKet): + element = self.state + amp = sqrt(state[element] + 1) + return amp*state.up(element) + else: + return Mul(self, state) + + def __repr__(self): + return "CreateBoson(%s)" % self.state + + def _latex(self, printer): + if self.state is S.Zero: + return "{b^\\dagger_{0}}" + else: + return "{b^\\dagger_{%s}}" % printer._print(self.state) + +B = AnnihilateBoson +Bd = CreateBoson + + +class FermionicOperator(SqOperator): + + @property + def is_restricted(self): + """ + Is this FermionicOperator restricted with respect to fermi level? + + Returns + ======= + + 1 : restricted to orbits above fermi + 0 : no restriction + -1 : restricted to orbits below fermi + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F, Fd + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_restricted + 1 + >>> Fd(a).is_restricted + 1 + >>> F(i).is_restricted + -1 + >>> Fd(i).is_restricted + -1 + >>> F(p).is_restricted + 0 + >>> Fd(p).is_restricted + 0 + + """ + ass = self.args[0].assumptions0 + if ass.get("below_fermi"): + return -1 + if ass.get("above_fermi"): + return 1 + return 0 + + @property + def is_above_fermi(self): + """ + Does the index of this FermionicOperator allow values above fermi? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_above_fermi + True + >>> F(i).is_above_fermi + False + >>> F(p).is_above_fermi + True + + Note + ==== + + The same applies to creation operators Fd + + """ + return not self.args[0].assumptions0.get("below_fermi") + + @property + def is_below_fermi(self): + """ + Does the index of this FermionicOperator allow values below fermi? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_below_fermi + False + >>> F(i).is_below_fermi + True + >>> F(p).is_below_fermi + True + + The same applies to creation operators Fd + + """ + return not self.args[0].assumptions0.get("above_fermi") + + @property + def is_only_below_fermi(self): + """ + Is the index of this FermionicOperator restricted to values below fermi? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_only_below_fermi + False + >>> F(i).is_only_below_fermi + True + >>> F(p).is_only_below_fermi + False + + The same applies to creation operators Fd + """ + return self.is_below_fermi and not self.is_above_fermi + + @property + def is_only_above_fermi(self): + """ + Is the index of this FermionicOperator restricted to values above fermi? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_only_above_fermi + True + >>> F(i).is_only_above_fermi + False + >>> F(p).is_only_above_fermi + False + + The same applies to creation operators Fd + """ + return self.is_above_fermi and not self.is_below_fermi + + def _sortkey(self): + h = hash(self) + label = str(self.args[0]) + + if self.is_only_q_creator: + return 1, label, h + if self.is_only_q_annihilator: + return 4, label, h + if isinstance(self, Annihilator): + return 3, label, h + if isinstance(self, Creator): + return 2, label, h + + +class AnnihilateFermion(FermionicOperator, Annihilator): + """ + Fermionic annihilation operator. + """ + + op_symbol = 'f' + + def _dagger_(self): + return CreateFermion(self.state) + + def apply_operator(self, state): + """ + Apply state to self if self is not symbolic and state is a FockStateKet, else + multiply self by state. + + Examples + ======== + + >>> from sympy.physics.secondquant import B, Dagger, BKet + >>> from sympy.abc import x, y, n + >>> Dagger(B(x)).apply_operator(y) + y*CreateBoson(x) + >>> B(0).apply_operator(BKet((n,))) + sqrt(n)*FockStateBosonKet((n - 1,)) + """ + if isinstance(state, FockStateFermionKet): + element = self.state + return state.down(element) + + elif isinstance(state, Mul): + c_part, nc_part = state.args_cnc() + if isinstance(nc_part[0], FockStateFermionKet): + element = self.state + return Mul(*(c_part + [nc_part[0].down(element)] + nc_part[1:])) + else: + return Mul(self, state) + + else: + return Mul(self, state) + + @property + def is_q_creator(self): + """ + Can we create a quasi-particle? (create hole or create particle) + If so, would that be above or below the fermi surface? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_q_creator + 0 + >>> F(i).is_q_creator + -1 + >>> F(p).is_q_creator + -1 + + """ + if self.is_below_fermi: + return -1 + return 0 + + @property + def is_q_annihilator(self): + """ + Can we destroy a quasi-particle? (annihilate hole or annihilate particle) + If so, would that be above or below the fermi surface? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=1) + >>> i = Symbol('i', below_fermi=1) + >>> p = Symbol('p') + + >>> F(a).is_q_annihilator + 1 + >>> F(i).is_q_annihilator + 0 + >>> F(p).is_q_annihilator + 1 + + """ + if self.is_above_fermi: + return 1 + return 0 + + @property + def is_only_q_creator(self): + """ + Always create a quasi-particle? (create hole or create particle) + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_only_q_creator + False + >>> F(i).is_only_q_creator + True + >>> F(p).is_only_q_creator + False + + """ + return self.is_only_below_fermi + + @property + def is_only_q_annihilator(self): + """ + Always destroy a quasi-particle? (annihilate hole or annihilate particle) + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_only_q_annihilator + True + >>> F(i).is_only_q_annihilator + False + >>> F(p).is_only_q_annihilator + False + + """ + return self.is_only_above_fermi + + def __repr__(self): + return "AnnihilateFermion(%s)" % self.state + + def _latex(self, printer): + if self.state is S.Zero: + return "a_{0}" + else: + return "a_{%s}" % printer._print(self.state) + + +class CreateFermion(FermionicOperator, Creator): + """ + Fermionic creation operator. + """ + + op_symbol = 'f+' + + def _dagger_(self): + return AnnihilateFermion(self.state) + + def apply_operator(self, state): + """ + Apply state to self if self is not symbolic and state is a FockStateKet, else + multiply self by state. + + Examples + ======== + + >>> from sympy.physics.secondquant import B, Dagger, BKet + >>> from sympy.abc import x, y, n + >>> Dagger(B(x)).apply_operator(y) + y*CreateBoson(x) + >>> B(0).apply_operator(BKet((n,))) + sqrt(n)*FockStateBosonKet((n - 1,)) + """ + if isinstance(state, FockStateFermionKet): + element = self.state + return state.up(element) + + elif isinstance(state, Mul): + c_part, nc_part = state.args_cnc() + if isinstance(nc_part[0], FockStateFermionKet): + element = self.state + return Mul(*(c_part + [nc_part[0].up(element)] + nc_part[1:])) + + return Mul(self, state) + + @property + def is_q_creator(self): + """ + Can we create a quasi-particle? (create hole or create particle) + If so, would that be above or below the fermi surface? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import Fd + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> Fd(a).is_q_creator + 1 + >>> Fd(i).is_q_creator + 0 + >>> Fd(p).is_q_creator + 1 + + """ + if self.is_above_fermi: + return 1 + return 0 + + @property + def is_q_annihilator(self): + """ + Can we destroy a quasi-particle? (annihilate hole or annihilate particle) + If so, would that be above or below the fermi surface? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import Fd + >>> a = Symbol('a', above_fermi=1) + >>> i = Symbol('i', below_fermi=1) + >>> p = Symbol('p') + + >>> Fd(a).is_q_annihilator + 0 + >>> Fd(i).is_q_annihilator + -1 + >>> Fd(p).is_q_annihilator + -1 + + """ + if self.is_below_fermi: + return -1 + return 0 + + @property + def is_only_q_creator(self): + """ + Always create a quasi-particle? (create hole or create particle) + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import Fd + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> Fd(a).is_only_q_creator + True + >>> Fd(i).is_only_q_creator + False + >>> Fd(p).is_only_q_creator + False + + """ + return self.is_only_above_fermi + + @property + def is_only_q_annihilator(self): + """ + Always destroy a quasi-particle? (annihilate hole or annihilate particle) + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import Fd + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> Fd(a).is_only_q_annihilator + False + >>> Fd(i).is_only_q_annihilator + True + >>> Fd(p).is_only_q_annihilator + False + + """ + return self.is_only_below_fermi + + def __repr__(self): + return "CreateFermion(%s)" % self.state + + def _latex(self, printer): + if self.state is S.Zero: + return "{a^\\dagger_{0}}" + else: + return "{a^\\dagger_{%s}}" % printer._print(self.state) + +Fd = CreateFermion +F = AnnihilateFermion + + +class FockState(Expr): + """ + Many particle Fock state with a sequence of occupation numbers. + + Anywhere you can have a FockState, you can also have S.Zero. + All code must check for this! + + Base class to represent FockStates. + """ + is_commutative = False + + def __new__(cls, occupations): + """ + occupations is a list with two possible meanings: + + - For bosons it is a list of occupation numbers. + Element i is the number of particles in state i. + + - For fermions it is a list of occupied orbits. + Element 0 is the state that was occupied first, element i + is the i'th occupied state. + """ + occupations = list(map(sympify, occupations)) + obj = Basic.__new__(cls, Tuple(*occupations)) + return obj + + def __getitem__(self, i): + i = int(i) + return self.args[0][i] + + def __repr__(self): + return ("FockState(%r)") % (self.args) + + def __str__(self): + return "%s%r%s" % (getattr(self, 'lbracket', ""), self._labels(), getattr(self, 'rbracket', "")) + + def _labels(self): + return self.args[0] + + def __len__(self): + return len(self.args[0]) + + def _latex(self, printer): + return "%s%s%s" % (getattr(self, 'lbracket_latex', ""), printer._print(self._labels()), getattr(self, 'rbracket_latex', "")) + + +class BosonState(FockState): + """ + Base class for FockStateBoson(Ket/Bra). + """ + + def up(self, i): + """ + Performs the action of a creation operator. + + Examples + ======== + + >>> from sympy.physics.secondquant import BBra + >>> b = BBra([1, 2]) + >>> b + FockStateBosonBra((1, 2)) + >>> b.up(1) + FockStateBosonBra((1, 3)) + """ + i = int(i) + new_occs = list(self.args[0]) + new_occs[i] = new_occs[i] + S.One + return self.__class__(new_occs) + + def down(self, i): + """ + Performs the action of an annihilation operator. + + Examples + ======== + + >>> from sympy.physics.secondquant import BBra + >>> b = BBra([1, 2]) + >>> b + FockStateBosonBra((1, 2)) + >>> b.down(1) + FockStateBosonBra((1, 1)) + """ + i = int(i) + new_occs = list(self.args[0]) + if new_occs[i] == S.Zero: + return S.Zero + else: + new_occs[i] = new_occs[i] - S.One + return self.__class__(new_occs) + + +class FermionState(FockState): + """ + Base class for FockStateFermion(Ket/Bra). + """ + + fermi_level = 0 + + def __new__(cls, occupations, fermi_level=0): + occupations = list(map(sympify, occupations)) + if len(occupations) > 1: + try: + (occupations, sign) = _sort_anticommuting_fermions( + occupations, key=_sqkey_index) + except ViolationOfPauliPrinciple: + return S.Zero + else: + sign = 0 + + cls.fermi_level = fermi_level + + if cls._count_holes(occupations) > fermi_level: + return S.Zero + + if sign % 2: + return S.NegativeOne*FockState.__new__(cls, occupations) + else: + return FockState.__new__(cls, occupations) + + def up(self, i): + """ + Performs the action of a creation operator. + + Explanation + =========== + + If below fermi we try to remove a hole, + if above fermi we try to create a particle. + + If general index p we return ``Kronecker(p,i)*self`` + where ``i`` is a new symbol with restriction above or below. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import FKet + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> FKet([]).up(a) + FockStateFermionKet((a,)) + + A creator acting on vacuum below fermi vanishes + + >>> FKet([]).up(i) + 0 + + + """ + present = i in self.args[0] + + if self._only_above_fermi(i): + if present: + return S.Zero + else: + return self._add_orbit(i) + elif self._only_below_fermi(i): + if present: + return self._remove_orbit(i) + else: + return S.Zero + else: + if present: + hole = Dummy("i", below_fermi=True) + return KroneckerDelta(i, hole)*self._remove_orbit(i) + else: + particle = Dummy("a", above_fermi=True) + return KroneckerDelta(i, particle)*self._add_orbit(i) + + def down(self, i): + """ + Performs the action of an annihilation operator. + + Explanation + =========== + + If below fermi we try to create a hole, + If above fermi we try to remove a particle. + + If general index p we return ``Kronecker(p,i)*self`` + where ``i`` is a new symbol with restriction above or below. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import FKet + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + An annihilator acting on vacuum above fermi vanishes + + >>> FKet([]).down(a) + 0 + + Also below fermi, it vanishes, unless we specify a fermi level > 0 + + >>> FKet([]).down(i) + 0 + >>> FKet([],4).down(i) + FockStateFermionKet((i,)) + + """ + present = i in self.args[0] + + if self._only_above_fermi(i): + if present: + return self._remove_orbit(i) + else: + return S.Zero + + elif self._only_below_fermi(i): + if present: + return S.Zero + else: + return self._add_orbit(i) + else: + if present: + hole = Dummy("i", below_fermi=True) + return KroneckerDelta(i, hole)*self._add_orbit(i) + else: + particle = Dummy("a", above_fermi=True) + return KroneckerDelta(i, particle)*self._remove_orbit(i) + + @classmethod + def _only_below_fermi(cls, i): + """ + Tests if given orbit is only below fermi surface. + + If nothing can be concluded we return a conservative False. + """ + if i.is_number: + return i <= cls.fermi_level + if i.assumptions0.get('below_fermi'): + return True + return False + + @classmethod + def _only_above_fermi(cls, i): + """ + Tests if given orbit is only above fermi surface. + + If fermi level has not been set we return True. + If nothing can be concluded we return a conservative False. + """ + if i.is_number: + return i > cls.fermi_level + if i.assumptions0.get('above_fermi'): + return True + return not cls.fermi_level + + def _remove_orbit(self, i): + """ + Removes particle/fills hole in orbit i. No input tests performed here. + """ + new_occs = list(self.args[0]) + pos = new_occs.index(i) + del new_occs[pos] + if (pos) % 2: + return S.NegativeOne*self.__class__(new_occs, self.fermi_level) + else: + return self.__class__(new_occs, self.fermi_level) + + def _add_orbit(self, i): + """ + Adds particle/creates hole in orbit i. No input tests performed here. + """ + return self.__class__((i,) + self.args[0], self.fermi_level) + + @classmethod + def _count_holes(cls, occupations): + """ + Returns the number of identified hole states in occupations list. + """ + return len([i for i in occupations if cls._only_below_fermi(i)]) + + def _negate_holes(self, occupations): + """ + Returns the occupations list where states below the fermi level have negative labels. + + For symbolic state labels, no sign is included. + """ + return tuple([-i if self._only_below_fermi(i) and i.is_number else i for i in occupations]) + + def __repr__(self): + if self.fermi_level: + return "FockStateKet(%r, fermi_level=%s)" % (self.args[0], self.fermi_level) + else: + return "FockStateKet(%r)" % (self.args[0],) + + def _labels(self): + return self._negate_holes(self.args[0]) + + +class FockStateKet(FockState): + """ + Representation of a ket. + """ + lbracket = '|' + rbracket = '>' + lbracket_latex = r'\left|' + rbracket_latex = r'\right\rangle' + + +class FockStateBra(FockState): + """ + Representation of a bra. + """ + lbracket = '<' + rbracket = '|' + lbracket_latex = r'\left\langle' + rbracket_latex = r'\right|' + + def __mul__(self, other): + if isinstance(other, FockStateKet): + return InnerProduct(self, other) + else: + return Expr.__mul__(self, other) + + +class FockStateBosonKet(BosonState, FockStateKet): + """ + Many particle Fock state with a sequence of occupation numbers. + + Occupation numbers can be any integer >= 0. + + Examples + ======== + + >>> from sympy.physics.secondquant import BKet + >>> BKet([1, 2]) + FockStateBosonKet((1, 2)) + """ + def _dagger_(self): + return FockStateBosonBra(*self.args) + + +class FockStateBosonBra(BosonState, FockStateBra): + """ + Describes a collection of BosonBra particles. + + Examples + ======== + + >>> from sympy.physics.secondquant import BBra + >>> BBra([1, 2]) + FockStateBosonBra((1, 2)) + """ + def _dagger_(self): + return FockStateBosonKet(*self.args) + + +class FockStateFermionKet(FermionState, FockStateKet): + """ + Many-particle Fock state with a sequence of occupied orbits. + + Explanation + =========== + + Each state can only have one particle, so we choose to store a list of + occupied orbits rather than a tuple with occupation numbers (zeros and ones). + + states below fermi level are holes, and are represented by negative labels + in the occupation list. + + For symbolic state labels, the fermi_level caps the number of allowed hole- + states. + + Examples + ======== + + >>> from sympy.physics.secondquant import FKet + >>> FKet([1, 2]) + FockStateFermionKet((1, 2)) + """ + def _dagger_(self): + return FockStateFermionBra(*self.args) + + +class FockStateFermionBra(FermionState, FockStateBra): + """ + See Also + ======== + + FockStateFermionKet + + Examples + ======== + + >>> from sympy.physics.secondquant import FBra + >>> FBra([1, 2]) + FockStateFermionBra((1, 2)) + """ + def _dagger_(self): + return FockStateFermionKet(*self.args) + +BBra = FockStateBosonBra +BKet = FockStateBosonKet +FBra = FockStateFermionBra +FKet = FockStateFermionKet + + +def _apply_Mul(m): + """ + Take a Mul instance with operators and apply them to states. + + Explanation + =========== + + This method applies all operators with integer state labels + to the actual states. For symbolic state labels, nothing is done. + When inner products of FockStates are encountered (like ), + they are converted to instances of InnerProduct. + + This does not currently work on double inner products like, + . + + If the argument is not a Mul, it is simply returned as is. + """ + if not isinstance(m, Mul): + return m + c_part, nc_part = m.args_cnc() + n_nc = len(nc_part) + if n_nc in (0, 1): + return m + else: + last = nc_part[-1] + next_to_last = nc_part[-2] + if isinstance(last, FockStateKet): + if isinstance(next_to_last, SqOperator): + if next_to_last.is_symbolic: + return m + else: + result = next_to_last.apply_operator(last) + if result == 0: + return S.Zero + else: + return _apply_Mul(Mul(*(c_part + nc_part[:-2] + [result]))) + elif isinstance(next_to_last, Pow): + if isinstance(next_to_last.base, SqOperator) and \ + next_to_last.exp.is_Integer: + if next_to_last.base.is_symbolic: + return m + else: + result = last + for i in range(next_to_last.exp): + result = next_to_last.base.apply_operator(result) + if result == 0: + break + if result == 0: + return S.Zero + else: + return _apply_Mul(Mul(*(c_part + nc_part[:-2] + [result]))) + else: + return m + elif isinstance(next_to_last, FockStateBra): + result = InnerProduct(next_to_last, last) + if result == 0: + return S.Zero + else: + return _apply_Mul(Mul(*(c_part + nc_part[:-2] + [result]))) + else: + return m + else: + return m + + +def apply_operators(e): + """ + Take a SymPy expression with operators and states and apply the operators. + + Examples + ======== + + >>> from sympy.physics.secondquant import apply_operators + >>> from sympy import sympify + >>> apply_operators(sympify(3)+4) + 7 + """ + e = e.expand() + muls = e.atoms(Mul) + subs_list = [(m, _apply_Mul(m)) for m in iter(muls)] + return e.subs(subs_list) + + +class InnerProduct(Basic): + """ + An unevaluated inner product between a bra and ket. + + Explanation + =========== + + Currently this class just reduces things to a product of + Kronecker Deltas. In the future, we could introduce abstract + states like ``|a>`` and ``|b>``, and leave the inner product unevaluated as + ````. + + """ + is_commutative = True + + def __new__(cls, bra, ket): + if not isinstance(bra, FockStateBra): + raise TypeError("must be a bra") + if not isinstance(ket, FockStateKet): + raise TypeError("must be a ket") + return cls.eval(bra, ket) + + @classmethod + def eval(cls, bra, ket): + result = S.One + for i, j in zip(bra.args[0], ket.args[0]): + result *= KroneckerDelta(i, j) + if result == 0: + break + return result + + @property + def bra(self): + """Returns the bra part of the state""" + return self.args[0] + + @property + def ket(self): + """Returns the ket part of the state""" + return self.args[1] + + def __repr__(self): + sbra = repr(self.bra) + sket = repr(self.ket) + return "%s|%s" % (sbra[:-1], sket[1:]) + + def __str__(self): + return self.__repr__() + + +def matrix_rep(op, basis): + """ + Find the representation of an operator in a basis. + + Examples + ======== + + >>> from sympy.physics.secondquant import VarBosonicBasis, B, matrix_rep + >>> b = VarBosonicBasis(5) + >>> o = B(0) + >>> matrix_rep(o, b) + Matrix([ + [0, 1, 0, 0, 0], + [0, 0, sqrt(2), 0, 0], + [0, 0, 0, sqrt(3), 0], + [0, 0, 0, 0, 2], + [0, 0, 0, 0, 0]]) + """ + a = zeros(len(basis)) + for i in range(len(basis)): + for j in range(len(basis)): + a[i, j] = apply_operators(Dagger(basis[i])*op*basis[j]) + return a + + +class BosonicBasis: + """ + Base class for a basis set of bosonic Fock states. + """ + pass + + +class VarBosonicBasis: + """ + A single state, variable particle number basis set. + + Examples + ======== + + >>> from sympy.physics.secondquant import VarBosonicBasis + >>> b = VarBosonicBasis(5) + >>> b + [FockState((0,)), FockState((1,)), FockState((2,)), + FockState((3,)), FockState((4,))] + """ + + def __init__(self, n_max): + self.n_max = n_max + self._build_states() + + def _build_states(self): + self.basis = [] + for i in range(self.n_max): + self.basis.append(FockStateBosonKet([i])) + self.n_basis = len(self.basis) + + def index(self, state): + """ + Returns the index of state in basis. + + Examples + ======== + + >>> from sympy.physics.secondquant import VarBosonicBasis + >>> b = VarBosonicBasis(3) + >>> state = b.state(1) + >>> b + [FockState((0,)), FockState((1,)), FockState((2,))] + >>> state + FockStateBosonKet((1,)) + >>> b.index(state) + 1 + """ + return self.basis.index(state) + + def state(self, i): + """ + The state of a single basis. + + Examples + ======== + + >>> from sympy.physics.secondquant import VarBosonicBasis + >>> b = VarBosonicBasis(5) + >>> b.state(3) + FockStateBosonKet((3,)) + """ + return self.basis[i] + + def __getitem__(self, i): + return self.state(i) + + def __len__(self): + return len(self.basis) + + def __repr__(self): + return repr(self.basis) + + +class FixedBosonicBasis(BosonicBasis): + """ + Fixed particle number basis set. + + Examples + ======== + + >>> from sympy.physics.secondquant import FixedBosonicBasis + >>> b = FixedBosonicBasis(2, 2) + >>> state = b.state(1) + >>> b + [FockState((2, 0)), FockState((1, 1)), FockState((0, 2))] + >>> state + FockStateBosonKet((1, 1)) + >>> b.index(state) + 1 + """ + def __init__(self, n_particles, n_levels): + self.n_particles = n_particles + self.n_levels = n_levels + self._build_particle_locations() + self._build_states() + + def _build_particle_locations(self): + tup = ["i%i" % i for i in range(self.n_particles)] + first_loop = "for i0 in range(%i)" % self.n_levels + other_loops = '' + for cur, prev in zip(tup[1:], tup): + temp = "for %s in range(%s + 1) " % (cur, prev) + other_loops = other_loops + temp + tup_string = "(%s)" % ", ".join(tup) + list_comp = "[%s %s %s]" % (tup_string, first_loop, other_loops) + result = eval(list_comp) + if self.n_particles == 1: + result = [(item,) for item in result] + self.particle_locations = result + + def _build_states(self): + self.basis = [] + for tuple_of_indices in self.particle_locations: + occ_numbers = self.n_levels*[0] + for level in tuple_of_indices: + occ_numbers[level] += 1 + self.basis.append(FockStateBosonKet(occ_numbers)) + self.n_basis = len(self.basis) + + def index(self, state): + """Returns the index of state in basis. + + Examples + ======== + + >>> from sympy.physics.secondquant import FixedBosonicBasis + >>> b = FixedBosonicBasis(2, 3) + >>> b.index(b.state(3)) + 3 + """ + return self.basis.index(state) + + def state(self, i): + """Returns the state that lies at index i of the basis + + Examples + ======== + + >>> from sympy.physics.secondquant import FixedBosonicBasis + >>> b = FixedBosonicBasis(2, 3) + >>> b.state(3) + FockStateBosonKet((1, 0, 1)) + """ + return self.basis[i] + + def __getitem__(self, i): + return self.state(i) + + def __len__(self): + return len(self.basis) + + def __repr__(self): + return repr(self.basis) + + +class Commutator(Function): + """ + The Commutator: [A, B] = A*B - B*A + + The arguments are ordered according to .__cmp__() + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import Commutator + >>> A, B = symbols('A,B', commutative=False) + >>> Commutator(B, A) + -Commutator(A, B) + + Evaluate the commutator with .doit() + + >>> comm = Commutator(A,B); comm + Commutator(A, B) + >>> comm.doit() + A*B - B*A + + + For two second quantization operators the commutator is evaluated + immediately: + + >>> from sympy.physics.secondquant import Fd, F + >>> a = symbols('a', above_fermi=True) + >>> i = symbols('i', below_fermi=True) + >>> p,q = symbols('p,q') + + >>> Commutator(Fd(a),Fd(i)) + 2*NO(CreateFermion(a)*CreateFermion(i)) + + But for more complicated expressions, the evaluation is triggered by + a call to .doit() + + >>> comm = Commutator(Fd(p)*Fd(q),F(i)); comm + Commutator(CreateFermion(p)*CreateFermion(q), AnnihilateFermion(i)) + >>> comm.doit(wicks=True) + -KroneckerDelta(i, p)*CreateFermion(q) + + KroneckerDelta(i, q)*CreateFermion(p) + + """ + + is_commutative = False + + @classmethod + def eval(cls, a, b): + """ + The Commutator [A,B] is on canonical form if A < B. + + Examples + ======== + + >>> from sympy.physics.secondquant import Commutator, F, Fd + >>> from sympy.abc import x + >>> c1 = Commutator(F(x), Fd(x)) + >>> c2 = Commutator(Fd(x), F(x)) + >>> Commutator.eval(c1, c2) + 0 + """ + if not (a and b): + return S.Zero + if a == b: + return S.Zero + if a.is_commutative or b.is_commutative: + return S.Zero + + # + # [A+B,C] -> [A,C] + [B,C] + # + a = a.expand() + if isinstance(a, Add): + return Add(*[cls(term, b) for term in a.args]) + b = b.expand() + if isinstance(b, Add): + return Add(*[cls(a, term) for term in b.args]) + + # + # [xA,yB] -> xy*[A,B] + # + ca, nca = a.args_cnc() + cb, ncb = b.args_cnc() + c_part = list(ca) + list(cb) + if c_part: + return Mul(Mul(*c_part), cls(Mul._from_args(nca), Mul._from_args(ncb))) + + # + # single second quantization operators + # + if isinstance(a, BosonicOperator) and isinstance(b, BosonicOperator): + if isinstance(b, CreateBoson) and isinstance(a, AnnihilateBoson): + return KroneckerDelta(a.state, b.state) + if isinstance(a, CreateBoson) and isinstance(b, AnnihilateBoson): + return S.NegativeOne*KroneckerDelta(a.state, b.state) + else: + return S.Zero + if isinstance(a, FermionicOperator) and isinstance(b, FermionicOperator): + return wicks(a*b) - wicks(b*a) + + # + # Canonical ordering of arguments + # + if a.sort_key() > b.sort_key(): + return S.NegativeOne*cls(b, a) + + def doit(self, **hints): + """ + Enables the computation of complex expressions. + + Examples + ======== + + >>> from sympy.physics.secondquant import Commutator, F, Fd + >>> from sympy import symbols + >>> i, j = symbols('i,j', below_fermi=True) + >>> a, b = symbols('a,b', above_fermi=True) + >>> c = Commutator(Fd(a)*F(i),Fd(b)*F(j)) + >>> c.doit(wicks=True) + 0 + """ + a = self.args[0] + b = self.args[1] + + if hints.get("wicks"): + a = a.doit(**hints) + b = b.doit(**hints) + try: + return wicks(a*b) - wicks(b*a) + except ContractionAppliesOnlyToFermions: + pass + except WicksTheoremDoesNotApply: + pass + + return (a*b - b*a).doit(**hints) + + def __repr__(self): + return "Commutator(%s,%s)" % (self.args[0], self.args[1]) + + def __str__(self): + return "[%s,%s]" % (self.args[0], self.args[1]) + + def _latex(self, printer): + return "\\left[%s,%s\\right]" % tuple([ + printer._print(arg) for arg in self.args]) + + +class NO(Expr): + """ + This Object is used to represent normal ordering brackets. + + i.e. {abcd} sometimes written :abcd: + + Explanation + =========== + + Applying the function NO(arg) to an argument means that all operators in + the argument will be assumed to anticommute, and have vanishing + contractions. This allows an immediate reordering to canonical form + upon object creation. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import NO, F, Fd + >>> p,q = symbols('p,q') + >>> NO(Fd(p)*F(q)) + NO(CreateFermion(p)*AnnihilateFermion(q)) + >>> NO(F(q)*Fd(p)) + -NO(CreateFermion(p)*AnnihilateFermion(q)) + + + Note + ==== + + If you want to generate a normal ordered equivalent of an expression, you + should use the function wicks(). This class only indicates that all + operators inside the brackets anticommute, and have vanishing contractions. + Nothing more, nothing less. + + """ + is_commutative = False + + def __new__(cls, arg): + """ + Use anticommutation to get canonical form of operators. + + Explanation + =========== + + Employ associativity of normal ordered product: {ab{cd}} = {abcd} + but note that {ab}{cd} /= {abcd}. + + We also employ distributivity: {ab + cd} = {ab} + {cd}. + + Canonical form also implies expand() {ab(c+d)} = {abc} + {abd}. + + """ + + # {ab + cd} = {ab} + {cd} + arg = sympify(arg) + arg = arg.expand() + if arg.is_Add: + return Add(*[ cls(term) for term in arg.args]) + + if arg.is_Mul: + + # take coefficient outside of normal ordering brackets + c_part, seq = arg.args_cnc() + if c_part: + coeff = Mul(*c_part) + if not seq: + return coeff + else: + coeff = S.One + + # {ab{cd}} = {abcd} + newseq = [] + foundit = False + for fac in seq: + if isinstance(fac, NO): + newseq.extend(fac.args) + foundit = True + else: + newseq.append(fac) + if foundit: + return coeff*cls(Mul(*newseq)) + + # We assume that the user don't mix B and F operators + if isinstance(seq[0], BosonicOperator): + raise NotImplementedError + + try: + newseq, sign = _sort_anticommuting_fermions(seq) + except ViolationOfPauliPrinciple: + return S.Zero + + if sign % 2: + return (S.NegativeOne*coeff)*cls(Mul(*newseq)) + elif sign: + return coeff*cls(Mul(*newseq)) + else: + pass # since sign==0, no permutations was necessary + + # if we couldn't do anything with Mul object, we just + # mark it as normal ordered + if coeff != S.One: + return coeff*cls(Mul(*newseq)) + return Expr.__new__(cls, Mul(*newseq)) + + if isinstance(arg, NO): + return arg + + # if object was not Mul or Add, normal ordering does not apply + return arg + + @property + def has_q_creators(self): + """ + Return 0 if the leftmost argument of the first argument is a not a + q_creator, else 1 if it is above fermi or -1 if it is below fermi. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import NO, F, Fd + + >>> a = symbols('a', above_fermi=True) + >>> i = symbols('i', below_fermi=True) + >>> NO(Fd(a)*Fd(i)).has_q_creators + 1 + >>> NO(F(i)*F(a)).has_q_creators + -1 + >>> NO(Fd(i)*F(a)).has_q_creators #doctest: +SKIP + 0 + + """ + return self.args[0].args[0].is_q_creator + + @property + def has_q_annihilators(self): + """ + Return 0 if the rightmost argument of the first argument is a not a + q_annihilator, else 1 if it is above fermi or -1 if it is below fermi. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import NO, F, Fd + + >>> a = symbols('a', above_fermi=True) + >>> i = symbols('i', below_fermi=True) + >>> NO(Fd(a)*Fd(i)).has_q_annihilators + -1 + >>> NO(F(i)*F(a)).has_q_annihilators + 1 + >>> NO(Fd(a)*F(i)).has_q_annihilators + 0 + + """ + return self.args[0].args[-1].is_q_annihilator + + def doit(self, **hints): + """ + Either removes the brackets or enables complex computations + in its arguments. + + Examples + ======== + + >>> from sympy.physics.secondquant import NO, Fd, F + >>> from textwrap import fill + >>> from sympy import symbols, Dummy + >>> p,q = symbols('p,q', cls=Dummy) + >>> print(fill(str(NO(Fd(p)*F(q)).doit()))) + KroneckerDelta(_a, _p)*KroneckerDelta(_a, + _q)*CreateFermion(_a)*AnnihilateFermion(_a) + KroneckerDelta(_a, + _p)*KroneckerDelta(_i, _q)*CreateFermion(_a)*AnnihilateFermion(_i) - + KroneckerDelta(_a, _q)*KroneckerDelta(_i, + _p)*AnnihilateFermion(_a)*CreateFermion(_i) - KroneckerDelta(_i, + _p)*KroneckerDelta(_i, _q)*AnnihilateFermion(_i)*CreateFermion(_i) + """ + if hints.get("remove_brackets", True): + return self._remove_brackets() + else: + return self.__new__(type(self), self.args[0].doit(**hints)) + + def _remove_brackets(self): + """ + Returns the sorted string without normal order brackets. + + The returned string have the property that no nonzero + contractions exist. + """ + + # check if any creator is also an annihilator + subslist = [] + for i in self.iter_q_creators(): + if self[i].is_q_annihilator: + assume = self[i].state.assumptions0 + + # only operators with a dummy index can be split in two terms + if isinstance(self[i].state, Dummy): + + # create indices with fermi restriction + assume.pop("above_fermi", None) + assume["below_fermi"] = True + below = Dummy('i', **assume) + assume.pop("below_fermi", None) + assume["above_fermi"] = True + above = Dummy('a', **assume) + + cls = type(self[i]) + split = ( + self[i].__new__(cls, below) + * KroneckerDelta(below, self[i].state) + + self[i].__new__(cls, above) + * KroneckerDelta(above, self[i].state) + ) + subslist.append((self[i], split)) + else: + raise SubstitutionOfAmbigousOperatorFailed(self[i]) + if subslist: + result = NO(self.subs(subslist)) + if isinstance(result, Add): + return Add(*[term.doit() for term in result.args]) + else: + return self.args[0] + + def _expand_operators(self): + """ + Returns a sum of NO objects that contain no ambiguous q-operators. + + Explanation + =========== + + If an index q has range both above and below fermi, the operator F(q) + is ambiguous in the sense that it can be both a q-creator and a q-annihilator. + If q is dummy, it is assumed to be a summation variable and this method + rewrites it into a sum of NO terms with unambiguous operators: + + {Fd(p)*F(q)} = {Fd(a)*F(b)} + {Fd(a)*F(i)} + {Fd(j)*F(b)} -{F(i)*Fd(j)} + + where a,b are above and i,j are below fermi level. + """ + return NO(self._remove_brackets) + + def __getitem__(self, i): + if isinstance(i, slice): + indices = i.indices(len(self)) + return [self.args[0].args[i] for i in range(*indices)] + else: + return self.args[0].args[i] + + def __len__(self): + return len(self.args[0].args) + + def iter_q_annihilators(self): + """ + Iterates over the annihilation operators. + + Examples + ======== + + >>> from sympy import symbols + >>> i, j = symbols('i j', below_fermi=True) + >>> a, b = symbols('a b', above_fermi=True) + >>> from sympy.physics.secondquant import NO, F, Fd + >>> no = NO(Fd(a)*F(i)*F(b)*Fd(j)) + + >>> no.iter_q_creators() + + >>> list(no.iter_q_creators()) + [0, 1] + >>> list(no.iter_q_annihilators()) + [3, 2] + + """ + ops = self.args[0].args + iter = range(len(ops) - 1, -1, -1) + for i in iter: + if ops[i].is_q_annihilator: + yield i + else: + break + + def iter_q_creators(self): + """ + Iterates over the creation operators. + + Examples + ======== + + >>> from sympy import symbols + >>> i, j = symbols('i j', below_fermi=True) + >>> a, b = symbols('a b', above_fermi=True) + >>> from sympy.physics.secondquant import NO, F, Fd + >>> no = NO(Fd(a)*F(i)*F(b)*Fd(j)) + + >>> no.iter_q_creators() + + >>> list(no.iter_q_creators()) + [0, 1] + >>> list(no.iter_q_annihilators()) + [3, 2] + + """ + + ops = self.args[0].args + iter = range(0, len(ops)) + for i in iter: + if ops[i].is_q_creator: + yield i + else: + break + + def get_subNO(self, i): + """ + Returns a NO() without FermionicOperator at index i. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import F, NO + >>> p, q, r = symbols('p,q,r') + + >>> NO(F(p)*F(q)*F(r)).get_subNO(1) + NO(AnnihilateFermion(p)*AnnihilateFermion(r)) + + """ + arg0 = self.args[0] # it's a Mul by definition of how it's created + mul = arg0._new_rawargs(*(arg0.args[:i] + arg0.args[i + 1:])) + return NO(mul) + + def _latex(self, printer): + return "\\left\\{%s\\right\\}" % printer._print(self.args[0]) + + def __repr__(self): + return "NO(%s)" % self.args[0] + + def __str__(self): + return ":%s:" % self.args[0] + + +def contraction(a, b): + """ + Calculates contraction of Fermionic operators a and b. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import F, Fd, contraction + >>> p, q = symbols('p,q') + >>> a, b = symbols('a,b', above_fermi=True) + >>> i, j = symbols('i,j', below_fermi=True) + + A contraction is non-zero only if a quasi-creator is to the right of a + quasi-annihilator: + + >>> contraction(F(a),Fd(b)) + KroneckerDelta(a, b) + >>> contraction(Fd(i),F(j)) + KroneckerDelta(i, j) + + For general indices a non-zero result restricts the indices to below/above + the fermi surface: + + >>> contraction(Fd(p),F(q)) + KroneckerDelta(_i, q)*KroneckerDelta(p, q) + >>> contraction(F(p),Fd(q)) + KroneckerDelta(_a, q)*KroneckerDelta(p, q) + + Two creators or two annihilators always vanishes: + + >>> contraction(F(p),F(q)) + 0 + >>> contraction(Fd(p),Fd(q)) + 0 + + """ + if isinstance(b, FermionicOperator) and isinstance(a, FermionicOperator): + if isinstance(a, AnnihilateFermion) and isinstance(b, CreateFermion): + if b.state.assumptions0.get("below_fermi"): + return S.Zero + if a.state.assumptions0.get("below_fermi"): + return S.Zero + if b.state.assumptions0.get("above_fermi"): + return KroneckerDelta(a.state, b.state) + if a.state.assumptions0.get("above_fermi"): + return KroneckerDelta(a.state, b.state) + + return (KroneckerDelta(a.state, b.state)* + KroneckerDelta(b.state, Dummy('a', above_fermi=True))) + if isinstance(b, AnnihilateFermion) and isinstance(a, CreateFermion): + if b.state.assumptions0.get("above_fermi"): + return S.Zero + if a.state.assumptions0.get("above_fermi"): + return S.Zero + if b.state.assumptions0.get("below_fermi"): + return KroneckerDelta(a.state, b.state) + if a.state.assumptions0.get("below_fermi"): + return KroneckerDelta(a.state, b.state) + + return (KroneckerDelta(a.state, b.state)* + KroneckerDelta(b.state, Dummy('i', below_fermi=True))) + + # vanish if 2xAnnihilator or 2xCreator + return S.Zero + + else: + #not fermion operators + t = ( isinstance(i, FermionicOperator) for i in (a, b) ) + raise ContractionAppliesOnlyToFermions(*t) + + +def _sqkey_operator(sq_operator): + """Generates key for canonical sorting of SQ operators.""" + return sq_operator._sortkey() + +def _sqkey_index(index): + """Key for sorting of indices. + + particle < hole < general + + FIXME: This is a bottle-neck, can we do it faster? + """ + h = hash(index) + label = str(index) + if isinstance(index, Dummy): + if index.assumptions0.get('above_fermi'): + return (20, label, h) + elif index.assumptions0.get('below_fermi'): + return (21, label, h) + else: + return (22, label, h) + + if index.assumptions0.get('above_fermi'): + return (10, label, h) + elif index.assumptions0.get('below_fermi'): + return (11, label, h) + else: + return (12, label, h) + + + +def _sort_anticommuting_fermions(string1, key=_sqkey_operator): + """Sort fermionic operators to canonical order, assuming all pairs anticommute. + + Explanation + =========== + + Uses a bidirectional bubble sort. Items in string1 are not referenced + so in principle they may be any comparable objects. The sorting depends on the + operators '>' and '=='. + + If the Pauli principle is violated, an exception is raised. + + Returns + ======= + + tuple (sorted_str, sign) + + sorted_str: list containing the sorted operators + sign: int telling how many times the sign should be changed + (if sign==0 the string was already sorted) + """ + + verified = False + sign = 0 + rng = list(range(len(string1) - 1)) + rev = list(range(len(string1) - 3, -1, -1)) + + keys = list(map(key, string1)) + key_val = dict(list(zip(keys, string1))) + + while not verified: + verified = True + for i in rng: + left = keys[i] + right = keys[i + 1] + if left == right: + raise ViolationOfPauliPrinciple([left, right]) + if left > right: + verified = False + keys[i:i + 2] = [right, left] + sign = sign + 1 + if verified: + break + for i in rev: + left = keys[i] + right = keys[i + 1] + if left == right: + raise ViolationOfPauliPrinciple([left, right]) + if left > right: + verified = False + keys[i:i + 2] = [right, left] + sign = sign + 1 + string1 = [ key_val[k] for k in keys ] + return (string1, sign) + + +def evaluate_deltas(e): + """ + We evaluate KroneckerDelta symbols in the expression assuming Einstein summation. + + Explanation + =========== + + If one index is repeated it is summed over and in effect substituted with + the other one. If both indices are repeated we substitute according to what + is the preferred index. this is determined by + KroneckerDelta.preferred_index and KroneckerDelta.killable_index. + + In case there are no possible substitutions or if a substitution would + imply a loss of information, nothing is done. + + In case an index appears in more than one KroneckerDelta, the resulting + substitution depends on the order of the factors. Since the ordering is platform + dependent, the literal expression resulting from this function may be hard to + predict. + + Examples + ======== + + We assume the following: + + >>> from sympy import symbols, Function, Dummy, KroneckerDelta + >>> from sympy.physics.secondquant import evaluate_deltas + >>> i,j = symbols('i j', below_fermi=True, cls=Dummy) + >>> a,b = symbols('a b', above_fermi=True, cls=Dummy) + >>> p,q = symbols('p q', cls=Dummy) + >>> f = Function('f') + >>> t = Function('t') + + The order of preference for these indices according to KroneckerDelta is + (a, b, i, j, p, q). + + Trivial cases: + + >>> evaluate_deltas(KroneckerDelta(i,j)*f(i)) # d_ij f(i) -> f(j) + f(_j) + >>> evaluate_deltas(KroneckerDelta(i,j)*f(j)) # d_ij f(j) -> f(i) + f(_i) + >>> evaluate_deltas(KroneckerDelta(i,p)*f(p)) # d_ip f(p) -> f(i) + f(_i) + >>> evaluate_deltas(KroneckerDelta(q,p)*f(p)) # d_qp f(p) -> f(q) + f(_q) + >>> evaluate_deltas(KroneckerDelta(q,p)*f(q)) # d_qp f(q) -> f(p) + f(_p) + + More interesting cases: + + >>> evaluate_deltas(KroneckerDelta(i,p)*t(a,i)*f(p,q)) + f(_i, _q)*t(_a, _i) + >>> evaluate_deltas(KroneckerDelta(a,p)*t(a,i)*f(p,q)) + f(_a, _q)*t(_a, _i) + >>> evaluate_deltas(KroneckerDelta(p,q)*f(p,q)) + f(_p, _p) + + Finally, here are some cases where nothing is done, because that would + imply a loss of information: + + >>> evaluate_deltas(KroneckerDelta(i,p)*f(q)) + f(_q)*KroneckerDelta(_i, _p) + >>> evaluate_deltas(KroneckerDelta(i,p)*f(i)) + f(_i)*KroneckerDelta(_i, _p) + """ + + # We treat Deltas only in mul objects + # for general function objects we don't evaluate KroneckerDeltas in arguments, + # but here we hard code exceptions to this rule + accepted_functions = ( + Add, + ) + if isinstance(e, accepted_functions): + return e.func(*[evaluate_deltas(arg) for arg in e.args]) + + elif isinstance(e, Mul): + # find all occurrences of delta function and count each index present in + # expression. + deltas = [] + indices = {} + for i in e.args: + for s in i.free_symbols: + if s in indices: + indices[s] += 1 + else: + indices[s] = 0 # geek counting simplifies logic below + if isinstance(i, KroneckerDelta): + deltas.append(i) + + for d in deltas: + # If we do something, and there are more deltas, we should recurse + # to treat the resulting expression properly + if d.killable_index.is_Symbol and indices[d.killable_index]: + e = e.subs(d.killable_index, d.preferred_index) + if len(deltas) > 1: + return evaluate_deltas(e) + elif (d.preferred_index.is_Symbol and indices[d.preferred_index] + and d.indices_contain_equal_information): + e = e.subs(d.preferred_index, d.killable_index) + if len(deltas) > 1: + return evaluate_deltas(e) + else: + pass + + return e + # nothing to do, maybe we hit a Symbol or a number + else: + return e + + +def substitute_dummies(expr, new_indices=False, pretty_indices={}): + """ + Collect terms by substitution of dummy variables. + + Explanation + =========== + + This routine allows simplification of Add expressions containing terms + which differ only due to dummy variables. + + The idea is to substitute all dummy variables consistently depending on + the structure of the term. For each term, we obtain a sequence of all + dummy variables, where the order is determined by the index range, what + factors the index belongs to and its position in each factor. See + _get_ordered_dummies() for more information about the sorting of dummies. + The index sequence is then substituted consistently in each term. + + Examples + ======== + + >>> from sympy import symbols, Function, Dummy + >>> from sympy.physics.secondquant import substitute_dummies + >>> a,b,c,d = symbols('a b c d', above_fermi=True, cls=Dummy) + >>> i,j = symbols('i j', below_fermi=True, cls=Dummy) + >>> f = Function('f') + + >>> expr = f(a,b) + f(c,d); expr + f(_a, _b) + f(_c, _d) + + Since a, b, c and d are equivalent summation indices, the expression can be + simplified to a single term (for which the dummy indices are still summed over) + + >>> substitute_dummies(expr) + 2*f(_a, _b) + + + Controlling output: + + By default the dummy symbols that are already present in the expression + will be reused in a different permutation. However, if new_indices=True, + new dummies will be generated and inserted. The keyword 'pretty_indices' + can be used to control this generation of new symbols. + + By default the new dummies will be generated on the form i_1, i_2, a_1, + etc. If you supply a dictionary with key:value pairs in the form: + + { index_group: string_of_letters } + + The letters will be used as labels for the new dummy symbols. The + index_groups must be one of 'above', 'below' or 'general'. + + >>> expr = f(a,b,i,j) + >>> my_dummies = { 'above':'st', 'below':'uv' } + >>> substitute_dummies(expr, new_indices=True, pretty_indices=my_dummies) + f(_s, _t, _u, _v) + + If we run out of letters, or if there is no keyword for some index_group + the default dummy generator will be used as a fallback: + + >>> p,q = symbols('p q', cls=Dummy) # general indices + >>> expr = f(p,q) + >>> substitute_dummies(expr, new_indices=True, pretty_indices=my_dummies) + f(_p_0, _p_1) + + """ + + # setup the replacing dummies + if new_indices: + letters_above = pretty_indices.get('above', "") + letters_below = pretty_indices.get('below', "") + letters_general = pretty_indices.get('general', "") + len_above = len(letters_above) + len_below = len(letters_below) + len_general = len(letters_general) + + def _i(number): + try: + return letters_below[number] + except IndexError: + return 'i_' + str(number - len_below) + + def _a(number): + try: + return letters_above[number] + except IndexError: + return 'a_' + str(number - len_above) + + def _p(number): + try: + return letters_general[number] + except IndexError: + return 'p_' + str(number - len_general) + + aboves = [] + belows = [] + generals = [] + + dummies = expr.atoms(Dummy) + if not new_indices: + dummies = sorted(dummies, key=default_sort_key) + + # generate lists with the dummies we will insert + a = i = p = 0 + for d in dummies: + assum = d.assumptions0 + + if assum.get("above_fermi"): + if new_indices: + sym = _a(a) + a += 1 + l1 = aboves + elif assum.get("below_fermi"): + if new_indices: + sym = _i(i) + i += 1 + l1 = belows + else: + if new_indices: + sym = _p(p) + p += 1 + l1 = generals + + if new_indices: + l1.append(Dummy(sym, **assum)) + else: + l1.append(d) + + expr = expr.expand() + terms = Add.make_args(expr) + new_terms = [] + for term in terms: + i = iter(belows) + a = iter(aboves) + p = iter(generals) + ordered = _get_ordered_dummies(term) + subsdict = {} + for d in ordered: + if d.assumptions0.get('below_fermi'): + subsdict[d] = next(i) + elif d.assumptions0.get('above_fermi'): + subsdict[d] = next(a) + else: + subsdict[d] = next(p) + subslist = [] + final_subs = [] + for k, v in subsdict.items(): + if k == v: + continue + if v in subsdict: + # We check if the sequence of substitutions end quickly. In + # that case, we can avoid temporary symbols if we ensure the + # correct substitution order. + if subsdict[v] in subsdict: + # (x, y) -> (y, x), we need a temporary variable + x = Dummy('x') + subslist.append((k, x)) + final_subs.append((x, v)) + else: + # (x, y) -> (y, a), x->y must be done last + # but before temporary variables are resolved + final_subs.insert(0, (k, v)) + else: + subslist.append((k, v)) + subslist.extend(final_subs) + new_terms.append(term.subs(subslist)) + return Add(*new_terms) + + +class KeyPrinter(StrPrinter): + """Printer for which only equal objects are equal in print""" + def _print_Dummy(self, expr): + return "(%s_%i)" % (expr.name, expr.dummy_index) + + +def __kprint(expr): + p = KeyPrinter() + return p.doprint(expr) + + +def _get_ordered_dummies(mul, verbose=False): + """Returns all dummies in the mul sorted in canonical order. + + Explanation + =========== + + The purpose of the canonical ordering is that dummies can be substituted + consistently across terms with the result that equivalent terms can be + simplified. + + It is not possible to determine if two terms are equivalent based solely on + the dummy order. However, a consistent substitution guided by the ordered + dummies should lead to trivially (non-)equivalent terms, thereby revealing + the equivalence. This also means that if two terms have identical sequences of + dummies, the (non-)equivalence should already be apparent. + + Strategy + -------- + + The canonical order is given by an arbitrary sorting rule. A sort key + is determined for each dummy as a tuple that depends on all factors where + the index is present. The dummies are thereby sorted according to the + contraction structure of the term, instead of sorting based solely on the + dummy symbol itself. + + After all dummies in the term has been assigned a key, we check for identical + keys, i.e. unorderable dummies. If any are found, we call a specialized + method, _determine_ambiguous(), that will determine a unique order based + on recursive calls to _get_ordered_dummies(). + + Key description + --------------- + + A high level description of the sort key: + + 1. Range of the dummy index + 2. Relation to external (non-dummy) indices + 3. Position of the index in the first factor + 4. Position of the index in the second factor + + The sort key is a tuple with the following components: + + 1. A single character indicating the range of the dummy (above, below + or general.) + 2. A list of strings with fully masked string representations of all + factors where the dummy is present. By masked, we mean that dummies + are represented by a symbol to indicate either below fermi, above or + general. No other information is displayed about the dummies at + this point. The list is sorted stringwise. + 3. An integer number indicating the position of the index, in the first + factor as sorted in 2. + 4. An integer number indicating the position of the index, in the second + factor as sorted in 2. + + If a factor is either of type AntiSymmetricTensor or SqOperator, the index + position in items 3 and 4 is indicated as 'upper' or 'lower' only. + (Creation operators are considered upper and annihilation operators lower.) + + If the masked factors are identical, the two factors cannot be ordered + unambiguously in item 2. In this case, items 3, 4 are left out. If several + indices are contracted between the unorderable factors, it will be handled by + _determine_ambiguous() + + + """ + # setup dicts to avoid repeated calculations in key() + args = Mul.make_args(mul) + fac_dum = { fac: fac.atoms(Dummy) for fac in args } + fac_repr = { fac: __kprint(fac) for fac in args } + all_dums = set().union(*fac_dum.values()) + mask = {} + for d in all_dums: + if d.assumptions0.get('below_fermi'): + mask[d] = '0' + elif d.assumptions0.get('above_fermi'): + mask[d] = '1' + else: + mask[d] = '2' + dum_repr = {d: __kprint(d) for d in all_dums} + + def _key(d): + dumstruct = [ fac for fac in fac_dum if d in fac_dum[fac] ] + other_dums = set().union(*[fac_dum[fac] for fac in dumstruct]) + fac = dumstruct[-1] + if other_dums is fac_dum[fac]: + other_dums = fac_dum[fac].copy() + other_dums.remove(d) + masked_facs = [ fac_repr[fac] for fac in dumstruct ] + for d2 in other_dums: + masked_facs = [ fac.replace(dum_repr[d2], mask[d2]) + for fac in masked_facs ] + all_masked = [ fac.replace(dum_repr[d], mask[d]) + for fac in masked_facs ] + masked_facs = dict(list(zip(dumstruct, masked_facs))) + + # dummies for which the ordering cannot be determined + if has_dups(all_masked): + all_masked.sort() + return mask[d], tuple(all_masked) # positions are ambiguous + + # sort factors according to fully masked strings + keydict = dict(list(zip(dumstruct, all_masked))) + dumstruct.sort(key=lambda x: keydict[x]) + all_masked.sort() + + pos_val = [] + for fac in dumstruct: + if isinstance(fac, AntiSymmetricTensor): + if d in fac.upper: + pos_val.append('u') + if d in fac.lower: + pos_val.append('l') + elif isinstance(fac, Creator): + pos_val.append('u') + elif isinstance(fac, Annihilator): + pos_val.append('l') + elif isinstance(fac, NO): + ops = [ op for op in fac if op.has(d) ] + for op in ops: + if isinstance(op, Creator): + pos_val.append('u') + else: + pos_val.append('l') + else: + # fallback to position in string representation + facpos = -1 + while 1: + facpos = masked_facs[fac].find(dum_repr[d], facpos + 1) + if facpos == -1: + break + pos_val.append(facpos) + return (mask[d], tuple(all_masked), pos_val[0], pos_val[-1]) + dumkey = dict(list(zip(all_dums, list(map(_key, all_dums))))) + result = sorted(all_dums, key=lambda x: dumkey[x]) + if has_dups(iter(dumkey.values())): + # We have ambiguities + unordered = defaultdict(set) + for d, k in dumkey.items(): + unordered[k].add(d) + for k in [ k for k in unordered if len(unordered[k]) < 2 ]: + del unordered[k] + + unordered = [ unordered[k] for k in sorted(unordered) ] + result = _determine_ambiguous(mul, result, unordered) + return result + + +def _determine_ambiguous(term, ordered, ambiguous_groups): + # We encountered a term for which the dummy substitution is ambiguous. + # This happens for terms with 2 or more contractions between factors that + # cannot be uniquely ordered independent of summation indices. For + # example: + # + # Sum(p, q) v^{p, .}_{q, .}v^{q, .}_{p, .} + # + # Assuming that the indices represented by . are dummies with the + # same range, the factors cannot be ordered, and there is no + # way to determine a consistent ordering of p and q. + # + # The strategy employed here, is to relabel all unambiguous dummies with + # non-dummy symbols and call _get_ordered_dummies again. This procedure is + # applied to the entire term so there is a possibility that + # _determine_ambiguous() is called again from a deeper recursion level. + + # break recursion if there are no ordered dummies + all_ambiguous = set() + for dummies in ambiguous_groups: + all_ambiguous |= dummies + all_ordered = set(ordered) - all_ambiguous + if not all_ordered: + # FIXME: If we arrive here, there are no ordered dummies. A method to + # handle this needs to be implemented. In order to return something + # useful nevertheless, we choose arbitrarily the first dummy and + # determine the rest from this one. This method is dependent on the + # actual dummy labels which violates an assumption for the + # canonicalization procedure. A better implementation is needed. + group = [ d for d in ordered if d in ambiguous_groups[0] ] + d = group[0] + all_ordered.add(d) + ambiguous_groups[0].remove(d) + + stored_counter = _symbol_factory._counter + subslist = [] + for d in [ d for d in ordered if d in all_ordered ]: + nondum = _symbol_factory._next() + subslist.append((d, nondum)) + newterm = term.subs(subslist) + neworder = _get_ordered_dummies(newterm) + _symbol_factory._set_counter(stored_counter) + + # update ordered list with new information + for group in ambiguous_groups: + ordered_group = [ d for d in neworder if d in group ] + ordered_group.reverse() + result = [] + for d in ordered: + if d in group: + result.append(ordered_group.pop()) + else: + result.append(d) + ordered = result + return ordered + + +class _SymbolFactory: + def __init__(self, label): + self._counterVar = 0 + self._label = label + + def _set_counter(self, value): + """ + Sets counter to value. + """ + self._counterVar = value + + @property + def _counter(self): + """ + What counter is currently at. + """ + return self._counterVar + + def _next(self): + """ + Generates the next symbols and increments counter by 1. + """ + s = Symbol("%s%i" % (self._label, self._counterVar)) + self._counterVar += 1 + return s +_symbol_factory = _SymbolFactory('_]"]_') # most certainly a unique label + + +@cacheit +def _get_contractions(string1, keep_only_fully_contracted=False): + """ + Returns Add-object with contracted terms. + + Uses recursion to find all contractions. -- Internal helper function -- + + Will find nonzero contractions in string1 between indices given in + leftrange and rightrange. + + """ + + # Should we store current level of contraction? + if keep_only_fully_contracted and string1: + result = [] + else: + result = [NO(Mul(*string1))] + + for i in range(len(string1) - 1): + for j in range(i + 1, len(string1)): + + c = contraction(string1[i], string1[j]) + + if c: + sign = (j - i + 1) % 2 + if sign: + coeff = S.NegativeOne*c + else: + coeff = c + + # + # Call next level of recursion + # ============================ + # + # We now need to find more contractions among operators + # + # oplist = string1[:i]+ string1[i+1:j] + string1[j+1:] + # + # To prevent overcounting, we don't allow contractions + # we have already encountered. i.e. contractions between + # string1[:i] <---> string1[i+1:j] + # and string1[:i] <---> string1[j+1:]. + # + # This leaves the case: + oplist = string1[i + 1:j] + string1[j + 1:] + + if oplist: + + result.append(coeff*NO( + Mul(*string1[:i])*_get_contractions( oplist, + keep_only_fully_contracted=keep_only_fully_contracted))) + + else: + result.append(coeff*NO( Mul(*string1[:i]))) + + if keep_only_fully_contracted: + break # next iteration over i leaves leftmost operator string1[0] uncontracted + + return Add(*result) + + +def wicks(e, **kw_args): + """ + Returns the normal ordered equivalent of an expression using Wicks Theorem. + + Examples + ======== + + >>> from sympy import symbols, Dummy + >>> from sympy.physics.secondquant import wicks, F, Fd + >>> p, q, r = symbols('p,q,r') + >>> wicks(Fd(p)*F(q)) + KroneckerDelta(_i, q)*KroneckerDelta(p, q) + NO(CreateFermion(p)*AnnihilateFermion(q)) + + By default, the expression is expanded: + + >>> wicks(F(p)*(F(q)+F(r))) + NO(AnnihilateFermion(p)*AnnihilateFermion(q)) + NO(AnnihilateFermion(p)*AnnihilateFermion(r)) + + With the keyword 'keep_only_fully_contracted=True', only fully contracted + terms are returned. + + By request, the result can be simplified in the following order: + -- KroneckerDelta functions are evaluated + -- Dummy variables are substituted consistently across terms + + >>> p, q, r = symbols('p q r', cls=Dummy) + >>> wicks(Fd(p)*(F(q)+F(r)), keep_only_fully_contracted=True) + KroneckerDelta(_i, _q)*KroneckerDelta(_p, _q) + KroneckerDelta(_i, _r)*KroneckerDelta(_p, _r) + + """ + + if not e: + return S.Zero + + opts = { + 'simplify_kronecker_deltas': False, + 'expand': True, + 'simplify_dummies': False, + 'keep_only_fully_contracted': False + } + opts.update(kw_args) + + # check if we are already normally ordered + if isinstance(e, NO): + if opts['keep_only_fully_contracted']: + return S.Zero + else: + return e + elif isinstance(e, FermionicOperator): + if opts['keep_only_fully_contracted']: + return S.Zero + else: + return e + + # break up any NO-objects, and evaluate commutators + e = e.doit(wicks=True) + + # make sure we have only one term to consider + e = e.expand() + if isinstance(e, Add): + if opts['simplify_dummies']: + return substitute_dummies(Add(*[ wicks(term, **kw_args) for term in e.args])) + else: + return Add(*[ wicks(term, **kw_args) for term in e.args]) + + # For Mul-objects we can actually do something + if isinstance(e, Mul): + + # we don't want to mess around with commuting part of Mul + # so we factorize it out before starting recursion + c_part = [] + string1 = [] + for factor in e.args: + if factor.is_commutative: + c_part.append(factor) + else: + string1.append(factor) + n = len(string1) + + # catch trivial cases + if n == 0: + result = e + elif n == 1: + if opts['keep_only_fully_contracted']: + return S.Zero + else: + result = e + + else: # non-trivial + + if isinstance(string1[0], BosonicOperator): + raise NotImplementedError + + string1 = tuple(string1) + + # recursion over higher order contractions + result = _get_contractions(string1, + keep_only_fully_contracted=opts['keep_only_fully_contracted'] ) + result = Mul(*c_part)*result + + if opts['expand']: + result = result.expand() + if opts['simplify_kronecker_deltas']: + result = evaluate_deltas(result) + + return result + + # there was nothing to do + return e + + +class PermutationOperator(Expr): + """ + Represents the index permutation operator P(ij). + + P(ij)*f(i)*g(j) = f(i)*g(j) - f(j)*g(i) + """ + is_commutative = True + + def __new__(cls, i, j): + i, j = sorted(map(sympify, (i, j)), key=default_sort_key) + obj = Basic.__new__(cls, i, j) + return obj + + def get_permuted(self, expr): + """ + Returns -expr with permuted indices. + + Explanation + =========== + + >>> from sympy import symbols, Function + >>> from sympy.physics.secondquant import PermutationOperator + >>> p,q = symbols('p,q') + >>> f = Function('f') + >>> PermutationOperator(p,q).get_permuted(f(p,q)) + -f(q, p) + + """ + i = self.args[0] + j = self.args[1] + if expr.has(i) and expr.has(j): + tmp = Dummy() + expr = expr.subs(i, tmp) + expr = expr.subs(j, i) + expr = expr.subs(tmp, j) + return S.NegativeOne*expr + else: + return expr + + def _latex(self, printer): + return "P(%s%s)" % tuple(printer._print(i) for i in self.args) + + +def simplify_index_permutations(expr, permutation_operators): + """ + Performs simplification by introducing PermutationOperators where appropriate. + + Explanation + =========== + + Schematically: + [abij] - [abji] - [baij] + [baji] -> P(ab)*P(ij)*[abij] + + permutation_operators is a list of PermutationOperators to consider. + + If permutation_operators=[P(ab),P(ij)] we will try to introduce the + permutation operators P(ij) and P(ab) in the expression. If there are other + possible simplifications, we ignore them. + + >>> from sympy import symbols, Function + >>> from sympy.physics.secondquant import simplify_index_permutations + >>> from sympy.physics.secondquant import PermutationOperator + >>> p,q,r,s = symbols('p,q,r,s') + >>> f = Function('f') + >>> g = Function('g') + + >>> expr = f(p)*g(q) - f(q)*g(p); expr + f(p)*g(q) - f(q)*g(p) + >>> simplify_index_permutations(expr,[PermutationOperator(p,q)]) + f(p)*g(q)*PermutationOperator(p, q) + + >>> PermutList = [PermutationOperator(p,q),PermutationOperator(r,s)] + >>> expr = f(p,r)*g(q,s) - f(q,r)*g(p,s) + f(q,s)*g(p,r) - f(p,s)*g(q,r) + >>> simplify_index_permutations(expr,PermutList) + f(p, r)*g(q, s)*PermutationOperator(p, q)*PermutationOperator(r, s) + + """ + + def _get_indices(expr, ind): + """ + Collects indices recursively in predictable order. + """ + result = [] + for arg in expr.args: + if arg in ind: + result.append(arg) + else: + if arg.args: + result.extend(_get_indices(arg, ind)) + return result + + def _choose_one_to_keep(a, b, ind): + # we keep the one where indices in ind are in order ind[0] < ind[1] + return min(a, b, key=lambda x: default_sort_key(_get_indices(x, ind))) + + expr = expr.expand() + if isinstance(expr, Add): + terms = set(expr.args) + + for P in permutation_operators: + new_terms = set() + on_hold = set() + while terms: + term = terms.pop() + permuted = P.get_permuted(term) + if permuted in terms | on_hold: + try: + terms.remove(permuted) + except KeyError: + on_hold.remove(permuted) + keep = _choose_one_to_keep(term, permuted, P.args) + new_terms.add(P*keep) + else: + + # Some terms must get a second chance because the permuted + # term may already have canonical dummy ordering. Then + # substitute_dummies() does nothing. However, the other + # term, if it exists, will be able to match with us. + permuted1 = permuted + permuted = substitute_dummies(permuted) + if permuted1 == permuted: + on_hold.add(term) + elif permuted in terms | on_hold: + try: + terms.remove(permuted) + except KeyError: + on_hold.remove(permuted) + keep = _choose_one_to_keep(term, permuted, P.args) + new_terms.add(P*keep) + else: + new_terms.add(term) + terms = new_terms | on_hold + return Add(*terms) + return expr diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/sho.py b/.venv/lib/python3.13/site-packages/sympy/physics/sho.py new file mode 100644 index 0000000000000000000000000000000000000000..c55b31b3fa9fca4fa33a9f8e91c90c2174fe81a5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/sho.py @@ -0,0 +1,95 @@ +from sympy.core import S, pi, Rational +from sympy.functions import assoc_laguerre, sqrt, exp, factorial, factorial2 + + +def R_nl(n, l, nu, r): + """ + Returns the radial wavefunction R_{nl} for a 3d isotropic harmonic + oscillator. + + Parameters + ========== + + n : + The "nodal" quantum number. Corresponds to the number of nodes in + the wavefunction. ``n >= 0`` + l : + The quantum number for orbital angular momentum. + nu : + mass-scaled frequency: nu = m*omega/(2*hbar) where `m` is the mass + and `omega` the frequency of the oscillator. + (in atomic units ``nu == omega/2``) + r : + Radial coordinate. + + Examples + ======== + + >>> from sympy.physics.sho import R_nl + >>> from sympy.abc import r, nu, l + >>> R_nl(0, 0, 1, r) + 2*2**(3/4)*exp(-r**2)/pi**(1/4) + >>> R_nl(1, 0, 1, r) + 4*2**(1/4)*sqrt(3)*(3/2 - 2*r**2)*exp(-r**2)/(3*pi**(1/4)) + + l, nu and r may be symbolic: + + >>> R_nl(0, 0, nu, r) + 2*2**(3/4)*sqrt(nu**(3/2))*exp(-nu*r**2)/pi**(1/4) + >>> R_nl(0, l, 1, r) + r**l*sqrt(2**(l + 3/2)*2**(l + 2)/factorial2(2*l + 1))*exp(-r**2)/pi**(1/4) + + The normalization of the radial wavefunction is: + + >>> from sympy import Integral, oo + >>> Integral(R_nl(0, 0, 1, r)**2*r**2, (r, 0, oo)).n() + 1.00000000000000 + >>> Integral(R_nl(1, 0, 1, r)**2*r**2, (r, 0, oo)).n() + 1.00000000000000 + >>> Integral(R_nl(1, 1, 1, r)**2*r**2, (r, 0, oo)).n() + 1.00000000000000 + + """ + n, l, nu, r = map(S, [n, l, nu, r]) + + # formula uses n >= 1 (instead of nodal n >= 0) + n = n + 1 + C = sqrt( + ((2*nu)**(l + Rational(3, 2))*2**(n + l + 1)*factorial(n - 1))/ + (sqrt(pi)*(factorial2(2*n + 2*l - 1))) + ) + return C*r**(l)*exp(-nu*r**2)*assoc_laguerre(n - 1, l + S.Half, 2*nu*r**2) + + +def E_nl(n, l, hw): + """ + Returns the Energy of an isotropic harmonic oscillator. + + Parameters + ========== + + n : + The "nodal" quantum number. + l : + The orbital angular momentum. + hw : + The harmonic oscillator parameter. + + Notes + ===== + + The unit of the returned value matches the unit of hw, since the energy is + calculated as: + + E_nl = (2*n + l + 3/2)*hw + + Examples + ======== + + >>> from sympy.physics.sho import E_nl + >>> from sympy import symbols + >>> x, y, z = symbols('x, y, z') + >>> E_nl(x, y, z) + z*(2*x + y + 3/2) + """ + return (2*n + l + Rational(3, 2))*hw diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_clebsch_gordan.py b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_clebsch_gordan.py new file mode 100644 index 0000000000000000000000000000000000000000..e4313e3e412d6d1883efaf693c13e0f967daf9da --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_clebsch_gordan.py @@ -0,0 +1,223 @@ +from sympy.core.numbers import (I, pi, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.spherical_harmonics import Ynm +from sympy.matrices.dense import Matrix +from sympy.physics.wigner import (clebsch_gordan, wigner_9j, wigner_6j, gaunt, + real_gaunt, racah, dot_rot_grad_Ynm, wigner_3j, wigner_d_small, wigner_d) +from sympy.testing.pytest import raises, skip + +# for test cases, refer : https://en.wikipedia.org/wiki/Table_of_Clebsch%E2%80%93Gordan_coefficients + +def test_clebsch_gordan_docs(): + assert clebsch_gordan(Rational(3, 2), S.Half, 2, Rational(3, 2), S.Half, 2) == 1 + assert clebsch_gordan(Rational(3, 2), S.Half, 1, Rational(3, 2), Rational(-1, 2), 1) == sqrt(3)/2 + assert clebsch_gordan(Rational(3, 2), S.Half, 1, Rational(-1, 2), S.Half, 0) == -sqrt(2)/2 + + +def test_clebsch_gordan(): + # Argument order: (j_1, j_2, j, m_1, m_2, m) + + h = S.One + k = S.Half + l = Rational(3, 2) + i = Rational(-1, 2) + n = Rational(7, 2) + p = Rational(5, 2) + assert clebsch_gordan(k, k, 1, k, k, 1) == 1 + assert clebsch_gordan(k, k, 1, k, k, 0) == 0 + assert clebsch_gordan(k, k, 1, i, i, -1) == 1 + assert clebsch_gordan(k, k, 1, k, i, 0) == sqrt(2)/2 + assert clebsch_gordan(k, k, 0, k, i, 0) == sqrt(2)/2 + assert clebsch_gordan(k, k, 1, i, k, 0) == sqrt(2)/2 + assert clebsch_gordan(k, k, 0, i, k, 0) == -sqrt(2)/2 + assert clebsch_gordan(h, k, l, 1, k, l) == 1 + assert clebsch_gordan(h, k, l, 1, i, k) == 1/sqrt(3) + assert clebsch_gordan(h, k, k, 1, i, k) == sqrt(2)/sqrt(3) + assert clebsch_gordan(h, k, k, 0, k, k) == -1/sqrt(3) + assert clebsch_gordan(h, k, l, 0, k, k) == sqrt(2)/sqrt(3) + assert clebsch_gordan(h, h, S(2), 1, 1, S(2)) == 1 + assert clebsch_gordan(h, h, S(2), 1, 0, 1) == 1/sqrt(2) + assert clebsch_gordan(h, h, S(2), 0, 1, 1) == 1/sqrt(2) + assert clebsch_gordan(h, h, 1, 1, 0, 1) == 1/sqrt(2) + assert clebsch_gordan(h, h, 1, 0, 1, 1) == -1/sqrt(2) + assert clebsch_gordan(l, l, S(3), l, l, S(3)) == 1 + assert clebsch_gordan(l, l, S(2), l, k, S(2)) == 1/sqrt(2) + assert clebsch_gordan(l, l, S(3), l, k, S(2)) == 1/sqrt(2) + assert clebsch_gordan(S(2), S(2), S(4), S(2), S(2), S(4)) == 1 + assert clebsch_gordan(S(2), S(2), S(3), S(2), 1, S(3)) == 1/sqrt(2) + assert clebsch_gordan(S(2), S(2), S(3), 1, 1, S(2)) == 0 + assert clebsch_gordan(p, h, n, p, 1, n) == 1 + assert clebsch_gordan(p, h, p, p, 0, p) == sqrt(5)/sqrt(7) + assert clebsch_gordan(p, h, l, k, 1, l) == 1/sqrt(15) + + +def test_clebsch_gordan_numpy(): + try: + import numpy as np + except ImportError: + skip("numpy not installed") + assert clebsch_gordan(*np.zeros(6).astype(np.int64)) == 1 + assert wigner_3j(2, np.float64(6.0), 4.0, 0, 0, 0) == sqrt(715)/143 + assert wigner_3j(0, 0.5, 0.5, 0, 0.5, -0.5) == sqrt(2)/2 + raises(ValueError, lambda: wigner_3j(2.1, 6, 4, 0, 0, 0)) + + +def test_wigner(): + try: + import numpy as np + except ImportError: + skip("numpy not installed") + def tn(a, b): + return (a - b).n(64) < S('1e-64') + assert tn(wigner_9j(1, 1, 1, 1, 1, 1, 1, 1, 0, prec=64), Rational(1, 18)) + assert wigner_9j(3, 3, 2, 3, 3, 2, 3, 3, 2) == 3221*sqrt( + 70)/(246960*sqrt(105)) - 365/(3528*sqrt(70)*sqrt(105)) + assert wigner_6j(5, 5, 5, 5, 5, 5) == Rational(1, 52) + assert tn(wigner_6j(8, 8, 8, 8, 8, 8, prec=64), Rational(-12219, 965770)) + assert wigner_6j(1, 1, 1, 1.0, np.float64(1.0), 1) == Rational(1, 6) + assert wigner_6j(3.0, np.float32(3), 3.0, 3, 3, 3) == Rational(-1, 14) + # regression test for #8747 + half = S.Half + assert wigner_9j(0, 0, 0, 0, half, half, 0, half, half) == half + assert (wigner_9j(3, 5, 4, + 7 * half, 5 * half, 4, + 9 * half, 9 * half, 0) + == -sqrt(Rational(361, 205821000))) + assert (wigner_9j(1, 4, 3, + 5 * half, 4, 5 * half, + 5 * half, 2, 7 * half) + == -sqrt(Rational(3971, 373403520))) + assert (wigner_9j(4, 9 * half, 5 * half, + 2, 4, 4, + 5, 7 * half, 7 * half) + == -sqrt(Rational(3481, 5042614500))) + assert (wigner_9j(5, 5, 5.0, + np.float64(5.0), 5, 5, + 5, 5, 5) + == 0) + assert (wigner_9j(1.0, 2.0, 3.0, + 3, 2, 1, + 2, 1, 3) + == -4*sqrt(70)/11025) + + +def test_gaunt(): + def tn(a, b): + return (a - b).n(64) < S('1e-64') + assert gaunt(1, 0, 1, 1, 0, -1) == -1/(2*sqrt(pi)) + assert isinstance(gaunt(1, 1, 0, -1, 1, 0).args[0], Rational) + assert isinstance(gaunt(0, 1, 1, 0, -1, 1).args[0], Rational) + + assert tn(gaunt( + 10, 10, 12, 9, 3, -12, prec=64), (Rational(-98, 62031)) * sqrt(6279)/sqrt(pi)) + def gaunt_ref(l1, l2, l3, m1, m2, m3): + return ( + sqrt((2 * l1 + 1) * (2 * l2 + 1) * (2 * l3 + 1) / (4 * pi)) * + wigner_3j(l1, l2, l3, 0, 0, 0) * + wigner_3j(l1, l2, l3, m1, m2, m3) + ) + threshold = 1e-10 + l_max = 3 + l3_max = 24 + for l1 in range(l_max + 1): + for l2 in range(l_max + 1): + for l3 in range(l3_max + 1): + for m1 in range(-l1, l1 + 1): + for m2 in range(-l2, l2 + 1): + for m3 in range(-l3, l3 + 1): + args = l1, l2, l3, m1, m2, m3 + g = gaunt(*args) + g0 = gaunt_ref(*args) + assert abs(g - g0) < threshold + if m1 + m2 + m3 != 0: + assert abs(g) < threshold + if (l1 + l2 + l3) % 2: + assert abs(g) < threshold + assert gaunt(1, 1, 0, 0, 2, -2) is S.Zero + + +def test_realgaunt(): + # All non-zero values corresponding to l values from 0 to 2 + for l in range(3): + for m in range(-l, l+1): + assert real_gaunt(0, l, l, 0, m, m) == 1/(2*sqrt(pi)) + assert real_gaunt(1, 1, 2, 0, 0, 0) == sqrt(5)/(5*sqrt(pi)) + assert real_gaunt(1, 1, 2, 1, 1, 0) == -sqrt(5)/(10*sqrt(pi)) + assert real_gaunt(2, 2, 2, 0, 0, 0) == sqrt(5)/(7*sqrt(pi)) + assert real_gaunt(2, 2, 2, 0, 2, 2) == -sqrt(5)/(7*sqrt(pi)) + assert real_gaunt(2, 2, 2, -2, -2, 0) == -sqrt(5)/(7*sqrt(pi)) + assert real_gaunt(1, 1, 2, -1, 0, -1) == sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(1, 1, 2, 0, 1, 1) == sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(1, 1, 2, 1, 1, 2) == sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(1, 1, 2, -1, 1, -2) == sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(1, 1, 2, -1, -1, 2) == -sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(2, 2, 2, 0, 1, 1) == sqrt(5)/(14*sqrt(pi)) + assert real_gaunt(2, 2, 2, 1, 1, 2) == sqrt(15)/(14*sqrt(pi)) + assert real_gaunt(2, 2, 2, -1, -1, 2) == -sqrt(15)/(14*sqrt(pi)) + + assert real_gaunt(-2, -2, -2, -2, -2, 0) is S.Zero # m test + assert real_gaunt(-2, 1, 0, 1, 1, 1) is S.Zero # l test + assert real_gaunt(-2, -1, -2, -1, -1, 0) is S.Zero # m and l test + assert real_gaunt(-2, -2, -2, -2, -2, -2) is S.Zero # m and k test + assert real_gaunt(-2, -1, -2, -1, -1, -1) is S.Zero # m, l and k test + + x = symbols('x', integer=True) + v = [0]*6 + for i in range(len(v)): + v[i] = x # non literal ints fail + raises(ValueError, lambda: real_gaunt(*v)) + v[i] = 0 + + +def test_racah(): + assert racah(3,3,3,3,3,3) == Rational(-1,14) + assert racah(2,2,2,2,2,2) == Rational(-3,70) + assert racah(7,8,7,1,7,7, prec=4).is_Float + assert racah(5.5,7.5,9.5,6.5,8,9) == -719*sqrt(598)/1158924 + assert abs(racah(5.5,7.5,9.5,6.5,8,9, prec=4) - (-0.01517)) < S('1e-4') + + +def test_dot_rota_grad_SH(): + theta, phi = symbols("theta phi") + assert dot_rot_grad_Ynm(1, 1, 1, 1, 1, 0) != \ + sqrt(30)*Ynm(2, 2, 1, 0)/(10*sqrt(pi)) + assert dot_rot_grad_Ynm(1, 1, 1, 1, 1, 0).doit() == \ + sqrt(30)*Ynm(2, 2, 1, 0)/(10*sqrt(pi)) + assert dot_rot_grad_Ynm(1, 5, 1, 1, 1, 2) != \ + 0 + assert dot_rot_grad_Ynm(1, 5, 1, 1, 1, 2).doit() == \ + 0 + assert dot_rot_grad_Ynm(3, 3, 3, 3, theta, phi).doit() == \ + 15*sqrt(3003)*Ynm(6, 6, theta, phi)/(143*sqrt(pi)) + assert dot_rot_grad_Ynm(3, 3, 1, 1, theta, phi).doit() == \ + sqrt(3)*Ynm(4, 4, theta, phi)/sqrt(pi) + assert dot_rot_grad_Ynm(3, 2, 2, 0, theta, phi).doit() == \ + 3*sqrt(55)*Ynm(5, 2, theta, phi)/(11*sqrt(pi)) + assert dot_rot_grad_Ynm(3, 2, 3, 2, theta, phi).doit().expand() == \ + -sqrt(70)*Ynm(4, 4, theta, phi)/(11*sqrt(pi)) + \ + 45*sqrt(182)*Ynm(6, 4, theta, phi)/(143*sqrt(pi)) + + +def test_wigner_d(): + half = S(1)/2 + assert wigner_d_small(half, 0) == Matrix([[1, 0], [0, 1]]) + assert wigner_d_small(half, pi/2) == Matrix([[1, 1], [-1, 1]])/sqrt(2) + assert wigner_d_small(half, pi) == Matrix([[0, 1], [-1, 0]]) + + alpha, beta, gamma = symbols("alpha, beta, gamma", real=True) + D = wigner_d(half, alpha, beta, gamma) + assert D[0, 0] == exp(I*alpha/2)*exp(I*gamma/2)*cos(beta/2) + assert D[0, 1] == exp(I*alpha/2)*exp(-I*gamma/2)*sin(beta/2) + assert D[1, 0] == -exp(-I*alpha/2)*exp(I*gamma/2)*sin(beta/2) + assert D[1, 1] == exp(-I*alpha/2)*exp(-I*gamma/2)*cos(beta/2) + + # Test Y_{n mi}(g*x)=\sum_{mj}D^n_{mi mj}*Y_{n mj}(x) + theta, phi = symbols("theta phi", real=True) + v = Matrix([Ynm(1, mj, theta, phi) for mj in range(1, -2, -1)]) + w = wigner_d(1, -pi/2, pi/2, -pi/2)@v.subs({theta: pi/4, phi: pi}) + w_ = v.subs({theta: pi/2, phi: pi/4}) + assert w.expand(func=True).as_real_imag() == w_.expand(func=True).as_real_imag() diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_hydrogen.py b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_hydrogen.py new file mode 100644 index 0000000000000000000000000000000000000000..eb11744dd8e731f24fcd6f6be2a92ada4fffc554 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_hydrogen.py @@ -0,0 +1,126 @@ +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.integrals.integrals import integrate +from sympy.simplify.simplify import simplify +from sympy.physics.hydrogen import R_nl, E_nl, E_nl_dirac, Psi_nlm +from sympy.testing.pytest import raises + +n, r, Z = symbols('n r Z') + + +def feq(a, b, max_relative_error=1e-12, max_absolute_error=1e-12): + a = float(a) + b = float(b) + # if the numbers are close enough (absolutely), then they are equal + if abs(a - b) < max_absolute_error: + return True + # if not, they can still be equal if their relative error is small + if abs(b) > abs(a): + relative_error = abs((a - b)/b) + else: + relative_error = abs((a - b)/a) + return relative_error <= max_relative_error + + +def test_wavefunction(): + a = 1/Z + R = { + (1, 0): 2*sqrt(1/a**3) * exp(-r/a), + (2, 0): sqrt(1/(2*a**3)) * exp(-r/(2*a)) * (1 - r/(2*a)), + (2, 1): S.Half * sqrt(1/(6*a**3)) * exp(-r/(2*a)) * r/a, + (3, 0): Rational(2, 3) * sqrt(1/(3*a**3)) * exp(-r/(3*a)) * + (1 - 2*r/(3*a) + Rational(2, 27) * (r/a)**2), + (3, 1): Rational(4, 27) * sqrt(2/(3*a**3)) * exp(-r/(3*a)) * + (1 - r/(6*a)) * r/a, + (3, 2): Rational(2, 81) * sqrt(2/(15*a**3)) * exp(-r/(3*a)) * (r/a)**2, + (4, 0): Rational(1, 4) * sqrt(1/a**3) * exp(-r/(4*a)) * + (1 - 3*r/(4*a) + Rational(1, 8) * (r/a)**2 - Rational(1, 192) * (r/a)**3), + (4, 1): Rational(1, 16) * sqrt(5/(3*a**3)) * exp(-r/(4*a)) * + (1 - r/(4*a) + Rational(1, 80) * (r/a)**2) * (r/a), + (4, 2): Rational(1, 64) * sqrt(1/(5*a**3)) * exp(-r/(4*a)) * + (1 - r/(12*a)) * (r/a)**2, + (4, 3): Rational(1, 768) * sqrt(1/(35*a**3)) * exp(-r/(4*a)) * (r/a)**3, + } + for n, l in R: + assert simplify(R_nl(n, l, r, Z) - R[(n, l)]) == 0 + + +def test_norm(): + # Maximum "n" which is tested: + n_max = 2 # it works, but is slow, for n_max > 2 + for n in range(n_max + 1): + for l in range(n): + assert integrate(R_nl(n, l, r)**2 * r**2, (r, 0, oo)) == 1 + +def test_psi_nlm(): + r=S('r') + phi=S('phi') + theta=S('theta') + assert (Psi_nlm(1, 0, 0, r, phi, theta) == exp(-r) / sqrt(pi)) + assert (Psi_nlm(2, 1, -1, r, phi, theta)) == S.Half * exp(-r / (2)) * r \ + * (sin(theta) * exp(-I * phi) / (4 * sqrt(pi))) + assert (Psi_nlm(3, 2, 1, r, phi, theta, 2) == -sqrt(2) * sin(theta) \ + * exp(I * phi) * cos(theta) / (4 * sqrt(pi)) * S(2) / 81 \ + * sqrt(2 * 2 ** 3) * exp(-2 * r / (3)) * (r * 2) ** 2) + +def test_hydrogen_energies(): + assert E_nl(n, Z) == -Z**2/(2*n**2) + assert E_nl(n) == -1/(2*n**2) + + assert E_nl(1, 47) == -S(47)**2/(2*1**2) + assert E_nl(2, 47) == -S(47)**2/(2*2**2) + + assert E_nl(1) == -S.One/(2*1**2) + assert E_nl(2) == -S.One/(2*2**2) + assert E_nl(3) == -S.One/(2*3**2) + assert E_nl(4) == -S.One/(2*4**2) + assert E_nl(100) == -S.One/(2*100**2) + + raises(ValueError, lambda: E_nl(0)) + + +def test_hydrogen_energies_relat(): + # First test exact formulas for small "c" so that we get nice expressions: + assert E_nl_dirac(2, 0, Z=1, c=1) == 1/sqrt(2) - 1 + assert simplify(E_nl_dirac(2, 0, Z=1, c=2) - ( (8*sqrt(3) + 16) + / sqrt(16*sqrt(3) + 32) - 4)) == 0 + assert simplify(E_nl_dirac(2, 0, Z=1, c=3) - ( (54*sqrt(2) + 81) + / sqrt(108*sqrt(2) + 162) - 9)) == 0 + + # Now test for almost the correct speed of light, without floating point + # numbers: + assert simplify(E_nl_dirac(2, 0, Z=1, c=137) - ( (352275361 + 10285412 * + sqrt(1173)) / sqrt(704550722 + 20570824 * sqrt(1173)) - 18769)) == 0 + assert simplify(E_nl_dirac(2, 0, Z=82, c=137) - ( (352275361 + 2571353 * + sqrt(12045)) / sqrt(704550722 + 5142706*sqrt(12045)) - 18769)) == 0 + + # Test using exact speed of light, and compare against the nonrelativistic + # energies: + for n in range(1, 5): + for l in range(n): + assert feq(E_nl_dirac(n, l), E_nl(n), 1e-5, 1e-5) + if l > 0: + assert feq(E_nl_dirac(n, l, False), E_nl(n), 1e-5, 1e-5) + + Z = 2 + for n in range(1, 5): + for l in range(n): + assert feq(E_nl_dirac(n, l, Z=Z), E_nl(n, Z), 1e-4, 1e-4) + if l > 0: + assert feq(E_nl_dirac(n, l, False, Z), E_nl(n, Z), 1e-4, 1e-4) + + Z = 3 + for n in range(1, 5): + for l in range(n): + assert feq(E_nl_dirac(n, l, Z=Z), E_nl(n, Z), 1e-3, 1e-3) + if l > 0: + assert feq(E_nl_dirac(n, l, False, Z), E_nl(n, Z), 1e-3, 1e-3) + + # Test the exceptions: + raises(ValueError, lambda: E_nl_dirac(0, 0)) + raises(ValueError, lambda: E_nl_dirac(1, -1)) + raises(ValueError, lambda: E_nl_dirac(1, 0, False)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_paulialgebra.py b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_paulialgebra.py new file mode 100644 index 0000000000000000000000000000000000000000..f773470a1802f2864b79f56d38be1de030ff86dc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_paulialgebra.py @@ -0,0 +1,57 @@ +from sympy.core.numbers import I +from sympy.core.symbol import symbols +from sympy.physics.paulialgebra import Pauli +from sympy.testing.pytest import XFAIL +from sympy.physics.quantum import TensorProduct + +sigma1 = Pauli(1) +sigma2 = Pauli(2) +sigma3 = Pauli(3) + +tau1 = symbols("tau1", commutative = False) + + +def test_Pauli(): + + assert sigma1 == sigma1 + assert sigma1 != sigma2 + + assert sigma1*sigma2 == I*sigma3 + assert sigma3*sigma1 == I*sigma2 + assert sigma2*sigma3 == I*sigma1 + + assert sigma1*sigma1 == 1 + assert sigma2*sigma2 == 1 + assert sigma3*sigma3 == 1 + + assert sigma1**0 == 1 + assert sigma1**1 == sigma1 + assert sigma1**2 == 1 + assert sigma1**3 == sigma1 + assert sigma1**4 == 1 + + assert sigma3**2 == 1 + + assert sigma1*2*sigma1 == 2 + + +def test_evaluate_pauli_product(): + from sympy.physics.paulialgebra import evaluate_pauli_product + + assert evaluate_pauli_product(I*sigma2*sigma3) == -sigma1 + + # Check issue 6471 + assert evaluate_pauli_product(-I*4*sigma1*sigma2) == 4*sigma3 + + assert evaluate_pauli_product( + 1 + I*sigma1*sigma2*sigma1*sigma2 + \ + I*sigma1*sigma2*tau1*sigma1*sigma3 + \ + ((tau1**2).subs(tau1, I*sigma1)) + \ + sigma3*((tau1**2).subs(tau1, I*sigma1)) + \ + TensorProduct(I*sigma1*sigma2*sigma1*sigma2, 1) + ) == 1 -I + I*sigma3*tau1*sigma2 - 1 - sigma3 - I*TensorProduct(1,1) + + +@XFAIL +def test_Pauli_should_work(): + assert sigma1*sigma3*sigma1 == -sigma3 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_physics_matrices.py b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_physics_matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..14fa47668d0760826e0354c8cafae787a24256eb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_physics_matrices.py @@ -0,0 +1,84 @@ +from sympy.physics.matrices import msigma, mgamma, minkowski_tensor, pat_matrix, mdft +from sympy.core.numbers import (I, Rational) +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import (Matrix, eye, zeros) +from sympy.testing.pytest import warns_deprecated_sympy + + +def test_parallel_axis_theorem(): + # This tests the parallel axis theorem matrix by comparing to test + # matrices. + + # First case, 1 in all directions. + mat1 = Matrix(((2, -1, -1), (-1, 2, -1), (-1, -1, 2))) + assert pat_matrix(1, 1, 1, 1) == mat1 + assert pat_matrix(2, 1, 1, 1) == 2*mat1 + + # Second case, 1 in x, 0 in all others + mat2 = Matrix(((0, 0, 0), (0, 1, 0), (0, 0, 1))) + assert pat_matrix(1, 1, 0, 0) == mat2 + assert pat_matrix(2, 1, 0, 0) == 2*mat2 + + # Third case, 1 in y, 0 in all others + mat3 = Matrix(((1, 0, 0), (0, 0, 0), (0, 0, 1))) + assert pat_matrix(1, 0, 1, 0) == mat3 + assert pat_matrix(2, 0, 1, 0) == 2*mat3 + + # Fourth case, 1 in z, 0 in all others + mat4 = Matrix(((1, 0, 0), (0, 1, 0), (0, 0, 0))) + assert pat_matrix(1, 0, 0, 1) == mat4 + assert pat_matrix(2, 0, 0, 1) == 2*mat4 + + +def test_Pauli(): + #this and the following test are testing both Pauli and Dirac matrices + #and also that the general Matrix class works correctly in a real world + #situation + sigma1 = msigma(1) + sigma2 = msigma(2) + sigma3 = msigma(3) + + assert sigma1 == sigma1 + assert sigma1 != sigma2 + + # sigma*I -> I*sigma (see #354) + assert sigma1*sigma2 == sigma3*I + assert sigma3*sigma1 == sigma2*I + assert sigma2*sigma3 == sigma1*I + + assert sigma1*sigma1 == eye(2) + assert sigma2*sigma2 == eye(2) + assert sigma3*sigma3 == eye(2) + + assert sigma1*2*sigma1 == 2*eye(2) + assert sigma1*sigma3*sigma1 == -sigma3 + + +def test_Dirac(): + gamma0 = mgamma(0) + gamma1 = mgamma(1) + gamma2 = mgamma(2) + gamma3 = mgamma(3) + gamma5 = mgamma(5) + + # gamma*I -> I*gamma (see #354) + assert gamma5 == gamma0 * gamma1 * gamma2 * gamma3 * I + assert gamma1 * gamma2 + gamma2 * gamma1 == zeros(4) + assert gamma0 * gamma0 == eye(4) * minkowski_tensor[0, 0] + assert gamma2 * gamma2 != eye(4) * minkowski_tensor[0, 0] + assert gamma2 * gamma2 == eye(4) * minkowski_tensor[2, 2] + + assert mgamma(5, True) == \ + mgamma(0, True)*mgamma(1, True)*mgamma(2, True)*mgamma(3, True)*I + +def test_mdft(): + with warns_deprecated_sympy(): + assert mdft(1) == Matrix([[1]]) + with warns_deprecated_sympy(): + assert mdft(2) == 1/sqrt(2)*Matrix([[1,1],[1,-1]]) + with warns_deprecated_sympy(): + assert mdft(4) == Matrix([[S.Half, S.Half, S.Half, S.Half], + [S.Half, -I/2, Rational(-1,2), I/2], + [S.Half, Rational(-1,2), S.Half, Rational(-1,2)], + [S.Half, I/2, Rational(-1,2), -I/2]]) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_pring.py b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_pring.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7398eac4a8bb1cd4af810825caf3fcefb5f18f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_pring.py @@ -0,0 +1,41 @@ +from sympy.physics.pring import wavefunction, energy +from sympy.core.numbers import (I, pi) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals.integrals import integrate +from sympy.simplify.simplify import simplify +from sympy.abc import m, x, r +from sympy.physics.quantum.constants import hbar + + +def test_wavefunction(): + Psi = { + 0: (1/sqrt(2 * pi)), + 1: (1/sqrt(2 * pi)) * exp(I * x), + 2: (1/sqrt(2 * pi)) * exp(2 * I * x), + 3: (1/sqrt(2 * pi)) * exp(3 * I * x) + } + for n in Psi: + assert simplify(wavefunction(n, x) - Psi[n]) == 0 + + +def test_norm(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + assert integrate( + wavefunction(i, x) * wavefunction(-i, x), (x, 0, 2 * pi)) == 1 + + +def test_orthogonality(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + for j in range(i+1, n+1): + assert integrate( + wavefunction(i, x) * wavefunction(j, x), (x, 0, 2 * pi)) == 0 + + +def test_energy(n=1): + # Maximum "n" which is tested: + for i in range(n+1): + assert simplify( + energy(i, m, r) - ((i**2 * hbar**2) / (2 * m * r**2))) == 0 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_qho_1d.py b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_qho_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..34e52c9e3a721496fc61f7d2b31414db15caa7a8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_qho_1d.py @@ -0,0 +1,50 @@ +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals.integrals import integrate +from sympy.simplify.simplify import simplify +from sympy.abc import omega, m, x +from sympy.physics.qho_1d import psi_n, E_n, coherent_state +from sympy.physics.quantum.constants import hbar + +nu = m * omega / hbar + + +def test_wavefunction(): + Psi = { + 0: (nu/pi)**Rational(1, 4) * exp(-nu * x**2 /2), + 1: (nu/pi)**Rational(1, 4) * sqrt(2*nu) * x * exp(-nu * x**2 /2), + 2: (nu/pi)**Rational(1, 4) * (2 * nu * x**2 - 1)/sqrt(2) * exp(-nu * x**2 /2), + 3: (nu/pi)**Rational(1, 4) * sqrt(nu/3) * (2 * nu * x**3 - 3 * x) * exp(-nu * x**2 /2) + } + for n in Psi: + assert simplify(psi_n(n, x, m, omega) - Psi[n]) == 0 + + +def test_norm(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + assert integrate(psi_n(i, x, 1, 1)**2, (x, -oo, oo)) == 1 + + +def test_orthogonality(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + for j in range(i + 1, n + 1): + assert integrate( + psi_n(i, x, 1, 1)*psi_n(j, x, 1, 1), (x, -oo, oo)) == 0 + + +def test_energies(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + assert E_n(i, omega) == hbar * omega * (i + S.Half) + +def test_coherent_state(n=10): + # Maximum "n" which is tested: + # test whether coherent state is the eigenstate of annihilation operator + alpha = Symbol("alpha") + for i in range(n + 1): + assert simplify(sqrt(n + 1) * coherent_state(n + 1, alpha)) == simplify(alpha * coherent_state(n, alpha)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_secondquant.py b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_secondquant.py new file mode 100644 index 0000000000000000000000000000000000000000..e7f60fab05497aead65ad748460802c9c29740ce --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_secondquant.py @@ -0,0 +1,1301 @@ +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import exp +from sympy.physics.secondquant import ( + Dagger, Bd, VarBosonicBasis, BBra, B, BKet, FixedBosonicBasis, + matrix_rep, apply_operators, InnerProduct, Commutator, KroneckerDelta, + AnnihilateBoson, CreateBoson, BosonicOperator, + F, Fd, FKet, BosonState, CreateFermion, AnnihilateFermion, + evaluate_deltas, AntiSymmetricTensor, contraction, NO, wicks, + PermutationOperator, simplify_index_permutations, + _sort_anticommuting_fermions, _get_ordered_dummies, + substitute_dummies, FockStateBosonKet, + ContractionAppliesOnlyToFermions +) + +from sympy.concrete.summations import Sum +from sympy.core.function import (Function, expand) +from sympy.core.numbers import (I, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.printing.repr import srepr +from sympy.simplify.simplify import simplify + +from sympy.testing.pytest import slow, raises +from sympy.printing.latex import latex + + +def test_PermutationOperator(): + p, q, r, s = symbols('p,q,r,s') + f, g, h, i = map(Function, 'fghi') + P = PermutationOperator + assert P(p, q).get_permuted(f(p)*g(q)) == -f(q)*g(p) + assert P(p, q).get_permuted(f(p, q)) == -f(q, p) + assert P(p, q).get_permuted(f(p)) == f(p) + expr = (f(p)*g(q)*h(r)*i(s) + - f(q)*g(p)*h(r)*i(s) + - f(p)*g(q)*h(s)*i(r) + + f(q)*g(p)*h(s)*i(r)) + perms = [P(p, q), P(r, s)] + assert (simplify_index_permutations(expr, perms) == + P(p, q)*P(r, s)*f(p)*g(q)*h(r)*i(s)) + assert latex(P(p, q)) == 'P(pq)' + + p1, p2 = symbols('p1,p2') + assert latex(P(p1,p2) == 'P(p_{1}p_{2})') + +def test_index_permutations_with_dummies(): + a, b, c, d = symbols('a b c d') + p, q, r, s = symbols('p q r s', cls=Dummy) + f, g = map(Function, 'fg') + P = PermutationOperator + + # No dummy substitution necessary + expr = f(a, b, p, q) - f(b, a, p, q) + assert simplify_index_permutations( + expr, [P(a, b)]) == P(a, b)*f(a, b, p, q) + + # Cases where dummy substitution is needed + expected = P(a, b)*substitute_dummies(f(a, b, p, q)) + + expr = f(a, b, p, q) - f(b, a, q, p) + result = simplify_index_permutations(expr, [P(a, b)]) + assert expected == substitute_dummies(result) + + expr = f(a, b, q, p) - f(b, a, p, q) + result = simplify_index_permutations(expr, [P(a, b)]) + assert expected == substitute_dummies(result) + + # A case where nothing can be done + expr = f(a, b, q, p) - g(b, a, p, q) + result = simplify_index_permutations(expr, [P(a, b)]) + assert expr == result + + +def test_dagger(): + i, j, n, m = symbols('i,j,n,m') + assert Dagger(1) == 1 + assert Dagger(1.0) == 1.0 + assert Dagger(2*I) == -2*I + assert Dagger(S.Half*I/3.0) == I*Rational(-1, 2)/3.0 + assert Dagger(BKet([n])) == BBra([n]) + assert Dagger(B(0)) == Bd(0) + assert Dagger(Bd(0)) == B(0) + assert Dagger(B(n)) == Bd(n) + assert Dagger(Bd(n)) == B(n) + assert Dagger(B(0) + B(1)) == Bd(0) + Bd(1) + assert Dagger(n*m) == Dagger(n)*Dagger(m) # n, m commute + assert Dagger(B(n)*B(m)) == Bd(m)*Bd(n) + assert Dagger(B(n)**10) == Dagger(B(n))**10 + assert Dagger('a') == Dagger(Symbol('a')) + assert Dagger(Dagger('a')) == Symbol('a') + assert Dagger(exp(2 * I)) == exp(-2 * I) + assert Dagger(i) == conjugate(i) + + +def test_operator(): + i, j = symbols('i,j') + o = BosonicOperator(i) + assert o.state == i + assert o.is_symbolic + o = BosonicOperator(1) + assert o.state == 1 + assert not o.is_symbolic + + +def test_create(): + i, j, n, m, p1 = symbols('i,j,n,m,p1') + o = Bd(i) + assert latex(o) == "{b^\\dagger_{i}}" + assert latex(Bd(p1)) == "{b^\\dagger_{p_{1}}}" + assert isinstance(o, CreateBoson) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = Bd(0) + assert o.apply_operator(BKet([n])) == sqrt(n + 1)*BKet([n + 1]) + o = Bd(n) + assert o.apply_operator(BKet([n])) == o*BKet([n]) + + +def test_annihilate(): + i, j, n, m, p1 = symbols('i,j,n,m,p1') + o = B(i) + assert latex(o) == "b_{i}" + assert latex(B(p1)) == "b_{p_{1}}" + assert isinstance(o, AnnihilateBoson) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = B(0) + assert o.apply_operator(BKet([n])) == sqrt(n)*BKet([n - 1]) + o = B(n) + assert o.apply_operator(BKet([n])) == o*BKet([n]) + + +def test_basic_state(): + i, j, n, m = symbols('i,j,n,m') + s = BosonState([0, 1, 2, 3, 4]) + assert len(s) == 5 + assert s.args[0] == tuple(range(5)) + assert s.up(0) == BosonState([1, 1, 2, 3, 4]) + assert s.down(4) == BosonState([0, 1, 2, 3, 3]) + for i in range(5): + assert s.up(i).down(i) == s + assert s.down(0) == 0 + for i in range(5): + assert s[i] == i + s = BosonState([n, m]) + assert s.down(0) == BosonState([n - 1, m]) + assert s.up(0) == BosonState([n + 1, m]) + + +def test_basic_apply(): + n = symbols("n") + e = B(0)*BKet([n]) + assert apply_operators(e) == sqrt(n)*BKet([n - 1]) + e = Bd(0)*BKet([n]) + assert apply_operators(e) == sqrt(n + 1)*BKet([n + 1]) + + +def test_complex_apply(): + n, m = symbols("n,m") + o = Bd(0)*B(0)*Bd(1)*B(0) + e = apply_operators(o*BKet([n, m])) + answer = sqrt(n)*sqrt(m + 1)*(-1 + n)*BKet([-1 + n, 1 + m]) + assert expand(e) == expand(answer) + + +def test_number_operator(): + n = symbols("n") + o = Bd(0)*B(0) + e = apply_operators(o*BKet([n])) + assert e == n*BKet([n]) + + +def test_inner_product(): + i, j, k, l = symbols('i,j,k,l') + s1 = BBra([0]) + s2 = BKet([1]) + assert InnerProduct(s1, Dagger(s1)) == 1 + assert InnerProduct(s1, s2) == 0 + s1 = BBra([i, j]) + s2 = BKet([k, l]) + r = InnerProduct(s1, s2) + assert r == KroneckerDelta(i, k)*KroneckerDelta(j, l) + + +def test_symbolic_matrix_elements(): + n, m = symbols('n,m') + s1 = BBra([n]) + s2 = BKet([m]) + o = B(0) + e = apply_operators(s1*o*s2) + assert e == sqrt(m)*KroneckerDelta(n, m - 1) + + +def test_matrix_elements(): + b = VarBosonicBasis(5) + o = B(0) + m = matrix_rep(o, b) + for i in range(4): + assert m[i, i + 1] == sqrt(i + 1) + o = Bd(0) + m = matrix_rep(o, b) + for i in range(4): + assert m[i + 1, i] == sqrt(i + 1) + + +def test_fixed_bosonic_basis(): + b = FixedBosonicBasis(2, 2) + # assert b == [FockState((2, 0)), FockState((1, 1)), FockState((0, 2))] + state = b.state(1) + assert state == FockStateBosonKet((1, 1)) + assert b.index(state) == 1 + assert b.state(1) == b[1] + assert len(b) == 3 + assert str(b) == '[FockState((2, 0)), FockState((1, 1)), FockState((0, 2))]' + assert repr(b) == '[FockState((2, 0)), FockState((1, 1)), FockState((0, 2))]' + assert srepr(b) == '[FockState((2, 0)), FockState((1, 1)), FockState((0, 2))]' + + +@slow +def test_sho(): + n, m = symbols('n,m') + h_n = Bd(n)*B(n)*(n + S.Half) + H = Sum(h_n, (n, 0, 5)) + o = H.doit(deep=False) + b = FixedBosonicBasis(2, 6) + m = matrix_rep(o, b) + # We need to double check these energy values to make sure that they + # are correct and have the proper degeneracies! + diag = [1, 2, 3, 3, 4, 5, 4, 5, 6, 7, 5, 6, 7, 8, 9, 6, 7, 8, 9, 10, 11] + for i in range(len(diag)): + assert diag[i] == m[i, i] + + +def test_commutation(): + n, m = symbols("n,m", above_fermi=True) + c = Commutator(B(0), Bd(0)) + assert c == 1 + c = Commutator(Bd(0), B(0)) + assert c == -1 + c = Commutator(B(n), Bd(0)) + assert c == KroneckerDelta(n, 0) + c = Commutator(B(0), B(0)) + assert c == 0 + c = Commutator(B(0), Bd(0)) + e = simplify(apply_operators(c*BKet([n]))) + assert e == BKet([n]) + c = Commutator(B(0), B(1)) + e = simplify(apply_operators(c*BKet([n, m]))) + assert e == 0 + + c = Commutator(F(m), Fd(m)) + assert c == +1 - 2*NO(Fd(m)*F(m)) + c = Commutator(Fd(m), F(m)) + assert c.expand() == -1 + 2*NO(Fd(m)*F(m)) + + C = Commutator + X, Y, Z = symbols('X,Y,Z', commutative=False) + assert C(C(X, Y), Z) != 0 + assert C(C(X, Z), Y) != 0 + assert C(Y, C(X, Z)) != 0 + + i, j, k, l = symbols('i,j,k,l', below_fermi=True) + a, b, c, d = symbols('a,b,c,d', above_fermi=True) + p, q, r, s = symbols('p,q,r,s') + D = KroneckerDelta + + assert C(Fd(a), F(i)) == -2*NO(F(i)*Fd(a)) + assert C(Fd(j), NO(Fd(a)*F(i))).doit(wicks=True) == -D(j, i)*Fd(a) + assert C(Fd(a)*F(i), Fd(b)*F(j)).doit(wicks=True) == 0 + + c1 = Commutator(F(a), Fd(a)) + assert Commutator.eval(c1, c1) == 0 + c = Commutator(Fd(a)*F(i),Fd(b)*F(j)) + assert latex(c) == r'\left[{a^\dagger_{a}} a_{i},{a^\dagger_{b}} a_{j}\right]' + assert repr(c) == 'Commutator(CreateFermion(a)*AnnihilateFermion(i),CreateFermion(b)*AnnihilateFermion(j))' + assert str(c) == '[CreateFermion(a)*AnnihilateFermion(i),CreateFermion(b)*AnnihilateFermion(j)]' + + +def test_create_f(): + i, j, n, m = symbols('i,j,n,m') + o = Fd(i) + assert isinstance(o, CreateFermion) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = Fd(1) + assert o.apply_operator(FKet([n])) == FKet([1, n]) + assert o.apply_operator(FKet([n])) == -FKet([n, 1]) + o = Fd(n) + assert o.apply_operator(FKet([])) == FKet([n]) + + vacuum = FKet([], fermi_level=4) + assert vacuum == FKet([], fermi_level=4) + + i, j, k, l = symbols('i,j,k,l', below_fermi=True) + a, b, c, d = symbols('a,b,c,d', above_fermi=True) + p, q, r, s = symbols('p,q,r,s') + p1 = symbols("p1") + + assert Fd(i).apply_operator(FKet([i, j, k], 4)) == FKet([j, k], 4) + assert Fd(a).apply_operator(FKet([i, b, k], 4)) == FKet([a, i, b, k], 4) + + assert Dagger(B(p)).apply_operator(q) == q*CreateBoson(p) + assert repr(Fd(p)) == 'CreateFermion(p)' + assert srepr(Fd(p)) == "CreateFermion(Symbol('p'))" + assert latex(Fd(p)) == r'{a^\dagger_{p}}' + assert latex(Fd(p1)) == r'{a^\dagger_{p_{1}}}' + assert latex(FKet([a,i], 1)) == r"\left|\left( a, \ i\right)\right\rangle" + assert latex(FKet([j,i,b,a], 2)) == r"\left|\left( a, \ b, \ i, \ j\right)\right\rangle" + + +def test_annihilate_f(): + i, j, n, m = symbols('i,j,n,m') + o = F(i) + assert isinstance(o, AnnihilateFermion) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = F(1) + assert o.apply_operator(FKet([1, n])) == FKet([n]) + assert o.apply_operator(FKet([n, 1])) == -FKet([n]) + o = F(n) + assert o.apply_operator(FKet([n])) == FKet([]) + + i, j, k, l = symbols('i,j,k,l', below_fermi=True) + a, b, c, d = symbols('a,b,c,d', above_fermi=True) + p, q, r, s = symbols('p,q,r,s') + p1 = symbols('p1') + + assert F(i).apply_operator(FKet([i, j, k], 4)) == 0 + assert F(a).apply_operator(FKet([i, b, k], 4)) == 0 + assert F(l).apply_operator(FKet([i, j, k], 3)) == 0 + assert F(l).apply_operator(FKet([i, j, k], 4)) == FKet([l, i, j, k], 4) + assert str(F(p)) == 'f(p)' + assert repr(F(p)) == 'AnnihilateFermion(p)' + assert srepr(F(p)) == "AnnihilateFermion(Symbol('p'))" + assert latex(F(p)) == 'a_{p}' + assert latex(F(p1)) == 'a_{p_{1}}' + + +def test_create_b(): + i, j, n, m = symbols('i,j,n,m') + o = Bd(i) + assert isinstance(o, CreateBoson) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = Bd(0) + assert o.apply_operator(BKet([n])) == sqrt(n + 1)*BKet([n + 1]) + o = Bd(n) + assert o.apply_operator(BKet([n])) == o*BKet([n]) + + +def test_annihilate_b(): + i, j, n, m = symbols('i,j,n,m') + o = B(i) + assert isinstance(o, AnnihilateBoson) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = B(0) + + +def test_wicks(): + p, q, r, s = symbols('p,q,r,s', above_fermi=True) + + # Testing for particles only + + str = F(p)*Fd(q) + assert wicks(str) == NO(F(p)*Fd(q)) + KroneckerDelta(p, q) + str = Fd(p)*F(q) + assert wicks(str) == NO(Fd(p)*F(q)) + + str = F(p)*Fd(q)*F(r)*Fd(s) + nstr = wicks(str) + fasit = NO( + KroneckerDelta(p, q)*KroneckerDelta(r, s) + + KroneckerDelta(p, q)*AnnihilateFermion(r)*CreateFermion(s) + + KroneckerDelta(r, s)*AnnihilateFermion(p)*CreateFermion(q) + - KroneckerDelta(p, s)*AnnihilateFermion(r)*CreateFermion(q) + - AnnihilateFermion(p)*AnnihilateFermion(r)*CreateFermion(q)*CreateFermion(s)) + assert nstr == fasit + + assert (p*q*nstr).expand() == wicks(p*q*str) + assert (nstr*p*q*2).expand() == wicks(str*p*q*2) + + # Testing CC equations particles and holes + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + p, q, r, s = symbols('p q r s', cls=Dummy) + + assert (wicks(F(a)*NO(F(i)*F(j))*Fd(b)) == + NO(F(a)*F(i)*F(j)*Fd(b)) + + KroneckerDelta(a, b)*NO(F(i)*F(j))) + assert (wicks(F(a)*NO(F(i)*F(j)*F(k))*Fd(b)) == + NO(F(a)*F(i)*F(j)*F(k)*Fd(b)) - + KroneckerDelta(a, b)*NO(F(i)*F(j)*F(k))) + + expr = wicks(Fd(i)*NO(Fd(j)*F(k))*F(l)) + assert (expr == + -KroneckerDelta(i, k)*NO(Fd(j)*F(l)) - + KroneckerDelta(j, l)*NO(Fd(i)*F(k)) - + KroneckerDelta(i, k)*KroneckerDelta(j, l) + + KroneckerDelta(i, l)*NO(Fd(j)*F(k)) + + NO(Fd(i)*Fd(j)*F(k)*F(l))) + expr = wicks(F(a)*NO(F(b)*Fd(c))*Fd(d)) + assert (expr == + -KroneckerDelta(a, c)*NO(F(b)*Fd(d)) - + KroneckerDelta(b, d)*NO(F(a)*Fd(c)) - + KroneckerDelta(a, c)*KroneckerDelta(b, d) + + KroneckerDelta(a, d)*NO(F(b)*Fd(c)) + + NO(F(a)*F(b)*Fd(c)*Fd(d))) + + +def test_NO(): + i, j, k, l = symbols('i j k l', below_fermi=True) + a, b, c, d = symbols('a b c d', above_fermi=True) + p, q, r, s = symbols('p q r s', cls=Dummy) + + assert (NO(Fd(p)*F(q) + Fd(a)*F(b)) == + NO(Fd(p)*F(q)) + NO(Fd(a)*F(b))) + assert (NO(Fd(i)*NO(F(j)*Fd(a))) == + NO(Fd(i)*F(j)*Fd(a))) + assert NO(1) == 1 + assert NO(i) == i + assert (NO(Fd(a)*Fd(b)*(F(c) + F(d))) == + NO(Fd(a)*Fd(b)*F(c)) + + NO(Fd(a)*Fd(b)*F(d))) + + assert NO(Fd(a)*F(b))._remove_brackets() == Fd(a)*F(b) + assert NO(F(j)*Fd(i))._remove_brackets() == F(j)*Fd(i) + + assert (NO(Fd(p)*F(q)).subs(Fd(p), Fd(a) + Fd(i)) == + NO(Fd(a)*F(q)) + NO(Fd(i)*F(q))) + assert (NO(Fd(p)*F(q)).subs(F(q), F(a) + F(i)) == + NO(Fd(p)*F(a)) + NO(Fd(p)*F(i))) + + expr = NO(Fd(p)*F(q))._remove_brackets() + assert wicks(expr) == NO(expr) + + assert NO(Fd(a)*F(b)) == - NO(F(b)*Fd(a)) + + no = NO(Fd(a)*F(i)*F(b)*Fd(j)) + l1 = list(no.iter_q_creators()) + assert l1 == [0, 1] + l2 = list(no.iter_q_annihilators()) + assert l2 == [3, 2] + no = NO(Fd(a)*Fd(i)) + assert no.has_q_creators == 1 + assert no.has_q_annihilators == -1 + assert str(no) == ':CreateFermion(a)*CreateFermion(i):' + assert repr(no) == 'NO(CreateFermion(a)*CreateFermion(i))' + assert latex(no) == r'\left\{{a^\dagger_{a}} {a^\dagger_{i}}\right\}' + raises(NotImplementedError, lambda: NO(Bd(p)*F(q))) + + +def test_sorting(): + i, j = symbols('i,j', below_fermi=True) + a, b = symbols('a,b', above_fermi=True) + p, q = symbols('p,q') + + # p, q + assert _sort_anticommuting_fermions([Fd(p), F(q)]) == ([Fd(p), F(q)], 0) + assert _sort_anticommuting_fermions([F(p), Fd(q)]) == ([Fd(q), F(p)], 1) + + # i, p + assert _sort_anticommuting_fermions([F(p), Fd(i)]) == ([F(p), Fd(i)], 0) + assert _sort_anticommuting_fermions([Fd(i), F(p)]) == ([F(p), Fd(i)], 1) + assert _sort_anticommuting_fermions([Fd(p), Fd(i)]) == ([Fd(p), Fd(i)], 0) + assert _sort_anticommuting_fermions([Fd(i), Fd(p)]) == ([Fd(p), Fd(i)], 1) + assert _sort_anticommuting_fermions([F(p), F(i)]) == ([F(i), F(p)], 1) + assert _sort_anticommuting_fermions([F(i), F(p)]) == ([F(i), F(p)], 0) + assert _sort_anticommuting_fermions([Fd(p), F(i)]) == ([F(i), Fd(p)], 1) + assert _sort_anticommuting_fermions([F(i), Fd(p)]) == ([F(i), Fd(p)], 0) + + # a, p + assert _sort_anticommuting_fermions([F(p), Fd(a)]) == ([Fd(a), F(p)], 1) + assert _sort_anticommuting_fermions([Fd(a), F(p)]) == ([Fd(a), F(p)], 0) + assert _sort_anticommuting_fermions([Fd(p), Fd(a)]) == ([Fd(a), Fd(p)], 1) + assert _sort_anticommuting_fermions([Fd(a), Fd(p)]) == ([Fd(a), Fd(p)], 0) + assert _sort_anticommuting_fermions([F(p), F(a)]) == ([F(p), F(a)], 0) + assert _sort_anticommuting_fermions([F(a), F(p)]) == ([F(p), F(a)], 1) + assert _sort_anticommuting_fermions([Fd(p), F(a)]) == ([Fd(p), F(a)], 0) + assert _sort_anticommuting_fermions([F(a), Fd(p)]) == ([Fd(p), F(a)], 1) + + # i, a + assert _sort_anticommuting_fermions([F(i), Fd(j)]) == ([F(i), Fd(j)], 0) + assert _sort_anticommuting_fermions([Fd(j), F(i)]) == ([F(i), Fd(j)], 1) + assert _sort_anticommuting_fermions([Fd(a), Fd(i)]) == ([Fd(a), Fd(i)], 0) + assert _sort_anticommuting_fermions([Fd(i), Fd(a)]) == ([Fd(a), Fd(i)], 1) + assert _sort_anticommuting_fermions([F(a), F(i)]) == ([F(i), F(a)], 1) + assert _sort_anticommuting_fermions([F(i), F(a)]) == ([F(i), F(a)], 0) + + +def test_contraction(): + i, j, k, l = symbols('i,j,k,l', below_fermi=True) + a, b, c, d = symbols('a,b,c,d', above_fermi=True) + p, q, r, s = symbols('p,q,r,s') + assert contraction(Fd(i), F(j)) == KroneckerDelta(i, j) + assert contraction(F(a), Fd(b)) == KroneckerDelta(a, b) + assert contraction(F(a), Fd(i)) == 0 + assert contraction(Fd(a), F(i)) == 0 + assert contraction(F(i), Fd(a)) == 0 + assert contraction(Fd(i), F(a)) == 0 + assert contraction(Fd(i), F(p)) == KroneckerDelta(i, p) + restr = evaluate_deltas(contraction(Fd(p), F(q))) + assert restr.is_only_below_fermi + restr = evaluate_deltas(contraction(F(p), Fd(q))) + assert restr.is_only_above_fermi + raises(ContractionAppliesOnlyToFermions, lambda: contraction(B(a), Fd(b))) + + +def test_evaluate_deltas(): + i, j, k = symbols('i,j,k') + + r = KroneckerDelta(i, j) * KroneckerDelta(j, k) + assert evaluate_deltas(r) == KroneckerDelta(i, k) + + r = KroneckerDelta(i, 0) * KroneckerDelta(j, k) + assert evaluate_deltas(r) == KroneckerDelta(i, 0) * KroneckerDelta(j, k) + + r = KroneckerDelta(1, j) * KroneckerDelta(j, k) + assert evaluate_deltas(r) == KroneckerDelta(1, k) + + r = KroneckerDelta(j, 2) * KroneckerDelta(k, j) + assert evaluate_deltas(r) == KroneckerDelta(2, k) + + r = KroneckerDelta(i, 0) * KroneckerDelta(i, j) * KroneckerDelta(j, 1) + assert evaluate_deltas(r) == 0 + + r = (KroneckerDelta(0, i) * KroneckerDelta(0, j) + * KroneckerDelta(1, j) * KroneckerDelta(1, j)) + assert evaluate_deltas(r) == 0 + + +def test_Tensors(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + p, q, r, s = symbols('p q r s') + + AT = AntiSymmetricTensor + assert AT('t', (a, b), (i, j)) == -AT('t', (b, a), (i, j)) + assert AT('t', (a, b), (i, j)) == AT('t', (b, a), (j, i)) + assert AT('t', (a, b), (i, j)) == -AT('t', (a, b), (j, i)) + assert AT('t', (a, a), (i, j)) == 0 + assert AT('t', (a, b), (i, i)) == 0 + assert AT('t', (a, b, c), (i, j)) == -AT('t', (b, a, c), (i, j)) + assert AT('t', (a, b, c), (i, j, k)) == AT('t', (b, a, c), (i, k, j)) + + tabij = AT('t', (a, b), (i, j)) + assert tabij.has(a) + assert tabij.has(b) + assert tabij.has(i) + assert tabij.has(j) + assert tabij.subs(b, c) == AT('t', (a, c), (i, j)) + assert (2*tabij).subs(i, c) == 2*AT('t', (a, b), (c, j)) + assert tabij.symbol == Symbol('t') + assert latex(tabij) == '{t^{ab}_{ij}}' + assert str(tabij) == 't((_a, _b),(_i, _j))' + + assert AT('t', (a, a), (i, j)).subs(a, b) == AT('t', (b, b), (i, j)) + assert AT('t', (a, i), (a, j)).subs(a, b) == AT('t', (b, i), (b, j)) + + a1, a2, a3, a4 = symbols('alpha1:5') + u_alpha1234 = AntiSymmetricTensor("u", (a1, a2), (a3, a4)) + + assert latex(u_alpha1234) == r'{u^{\alpha_{1}\alpha_{2}}_{\alpha_{3}\alpha_{4}}}' + assert str(u_alpha1234) == 'u((alpha1, alpha2),(alpha3, alpha4))' + + +def test_fully_contracted(): + i, j, k, l = symbols('i j k l', below_fermi=True) + a, b, c, d = symbols('a b c d', above_fermi=True) + p, q, r, s = symbols('p q r s', cls=Dummy) + + Fock = (AntiSymmetricTensor('f', (p,), (q,))* + NO(Fd(p)*F(q))) + V = (AntiSymmetricTensor('v', (p, q), (r, s))* + NO(Fd(p)*Fd(q)*F(s)*F(r)))/4 + + Fai = wicks(NO(Fd(i)*F(a))*Fock, + keep_only_fully_contracted=True, + simplify_kronecker_deltas=True) + assert Fai == AntiSymmetricTensor('f', (a,), (i,)) + Vabij = wicks(NO(Fd(i)*Fd(j)*F(b)*F(a))*V, + keep_only_fully_contracted=True, + simplify_kronecker_deltas=True) + assert Vabij == AntiSymmetricTensor('v', (a, b), (i, j)) + + +def test_substitute_dummies_without_dummies(): + i, j = symbols('i,j') + assert substitute_dummies(att(i, j) + 2) == att(i, j) + 2 + assert substitute_dummies(att(i, j) + 1) == att(i, j) + 1 + + +def test_substitute_dummies_NO_operator(): + i, j = symbols('i j', cls=Dummy) + assert substitute_dummies(att(i, j)*NO(Fd(i)*F(j)) + - att(j, i)*NO(Fd(j)*F(i))) == 0 + + +def test_substitute_dummies_SQ_operator(): + i, j = symbols('i j', cls=Dummy) + assert substitute_dummies(att(i, j)*Fd(i)*F(j) + - att(j, i)*Fd(j)*F(i)) == 0 + + +def test_substitute_dummies_new_indices(): + i, j = symbols('i j', below_fermi=True, cls=Dummy) + a, b = symbols('a b', above_fermi=True, cls=Dummy) + p, q = symbols('p q', cls=Dummy) + f = Function('f') + assert substitute_dummies(f(i, a, p) - f(j, b, q), new_indices=True) == 0 + + +def test_substitute_dummies_substitution_order(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + f = Function('f') + from sympy.utilities.iterables import variations + for permut in variations([i, j, k, l], 4): + assert substitute_dummies(f(*permut) - f(i, j, k, l)) == 0 + + +def test_dummy_order_inner_outer_lines_VT1T1T1(): + ii = symbols('i', below_fermi=True) + aa = symbols('a', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + # Coupled-Cluster T1 terms with V*T1*T1*T1 + # t^{a}_{k} t^{c}_{i} t^{d}_{l} v^{lk}_{dc} + exprs = [ + # permut v and t <=> swapping internal lines, equivalent + # irrespective of symmetries in v + v(k, l, c, d)*t(c, ii)*t(d, l)*t(aa, k), + v(l, k, c, d)*t(c, ii)*t(d, k)*t(aa, l), + v(k, l, d, c)*t(d, ii)*t(c, l)*t(aa, k), + v(l, k, d, c)*t(d, ii)*t(c, k)*t(aa, l), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_dummy_order_inner_outer_lines_VT1T1T1T1(): + ii, jj = symbols('i j', below_fermi=True) + aa, bb = symbols('a b', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + # Coupled-Cluster T2 terms with V*T1*T1*T1*T1 + exprs = [ + # permut t <=> swapping external lines, not equivalent + # except if v has certain symmetries. + v(k, l, c, d)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(k, l, c, d)*t(c, jj)*t(d, ii)*t(aa, k)*t(bb, l), + v(k, l, c, d)*t(c, ii)*t(d, jj)*t(bb, k)*t(aa, l), + v(k, l, c, d)*t(c, jj)*t(d, ii)*t(bb, k)*t(aa, l), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + exprs = [ + # permut v <=> swapping external lines, not equivalent + # except if v has certain symmetries. + # + # Note that in contrast to above, these permutations have identical + # dummy order. That is because the proximity to external indices + # has higher influence on the canonical dummy ordering than the + # position of a dummy on the factors. In fact, the terms here are + # similar in structure as the result of the dummy substitutions above. + v(k, l, c, d)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(l, k, c, d)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(k, l, d, c)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(l, k, d, c)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) == dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + exprs = [ + # permut t and v <=> swapping internal lines, equivalent. + # Canonical dummy order is different, and a consistent + # substitution reveals the equivalence. + v(k, l, c, d)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(k, l, d, c)*t(c, jj)*t(d, ii)*t(aa, k)*t(bb, l), + v(l, k, c, d)*t(c, ii)*t(d, jj)*t(bb, k)*t(aa, l), + v(l, k, d, c)*t(c, jj)*t(d, ii)*t(bb, k)*t(aa, l), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_get_subNO(): + p, q, r = symbols('p,q,r') + assert NO(F(p)*F(q)*F(r)).get_subNO(1) == NO(F(p)*F(r)) + assert NO(F(p)*F(q)*F(r)).get_subNO(0) == NO(F(q)*F(r)) + assert NO(F(p)*F(q)*F(r)).get_subNO(2) == NO(F(p)*F(q)) + + +def test_equivalent_internal_lines_VT1T1(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + exprs = [ # permute v. Different dummy order. Not equivalent. + v(i, j, a, b)*t(a, i)*t(b, j), + v(j, i, a, b)*t(a, i)*t(b, j), + v(i, j, b, a)*t(a, i)*t(b, j), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v. Different dummy order. Equivalent + v(i, j, a, b)*t(a, i)*t(b, j), + v(j, i, b, a)*t(a, i)*t(b, j), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + exprs = [ # permute t. Same dummy order, not equivalent. + v(i, j, a, b)*t(a, i)*t(b, j), + v(i, j, a, b)*t(b, i)*t(a, j), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) == dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v and t. Different dummy order, equivalent + v(i, j, a, b)*t(a, i)*t(b, j), + v(j, i, a, b)*t(a, j)*t(b, i), + v(i, j, b, a)*t(b, i)*t(a, j), + v(j, i, b, a)*t(b, j)*t(a, i), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_equivalent_internal_lines_VT2conjT2(): + # this diagram requires special handling in TCE + i, j, k, l, m, n = symbols('i j k l m n', below_fermi=True, cls=Dummy) + a, b, c, d, e, f = symbols('a b c d e f', above_fermi=True, cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + + from sympy.utilities.iterables import variations + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + # v(abcd)t(abij)t(ijcd) + template = v(p1, p2, p3, p4)*t(p1, p2, i, j)*t(i, j, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + template = v(p1, p2, p3, p4)*t(p1, p2, j, i)*t(j, i, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + + # v(abcd)t(abij)t(jicd) + template = v(p1, p2, p3, p4)*t(p1, p2, i, j)*t(j, i, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + template = v(p1, p2, p3, p4)*t(p1, p2, j, i)*t(i, j, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def test_equivalent_internal_lines_VT2conjT2_ambiguous_order(): + # These diagrams invokes _determine_ambiguous() because the + # dummies can not be ordered unambiguously by the key alone + i, j, k, l, m, n = symbols('i j k l m n', below_fermi=True, cls=Dummy) + a, b, c, d, e, f = symbols('a b c d e f', above_fermi=True, cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + + from sympy.utilities.iterables import variations + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + # v(abcd)t(abij)t(cdij) + template = v(p1, p2, p3, p4)*t(p1, p2, i, j)*t(p3, p4, i, j) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + template = v(p1, p2, p3, p4)*t(p1, p2, j, i)*t(p3, p4, i, j) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def test_equivalent_internal_lines_VT2(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + exprs = [ + # permute v. Same dummy order, not equivalent. + # + # This test show that the dummy order may not be sensitive to all + # index permutations. The following expressions have identical + # structure as the resulting terms from of the dummy substitutions + # in the test above. Here, all expressions have the same dummy + # order, so they cannot be simplified by means of dummy + # substitution. In order to simplify further, it is necessary to + # exploit symmetries in the objects, for instance if t or v is + # antisymmetric. + v(i, j, a, b)*t(a, b, i, j), + v(j, i, a, b)*t(a, b, i, j), + v(i, j, b, a)*t(a, b, i, j), + v(j, i, b, a)*t(a, b, i, j), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) == dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ + # permute t. + v(i, j, a, b)*t(a, b, i, j), + v(i, j, a, b)*t(b, a, i, j), + v(i, j, a, b)*t(a, b, j, i), + v(i, j, a, b)*t(b, a, j, i), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v and t. Relabelling of dummies should be equivalent. + v(i, j, a, b)*t(a, b, i, j), + v(j, i, a, b)*t(a, b, j, i), + v(i, j, b, a)*t(b, a, i, j), + v(j, i, b, a)*t(b, a, j, i), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_internal_external_VT2T2(): + ii, jj = symbols('i j', below_fermi=True) + aa, bb = symbols('a b', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + exprs = [ + v(k, l, c, d)*t(aa, c, ii, k)*t(bb, d, jj, l), + v(l, k, c, d)*t(aa, c, ii, l)*t(bb, d, jj, k), + v(k, l, d, c)*t(aa, d, ii, k)*t(bb, c, jj, l), + v(l, k, d, c)*t(aa, d, ii, l)*t(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + exprs = [ + v(k, l, c, d)*t(aa, c, ii, k)*t(d, bb, jj, l), + v(l, k, c, d)*t(aa, c, ii, l)*t(d, bb, jj, k), + v(k, l, d, c)*t(aa, d, ii, k)*t(c, bb, jj, l), + v(l, k, d, c)*t(aa, d, ii, l)*t(c, bb, jj, k), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + exprs = [ + v(k, l, c, d)*t(c, aa, ii, k)*t(bb, d, jj, l), + v(l, k, c, d)*t(c, aa, ii, l)*t(bb, d, jj, k), + v(k, l, d, c)*t(d, aa, ii, k)*t(bb, c, jj, l), + v(l, k, d, c)*t(d, aa, ii, l)*t(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_internal_external_pqrs(): + ii, jj = symbols('i j') + aa, bb = symbols('a b') + k, l = symbols('k l', cls=Dummy) + c, d = symbols('c d', cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + exprs = [ + v(k, l, c, d)*t(aa, c, ii, k)*t(bb, d, jj, l), + v(l, k, c, d)*t(aa, c, ii, l)*t(bb, d, jj, k), + v(k, l, d, c)*t(aa, d, ii, k)*t(bb, c, jj, l), + v(l, k, d, c)*t(aa, d, ii, l)*t(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_dummy_order_well_defined(): + aa, bb = symbols('a b', above_fermi=True) + k, l, m = symbols('k l m', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + p, q = symbols('p q', cls=Dummy) + + A = Function('A') + B = Function('B') + C = Function('C') + dums = _get_ordered_dummies + + # We go through all key components in the order of increasing priority, + # and consider only fully orderable expressions. Non-orderable expressions + # are tested elsewhere. + + # pos in first factor determines sort order + assert dums(A(k, l)*B(l, k)) == [k, l] + assert dums(A(l, k)*B(l, k)) == [l, k] + assert dums(A(k, l)*B(k, l)) == [k, l] + assert dums(A(l, k)*B(k, l)) == [l, k] + + # factors involving the index + assert dums(A(k, l)*B(l, m)*C(k, m)) == [l, k, m] + assert dums(A(k, l)*B(l, m)*C(m, k)) == [l, k, m] + assert dums(A(l, k)*B(l, m)*C(k, m)) == [l, k, m] + assert dums(A(l, k)*B(l, m)*C(m, k)) == [l, k, m] + assert dums(A(k, l)*B(m, l)*C(k, m)) == [l, k, m] + assert dums(A(k, l)*B(m, l)*C(m, k)) == [l, k, m] + assert dums(A(l, k)*B(m, l)*C(k, m)) == [l, k, m] + assert dums(A(l, k)*B(m, l)*C(m, k)) == [l, k, m] + + # same, but with factor order determined by non-dummies + assert dums(A(k, aa, l)*A(l, bb, m)*A(bb, k, m)) == [l, k, m] + assert dums(A(k, aa, l)*A(l, bb, m)*A(bb, m, k)) == [l, k, m] + assert dums(A(k, aa, l)*A(m, bb, l)*A(bb, k, m)) == [l, k, m] + assert dums(A(k, aa, l)*A(m, bb, l)*A(bb, m, k)) == [l, k, m] + assert dums(A(l, aa, k)*A(l, bb, m)*A(bb, k, m)) == [l, k, m] + assert dums(A(l, aa, k)*A(l, bb, m)*A(bb, m, k)) == [l, k, m] + assert dums(A(l, aa, k)*A(m, bb, l)*A(bb, k, m)) == [l, k, m] + assert dums(A(l, aa, k)*A(m, bb, l)*A(bb, m, k)) == [l, k, m] + + # index range + assert dums(A(p, c, k)*B(p, c, k)) == [k, c, p] + assert dums(A(p, k, c)*B(p, c, k)) == [k, c, p] + assert dums(A(c, k, p)*B(p, c, k)) == [k, c, p] + assert dums(A(c, p, k)*B(p, c, k)) == [k, c, p] + assert dums(A(k, c, p)*B(p, c, k)) == [k, c, p] + assert dums(A(k, p, c)*B(p, c, k)) == [k, c, p] + assert dums(B(p, c, k)*A(p, c, k)) == [k, c, p] + assert dums(B(p, k, c)*A(p, c, k)) == [k, c, p] + assert dums(B(c, k, p)*A(p, c, k)) == [k, c, p] + assert dums(B(c, p, k)*A(p, c, k)) == [k, c, p] + assert dums(B(k, c, p)*A(p, c, k)) == [k, c, p] + assert dums(B(k, p, c)*A(p, c, k)) == [k, c, p] + + +def test_dummy_order_ambiguous(): + aa, bb = symbols('a b', above_fermi=True) + i, j, k, l, m = symbols('i j k l m', below_fermi=True, cls=Dummy) + a, b, c, d, e = symbols('a b c d e', above_fermi=True, cls=Dummy) + p, q = symbols('p q', cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + p5, p6, p7, p8 = symbols('p5 p6 p7 p8', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + h5, h6, h7, h8 = symbols('h5 h6 h7 h8', below_fermi=True, cls=Dummy) + + A = Function('A') + B = Function('B') + + from sympy.utilities.iterables import variations + + # A*A*A*A*B -- ordering of p5 and p4 is used to figure out the rest + template = A(p1, p2)*A(p4, p1)*A(p2, p3)*A(p3, p5)*B(p5, p4) + permutator = variations([a, b, c, d, e], 5) + base = template.subs(zip([p1, p2, p3, p4, p5], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4, p5], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + # A*A*A*A*A -- an arbitrary index is assigned and the rest are figured out + template = A(p1, p2)*A(p4, p1)*A(p2, p3)*A(p3, p5)*A(p5, p4) + permutator = variations([a, b, c, d, e], 5) + base = template.subs(zip([p1, p2, p3, p4, p5], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4, p5], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + # A*A*A -- ordering of p5 and p4 is used to figure out the rest + template = A(p1, p2, p4, p1)*A(p2, p3, p3, p5)*A(p5, p4) + permutator = variations([a, b, c, d, e], 5) + base = template.subs(zip([p1, p2, p3, p4, p5], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4, p5], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def atv(*args): + return AntiSymmetricTensor('v', args[:2], args[2:] ) + + +def att(*args): + if len(args) == 4: + return AntiSymmetricTensor('t', args[:2], args[2:] ) + elif len(args) == 2: + return AntiSymmetricTensor('t', (args[0],), (args[1],)) + + +def test_dummy_order_inner_outer_lines_VT1T1T1_AT(): + ii = symbols('i', below_fermi=True) + aa = symbols('a', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + # Coupled-Cluster T1 terms with V*T1*T1*T1 + # t^{a}_{k} t^{c}_{i} t^{d}_{l} v^{lk}_{dc} + exprs = [ + # permut v and t <=> swapping internal lines, equivalent + # irrespective of symmetries in v + atv(k, l, c, d)*att(c, ii)*att(d, l)*att(aa, k), + atv(l, k, c, d)*att(c, ii)*att(d, k)*att(aa, l), + atv(k, l, d, c)*att(d, ii)*att(c, l)*att(aa, k), + atv(l, k, d, c)*att(d, ii)*att(c, k)*att(aa, l), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_dummy_order_inner_outer_lines_VT1T1T1T1_AT(): + ii, jj = symbols('i j', below_fermi=True) + aa, bb = symbols('a b', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + # Coupled-Cluster T2 terms with V*T1*T1*T1*T1 + # non-equivalent substitutions (change of sign) + exprs = [ + # permut t <=> swapping external lines + atv(k, l, c, d)*att(c, ii)*att(d, jj)*att(aa, k)*att(bb, l), + atv(k, l, c, d)*att(c, jj)*att(d, ii)*att(aa, k)*att(bb, l), + atv(k, l, c, d)*att(c, ii)*att(d, jj)*att(bb, k)*att(aa, l), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == -substitute_dummies(permut) + + # equivalent substitutions + exprs = [ + atv(k, l, c, d)*att(c, ii)*att(d, jj)*att(aa, k)*att(bb, l), + # permut t <=> swapping external lines + atv(k, l, c, d)*att(c, jj)*att(d, ii)*att(bb, k)*att(aa, l), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_equivalent_internal_lines_VT1T1_AT(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + + exprs = [ # permute v. Different dummy order. Not equivalent. + atv(i, j, a, b)*att(a, i)*att(b, j), + atv(j, i, a, b)*att(a, i)*att(b, j), + atv(i, j, b, a)*att(a, i)*att(b, j), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v. Different dummy order. Equivalent + atv(i, j, a, b)*att(a, i)*att(b, j), + atv(j, i, b, a)*att(a, i)*att(b, j), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + exprs = [ # permute t. Same dummy order, not equivalent. + atv(i, j, a, b)*att(a, i)*att(b, j), + atv(i, j, a, b)*att(b, i)*att(a, j), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v and t. Different dummy order, equivalent + atv(i, j, a, b)*att(a, i)*att(b, j), + atv(j, i, a, b)*att(a, j)*att(b, i), + atv(i, j, b, a)*att(b, i)*att(a, j), + atv(j, i, b, a)*att(b, j)*att(a, i), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_equivalent_internal_lines_VT2conjT2_AT(): + # this diagram requires special handling in TCE + i, j, k, l, m, n = symbols('i j k l m n', below_fermi=True, cls=Dummy) + a, b, c, d, e, f = symbols('a b c d e f', above_fermi=True, cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + + from sympy.utilities.iterables import variations + + # atv(abcd)att(abij)att(ijcd) + template = atv(p1, p2, p3, p4)*att(p1, p2, i, j)*att(i, j, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + template = atv(p1, p2, p3, p4)*att(p1, p2, j, i)*att(j, i, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + # atv(abcd)att(abij)att(jicd) + template = atv(p1, p2, p3, p4)*att(p1, p2, i, j)*att(j, i, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + template = atv(p1, p2, p3, p4)*att(p1, p2, j, i)*att(i, j, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def test_equivalent_internal_lines_VT2conjT2_ambiguous_order_AT(): + # These diagrams invokes _determine_ambiguous() because the + # dummies can not be ordered unambiguously by the key alone + i, j, k, l, m, n = symbols('i j k l m n', below_fermi=True, cls=Dummy) + a, b, c, d, e, f = symbols('a b c d e f', above_fermi=True, cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + + from sympy.utilities.iterables import variations + + # atv(abcd)att(abij)att(cdij) + template = atv(p1, p2, p3, p4)*att(p1, p2, i, j)*att(p3, p4, i, j) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + template = atv(p1, p2, p3, p4)*att(p1, p2, j, i)*att(p3, p4, i, j) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def test_equivalent_internal_lines_VT2_AT(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + + exprs = [ + # permute v. Same dummy order, not equivalent. + atv(i, j, a, b)*att(a, b, i, j), + atv(j, i, a, b)*att(a, b, i, j), + atv(i, j, b, a)*att(a, b, i, j), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ + # permute t. + atv(i, j, a, b)*att(a, b, i, j), + atv(i, j, a, b)*att(b, a, i, j), + atv(i, j, a, b)*att(a, b, j, i), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v and t. Relabelling of dummies should be equivalent. + atv(i, j, a, b)*att(a, b, i, j), + atv(j, i, a, b)*att(a, b, j, i), + atv(i, j, b, a)*att(b, a, i, j), + atv(j, i, b, a)*att(b, a, j, i), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_internal_external_VT2T2_AT(): + ii, jj = symbols('i j', below_fermi=True) + aa, bb = symbols('a b', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + exprs = [ + atv(k, l, c, d)*att(aa, c, ii, k)*att(bb, d, jj, l), + atv(l, k, c, d)*att(aa, c, ii, l)*att(bb, d, jj, k), + atv(k, l, d, c)*att(aa, d, ii, k)*att(bb, c, jj, l), + atv(l, k, d, c)*att(aa, d, ii, l)*att(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + exprs = [ + atv(k, l, c, d)*att(aa, c, ii, k)*att(d, bb, jj, l), + atv(l, k, c, d)*att(aa, c, ii, l)*att(d, bb, jj, k), + atv(k, l, d, c)*att(aa, d, ii, k)*att(c, bb, jj, l), + atv(l, k, d, c)*att(aa, d, ii, l)*att(c, bb, jj, k), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + exprs = [ + atv(k, l, c, d)*att(c, aa, ii, k)*att(bb, d, jj, l), + atv(l, k, c, d)*att(c, aa, ii, l)*att(bb, d, jj, k), + atv(k, l, d, c)*att(d, aa, ii, k)*att(bb, c, jj, l), + atv(l, k, d, c)*att(d, aa, ii, l)*att(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_internal_external_pqrs_AT(): + ii, jj = symbols('i j') + aa, bb = symbols('a b') + k, l = symbols('k l', cls=Dummy) + c, d = symbols('c d', cls=Dummy) + + exprs = [ + atv(k, l, c, d)*att(aa, c, ii, k)*att(bb, d, jj, l), + atv(l, k, c, d)*att(aa, c, ii, l)*att(bb, d, jj, k), + atv(k, l, d, c)*att(aa, d, ii, k)*att(bb, c, jj, l), + atv(l, k, d, c)*att(aa, d, ii, l)*att(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_issue_19661(): + a = Symbol('0') + assert latex(Commutator(Bd(a)**2, B(a)) + ) == '- \\left[b_{0},{b^\\dagger_{0}}^{2}\\right]' + + +def test_canonical_ordering_AntiSymmetricTensor(): + v = symbols("v") + + c, d = symbols(('c','d'), above_fermi=True, + cls=Dummy) + k, l = symbols(('k','l'), below_fermi=True, + cls=Dummy) + + # formerly, the left gave either the left or the right + assert AntiSymmetricTensor(v, (k, l), (d, c) + ) == -AntiSymmetricTensor(v, (l, k), (d, c)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_sho.py b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_sho.py new file mode 100644 index 0000000000000000000000000000000000000000..7248838b4bb9ad280fd4211bbe208063b65adcf5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/tests/test_sho.py @@ -0,0 +1,21 @@ +from sympy.core import symbols, Rational, Function, diff +from sympy.physics.sho import R_nl, E_nl +from sympy.simplify.simplify import simplify + + +def test_sho_R_nl(): + omega, r = symbols('omega r') + l = symbols('l', integer=True) + u = Function('u') + + # check that it obeys the Schrodinger equation + for n in range(5): + schreq = ( -diff(u(r), r, 2)/2 + ((l*(l + 1))/(2*r**2) + + omega**2*r**2/2 - E_nl(n, l, omega))*u(r) ) + result = schreq.subs(u(r), r*R_nl(n, l, omega/2, r)) + assert simplify(result.doit()) == 0 + + +def test_energy(): + n, l, hw = symbols('n l hw') + assert simplify(E_nl(n, l, hw) - (2*n + l + Rational(3, 2))*hw) == 0 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf17c7f3051b03d9c0fc794d9d79885c94cc878e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/__init__.py @@ -0,0 +1,453 @@ +# isort:skip_file +""" +Dimensional analysis and unit systems. + +This module defines dimension/unit systems and physical quantities. It is +based on a group-theoretical construction where dimensions are represented as +vectors (coefficients being the exponents), and units are defined as a dimension +to which we added a scale. + +Quantities are built from a factor and a unit, and are the basic objects that +one will use when doing computations. + +All objects except systems and prefixes can be used in SymPy expressions. +Note that as part of a CAS, various objects do not combine automatically +under operations. + +Details about the implementation can be found in the documentation, and we +will not repeat all the explanations we gave there concerning our approach. +Ideas about future developments can be found on the `Github wiki +`_, and you should consult +this page if you are willing to help. + +Useful functions: + +- ``find_unit``: easily lookup pre-defined units. +- ``convert_to(expr, newunit)``: converts an expression into the same + expression expressed in another unit. + +""" + +from .dimensions import Dimension, DimensionSystem +from .unitsystem import UnitSystem +from .util import convert_to +from .quantities import Quantity + +from .definitions.dimension_definitions import ( + amount_of_substance, acceleration, action, area, + capacitance, charge, conductance, current, energy, + force, frequency, impedance, inductance, length, + luminous_intensity, magnetic_density, + magnetic_flux, mass, momentum, power, pressure, temperature, time, + velocity, voltage, volume +) + +Unit = Quantity + +speed = velocity +luminosity = luminous_intensity +magnetic_flux_density = magnetic_density +amount = amount_of_substance + +from .prefixes import ( + # 10-power based: + yotta, + zetta, + exa, + peta, + tera, + giga, + mega, + kilo, + hecto, + deca, + deci, + centi, + milli, + micro, + nano, + pico, + femto, + atto, + zepto, + yocto, + # 2-power based: + kibi, + mebi, + gibi, + tebi, + pebi, + exbi, +) + +from .definitions import ( + percent, percents, + permille, + rad, radian, radians, + deg, degree, degrees, + sr, steradian, steradians, + mil, angular_mil, angular_mils, + m, meter, meters, + kg, kilogram, kilograms, + s, second, seconds, + A, ampere, amperes, + K, kelvin, kelvins, + mol, mole, moles, + cd, candela, candelas, + g, gram, grams, + mg, milligram, milligrams, + ug, microgram, micrograms, + t, tonne, metric_ton, + newton, newtons, N, + joule, joules, J, + watt, watts, W, + pascal, pascals, Pa, pa, + hertz, hz, Hz, + coulomb, coulombs, C, + volt, volts, v, V, + ohm, ohms, + siemens, S, mho, mhos, + farad, farads, F, + henry, henrys, H, + tesla, teslas, T, + weber, webers, Wb, wb, + optical_power, dioptre, D, + lux, lx, + katal, kat, + gray, Gy, + becquerel, Bq, + km, kilometer, kilometers, + dm, decimeter, decimeters, + cm, centimeter, centimeters, + mm, millimeter, millimeters, + um, micrometer, micrometers, micron, microns, + nm, nanometer, nanometers, + pm, picometer, picometers, + ft, foot, feet, + inch, inches, + yd, yard, yards, + mi, mile, miles, + nmi, nautical_mile, nautical_miles, + angstrom, angstroms, + ha, hectare, + l, L, liter, liters, + dl, dL, deciliter, deciliters, + cl, cL, centiliter, centiliters, + ml, mL, milliliter, milliliters, + ms, millisecond, milliseconds, + us, microsecond, microseconds, + ns, nanosecond, nanoseconds, + ps, picosecond, picoseconds, + minute, minutes, + h, hour, hours, + day, days, + anomalistic_year, anomalistic_years, + sidereal_year, sidereal_years, + tropical_year, tropical_years, + common_year, common_years, + julian_year, julian_years, + draconic_year, draconic_years, + gaussian_year, gaussian_years, + full_moon_cycle, full_moon_cycles, + year, years, + G, gravitational_constant, + c, speed_of_light, + elementary_charge, + hbar, + planck, + eV, electronvolt, electronvolts, + avogadro_number, + avogadro, avogadro_constant, + boltzmann, boltzmann_constant, + stefan, stefan_boltzmann_constant, + R, molar_gas_constant, + faraday_constant, + josephson_constant, + von_klitzing_constant, + Da, dalton, amu, amus, atomic_mass_unit, atomic_mass_constant, + me, electron_rest_mass, + gee, gees, acceleration_due_to_gravity, + u0, magnetic_constant, vacuum_permeability, + e0, electric_constant, vacuum_permittivity, + Z0, vacuum_impedance, + coulomb_constant, electric_force_constant, + atmosphere, atmospheres, atm, + kPa, + bar, bars, + pound, pounds, + psi, + dHg0, + mmHg, torr, + mmu, mmus, milli_mass_unit, + quart, quarts, + ly, lightyear, lightyears, + au, astronomical_unit, astronomical_units, + planck_mass, + planck_time, + planck_temperature, + planck_length, + planck_charge, + planck_area, + planck_volume, + planck_momentum, + planck_energy, + planck_force, + planck_power, + planck_density, + planck_energy_density, + planck_intensity, + planck_angular_frequency, + planck_pressure, + planck_current, + planck_voltage, + planck_impedance, + planck_acceleration, + bit, bits, + byte, + kibibyte, kibibytes, + mebibyte, mebibytes, + gibibyte, gibibytes, + tebibyte, tebibytes, + pebibyte, pebibytes, + exbibyte, exbibytes, +) + +from .systems import ( + mks, mksa, si +) + + +def find_unit(quantity, unit_system="SI"): + """ + Return a list of matching units or dimension names. + + - If ``quantity`` is a string -- units/dimensions containing the string + `quantity`. + - If ``quantity`` is a unit or dimension -- units having matching base + units or dimensions. + + Examples + ======== + + >>> from sympy.physics import units as u + >>> u.find_unit('charge') + ['C', 'coulomb', 'coulombs', 'planck_charge', 'elementary_charge'] + >>> u.find_unit(u.charge) + ['C', 'coulomb', 'coulombs', 'planck_charge', 'elementary_charge'] + >>> u.find_unit("ampere") + ['ampere', 'amperes'] + >>> u.find_unit('angstrom') + ['angstrom', 'angstroms'] + >>> u.find_unit('volt') + ['volt', 'volts', 'electronvolt', 'electronvolts', 'planck_voltage'] + >>> u.find_unit(u.inch**3)[:9] + ['L', 'l', 'cL', 'cl', 'dL', 'dl', 'mL', 'ml', 'liter'] + """ + unit_system = UnitSystem.get_unit_system(unit_system) + + import sympy.physics.units as u + rv = [] + if isinstance(quantity, str): + rv = [i for i in dir(u) if quantity in i and isinstance(getattr(u, i), Quantity)] + dim = getattr(u, quantity) + if isinstance(dim, Dimension): + rv.extend(find_unit(dim)) + else: + for i in sorted(dir(u)): + other = getattr(u, i) + if not isinstance(other, Quantity): + continue + if isinstance(quantity, Quantity): + if quantity.dimension == other.dimension: + rv.append(str(i)) + elif isinstance(quantity, Dimension): + if other.dimension == quantity: + rv.append(str(i)) + elif other.dimension == Dimension(unit_system.get_dimensional_expr(quantity)): + rv.append(str(i)) + return sorted(set(rv), key=lambda x: (len(x), x)) + +# NOTE: the old units module had additional variables: +# 'density', 'illuminance', 'resistance'. +# They were not dimensions, but units (old Unit class). + +__all__ = [ + 'Dimension', 'DimensionSystem', + 'UnitSystem', + 'convert_to', + 'Quantity', + + 'amount_of_substance', 'acceleration', 'action', 'area', + 'capacitance', 'charge', 'conductance', 'current', 'energy', + 'force', 'frequency', 'impedance', 'inductance', 'length', + 'luminous_intensity', 'magnetic_density', + 'magnetic_flux', 'mass', 'momentum', 'power', 'pressure', 'temperature', 'time', + 'velocity', 'voltage', 'volume', + + 'Unit', + + 'speed', + 'luminosity', + 'magnetic_flux_density', + 'amount', + + 'yotta', + 'zetta', + 'exa', + 'peta', + 'tera', + 'giga', + 'mega', + 'kilo', + 'hecto', + 'deca', + 'deci', + 'centi', + 'milli', + 'micro', + 'nano', + 'pico', + 'femto', + 'atto', + 'zepto', + 'yocto', + + 'kibi', + 'mebi', + 'gibi', + 'tebi', + 'pebi', + 'exbi', + + 'percent', 'percents', + 'permille', + 'rad', 'radian', 'radians', + 'deg', 'degree', 'degrees', + 'sr', 'steradian', 'steradians', + 'mil', 'angular_mil', 'angular_mils', + 'm', 'meter', 'meters', + 'kg', 'kilogram', 'kilograms', + 's', 'second', 'seconds', + 'A', 'ampere', 'amperes', + 'K', 'kelvin', 'kelvins', + 'mol', 'mole', 'moles', + 'cd', 'candela', 'candelas', + 'g', 'gram', 'grams', + 'mg', 'milligram', 'milligrams', + 'ug', 'microgram', 'micrograms', + 't', 'tonne', 'metric_ton', + 'newton', 'newtons', 'N', + 'joule', 'joules', 'J', + 'watt', 'watts', 'W', + 'pascal', 'pascals', 'Pa', 'pa', + 'hertz', 'hz', 'Hz', + 'coulomb', 'coulombs', 'C', + 'volt', 'volts', 'v', 'V', + 'ohm', 'ohms', + 'siemens', 'S', 'mho', 'mhos', + 'farad', 'farads', 'F', + 'henry', 'henrys', 'H', + 'tesla', 'teslas', 'T', + 'weber', 'webers', 'Wb', 'wb', + 'optical_power', 'dioptre', 'D', + 'lux', 'lx', + 'katal', 'kat', + 'gray', 'Gy', + 'becquerel', 'Bq', + 'km', 'kilometer', 'kilometers', + 'dm', 'decimeter', 'decimeters', + 'cm', 'centimeter', 'centimeters', + 'mm', 'millimeter', 'millimeters', + 'um', 'micrometer', 'micrometers', 'micron', 'microns', + 'nm', 'nanometer', 'nanometers', + 'pm', 'picometer', 'picometers', + 'ft', 'foot', 'feet', + 'inch', 'inches', + 'yd', 'yard', 'yards', + 'mi', 'mile', 'miles', + 'nmi', 'nautical_mile', 'nautical_miles', + 'angstrom', 'angstroms', + 'ha', 'hectare', + 'l', 'L', 'liter', 'liters', + 'dl', 'dL', 'deciliter', 'deciliters', + 'cl', 'cL', 'centiliter', 'centiliters', + 'ml', 'mL', 'milliliter', 'milliliters', + 'ms', 'millisecond', 'milliseconds', + 'us', 'microsecond', 'microseconds', + 'ns', 'nanosecond', 'nanoseconds', + 'ps', 'picosecond', 'picoseconds', + 'minute', 'minutes', + 'h', 'hour', 'hours', + 'day', 'days', + 'anomalistic_year', 'anomalistic_years', + 'sidereal_year', 'sidereal_years', + 'tropical_year', 'tropical_years', + 'common_year', 'common_years', + 'julian_year', 'julian_years', + 'draconic_year', 'draconic_years', + 'gaussian_year', 'gaussian_years', + 'full_moon_cycle', 'full_moon_cycles', + 'year', 'years', + 'G', 'gravitational_constant', + 'c', 'speed_of_light', + 'elementary_charge', + 'hbar', + 'planck', + 'eV', 'electronvolt', 'electronvolts', + 'avogadro_number', + 'avogadro', 'avogadro_constant', + 'boltzmann', 'boltzmann_constant', + 'stefan', 'stefan_boltzmann_constant', + 'R', 'molar_gas_constant', + 'faraday_constant', + 'josephson_constant', + 'von_klitzing_constant', + 'Da', 'dalton', 'amu', 'amus', 'atomic_mass_unit', 'atomic_mass_constant', + 'me', 'electron_rest_mass', + 'gee', 'gees', 'acceleration_due_to_gravity', + 'u0', 'magnetic_constant', 'vacuum_permeability', + 'e0', 'electric_constant', 'vacuum_permittivity', + 'Z0', 'vacuum_impedance', + 'coulomb_constant', 'electric_force_constant', + 'atmosphere', 'atmospheres', 'atm', + 'kPa', + 'bar', 'bars', + 'pound', 'pounds', + 'psi', + 'dHg0', + 'mmHg', 'torr', + 'mmu', 'mmus', 'milli_mass_unit', + 'quart', 'quarts', + 'ly', 'lightyear', 'lightyears', + 'au', 'astronomical_unit', 'astronomical_units', + 'planck_mass', + 'planck_time', + 'planck_temperature', + 'planck_length', + 'planck_charge', + 'planck_area', + 'planck_volume', + 'planck_momentum', + 'planck_energy', + 'planck_force', + 'planck_power', + 'planck_density', + 'planck_energy_density', + 'planck_intensity', + 'planck_angular_frequency', + 'planck_pressure', + 'planck_current', + 'planck_voltage', + 'planck_impedance', + 'planck_acceleration', + 'bit', 'bits', + 'byte', + 'kibibyte', 'kibibytes', + 'mebibyte', 'mebibytes', + 'gibibyte', 'gibibytes', + 'tebibyte', 'tebibytes', + 'pebibyte', 'pebibytes', + 'exbibyte', 'exbibytes', + + 'mks', 'mksa', 'si', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/dimensions.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/dimensions.py new file mode 100644 index 0000000000000000000000000000000000000000..de42912edca025a6cb53d457fd3e03d8fa30931e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/dimensions.py @@ -0,0 +1,590 @@ +""" +Definition of physical dimensions. + +Unit systems will be constructed on top of these dimensions. + +Most of the examples in the doc use MKS system and are presented from the +computer point of view: from a human point, adding length to time is not legal +in MKS but it is in natural system; for a computer in natural system there is +no time dimension (but a velocity dimension instead) - in the basis - so the +question of adding time to length has no meaning. +""" + +from __future__ import annotations + +import collections +from functools import reduce + +from sympy.core.basic import Basic +from sympy.core.containers import (Dict, Tuple) +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.matrices.dense import Matrix +from sympy.functions.elementary.trigonometric import TrigonometricFunction +from sympy.core.expr import Expr +from sympy.core.power import Pow + + +class _QuantityMapper: + + _quantity_scale_factors_global: dict[Expr, Expr] = {} + _quantity_dimensional_equivalence_map_global: dict[Expr, Expr] = {} + _quantity_dimension_global: dict[Expr, Expr] = {} + + def __init__(self, *args, **kwargs): + self._quantity_dimension_map = {} + self._quantity_scale_factors = {} + + def set_quantity_dimension(self, quantity, dimension): + """ + Set the dimension for the quantity in a unit system. + + If this relation is valid in every unit system, use + ``quantity.set_global_dimension(dimension)`` instead. + """ + from sympy.physics.units import Quantity + dimension = sympify(dimension) + if not isinstance(dimension, Dimension): + if dimension == 1: + dimension = Dimension(1) + else: + raise ValueError("expected dimension or 1") + elif isinstance(dimension, Quantity): + dimension = self.get_quantity_dimension(dimension) + self._quantity_dimension_map[quantity] = dimension + + def set_quantity_scale_factor(self, quantity, scale_factor): + """ + Set the scale factor of a quantity relative to another quantity. + + It should be used only once per quantity to just one other quantity, + the algorithm will then be able to compute the scale factors to all + other quantities. + + In case the scale factor is valid in every unit system, please use + ``quantity.set_global_relative_scale_factor(scale_factor)`` instead. + """ + from sympy.physics.units import Quantity + from sympy.physics.units.prefixes import Prefix + scale_factor = sympify(scale_factor) + # replace all prefixes by their ratio to canonical units: + scale_factor = scale_factor.replace( + lambda x: isinstance(x, Prefix), + lambda x: x.scale_factor + ) + # replace all quantities by their ratio to canonical units: + scale_factor = scale_factor.replace( + lambda x: isinstance(x, Quantity), + lambda x: self.get_quantity_scale_factor(x) + ) + self._quantity_scale_factors[quantity] = scale_factor + + def get_quantity_dimension(self, unit): + from sympy.physics.units import Quantity + # First look-up the local dimension map, then the global one: + if unit in self._quantity_dimension_map: + return self._quantity_dimension_map[unit] + if unit in self._quantity_dimension_global: + return self._quantity_dimension_global[unit] + if unit in self._quantity_dimensional_equivalence_map_global: + dep_unit = self._quantity_dimensional_equivalence_map_global[unit] + if isinstance(dep_unit, Quantity): + return self.get_quantity_dimension(dep_unit) + else: + return Dimension(self.get_dimensional_expr(dep_unit)) + if isinstance(unit, Quantity): + return Dimension(unit.name) + else: + return Dimension(1) + + def get_quantity_scale_factor(self, unit): + if unit in self._quantity_scale_factors: + return self._quantity_scale_factors[unit] + if unit in self._quantity_scale_factors_global: + mul_factor, other_unit = self._quantity_scale_factors_global[unit] + return mul_factor*self.get_quantity_scale_factor(other_unit) + return S.One + + +class Dimension(Expr): + """ + This class represent the dimension of a physical quantities. + + The ``Dimension`` constructor takes as parameters a name and an optional + symbol. + + For example, in classical mechanics we know that time is different from + temperature and dimensions make this difference (but they do not provide + any measure of these quantities. + + >>> from sympy.physics.units import Dimension + >>> length = Dimension('length') + >>> length + Dimension(length) + >>> time = Dimension('time') + >>> time + Dimension(time) + + Dimensions can be composed using multiplication, division and + exponentiation (by a number) to give new dimensions. Addition and + subtraction is defined only when the two objects are the same dimension. + + >>> velocity = length / time + >>> velocity + Dimension(length/time) + + It is possible to use a dimension system object to get the dimensionsal + dependencies of a dimension, for example the dimension system used by the + SI units convention can be used: + + >>> from sympy.physics.units.systems.si import dimsys_SI + >>> dimsys_SI.get_dimensional_dependencies(velocity) + {Dimension(length, L): 1, Dimension(time, T): -1} + >>> length + length + Dimension(length) + >>> l2 = length**2 + >>> l2 + Dimension(length**2) + >>> dimsys_SI.get_dimensional_dependencies(l2) + {Dimension(length, L): 2} + + """ + + _op_priority = 13.0 + + # XXX: This doesn't seem to be used anywhere... + _dimensional_dependencies = {} # type: ignore + + is_commutative = True + is_number = False + # make sqrt(M**2) --> M + is_positive = True + is_real = True + + def __new__(cls, name, symbol=None): + + if isinstance(name, str): + name = Symbol(name) + else: + name = sympify(name) + + if not isinstance(name, Expr): + raise TypeError("Dimension name needs to be a valid math expression") + + if isinstance(symbol, str): + symbol = Symbol(symbol) + elif symbol is not None: + assert isinstance(symbol, Symbol) + + obj = Expr.__new__(cls, name) + + obj._name = name + obj._symbol = symbol + return obj + + @property + def name(self): + return self._name + + @property + def symbol(self): + return self._symbol + + def __str__(self): + """ + Display the string representation of the dimension. + """ + if self.symbol is None: + return "Dimension(%s)" % (self.name) + else: + return "Dimension(%s, %s)" % (self.name, self.symbol) + + def __repr__(self): + return self.__str__() + + def __neg__(self): + return self + + def __add__(self, other): + from sympy.physics.units.quantities import Quantity + other = sympify(other) + if isinstance(other, Basic): + if other.has(Quantity): + raise TypeError("cannot sum dimension and quantity") + if isinstance(other, Dimension) and self == other: + return self + return super().__add__(other) + return self + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + # there is no notion of ordering (or magnitude) among dimension, + # subtraction is equivalent to addition when the operation is legal + return self + other + + def __rsub__(self, other): + # there is no notion of ordering (or magnitude) among dimension, + # subtraction is equivalent to addition when the operation is legal + return self + other + + def __pow__(self, other): + return self._eval_power(other) + + def _eval_power(self, other): + other = sympify(other) + return Dimension(self.name**other) + + def __mul__(self, other): + from sympy.physics.units.quantities import Quantity + if isinstance(other, Basic): + if other.has(Quantity): + raise TypeError("cannot sum dimension and quantity") + if isinstance(other, Dimension): + return Dimension(self.name*other.name) + if not other.free_symbols: # other.is_number cannot be used + return self + return super().__mul__(other) + return self + + def __rmul__(self, other): + return self.__mul__(other) + + def __truediv__(self, other): + return self*Pow(other, -1) + + def __rtruediv__(self, other): + return other * pow(self, -1) + + @classmethod + def _from_dimensional_dependencies(cls, dependencies): + return reduce(lambda x, y: x * y, ( + d**e for d, e in dependencies.items() + ), 1) + + def has_integer_powers(self, dim_sys): + """ + Check if the dimension object has only integer powers. + + All the dimension powers should be integers, but rational powers may + appear in intermediate steps. This method may be used to check that the + final result is well-defined. + """ + + return all(dpow.is_Integer for dpow in dim_sys.get_dimensional_dependencies(self).values()) + + +# Create dimensions according to the base units in MKSA. +# For other unit systems, they can be derived by transforming the base +# dimensional dependency dictionary. + + +class DimensionSystem(Basic, _QuantityMapper): + r""" + DimensionSystem represents a coherent set of dimensions. + + The constructor takes three parameters: + + - base dimensions; + - derived dimensions: these are defined in terms of the base dimensions + (for example velocity is defined from the division of length by time); + - dependency of dimensions: how the derived dimensions depend + on the base dimensions. + + Optionally either the ``derived_dims`` or the ``dimensional_dependencies`` + may be omitted. + """ + + def __new__(cls, base_dims, derived_dims=(), dimensional_dependencies={}): + dimensional_dependencies = dict(dimensional_dependencies) + + def parse_dim(dim): + if isinstance(dim, str): + dim = Dimension(Symbol(dim)) + elif isinstance(dim, Dimension): + pass + elif isinstance(dim, Symbol): + dim = Dimension(dim) + else: + raise TypeError("%s wrong type" % dim) + return dim + + base_dims = [parse_dim(i) for i in base_dims] + derived_dims = [parse_dim(i) for i in derived_dims] + + for dim in base_dims: + if (dim in dimensional_dependencies + and (len(dimensional_dependencies[dim]) != 1 or + dimensional_dependencies[dim].get(dim, None) != 1)): + raise IndexError("Repeated value in base dimensions") + dimensional_dependencies[dim] = Dict({dim: 1}) + + def parse_dim_name(dim): + if isinstance(dim, Dimension): + return dim + elif isinstance(dim, str): + return Dimension(Symbol(dim)) + elif isinstance(dim, Symbol): + return Dimension(dim) + else: + raise TypeError("unrecognized type %s for %s" % (type(dim), dim)) + + for dim in dimensional_dependencies.keys(): + dim = parse_dim(dim) + if (dim not in derived_dims) and (dim not in base_dims): + derived_dims.append(dim) + + def parse_dict(d): + return Dict({parse_dim_name(i): j for i, j in d.items()}) + + # Make sure everything is a SymPy type: + dimensional_dependencies = {parse_dim_name(i): parse_dict(j) for i, j in + dimensional_dependencies.items()} + + for dim in derived_dims: + if dim in base_dims: + raise ValueError("Dimension %s both in base and derived" % dim) + if dim not in dimensional_dependencies: + # TODO: should this raise a warning? + dimensional_dependencies[dim] = Dict({dim: 1}) + + base_dims.sort(key=default_sort_key) + derived_dims.sort(key=default_sort_key) + + base_dims = Tuple(*base_dims) + derived_dims = Tuple(*derived_dims) + dimensional_dependencies = Dict({i: Dict(j) for i, j in dimensional_dependencies.items()}) + obj = Basic.__new__(cls, base_dims, derived_dims, dimensional_dependencies) + return obj + + @property + def base_dims(self): + return self.args[0] + + @property + def derived_dims(self): + return self.args[1] + + @property + def dimensional_dependencies(self): + return self.args[2] + + def _get_dimensional_dependencies_for_name(self, dimension): + if isinstance(dimension, str): + dimension = Dimension(Symbol(dimension)) + elif not isinstance(dimension, Dimension): + dimension = Dimension(dimension) + + if dimension.name.is_Symbol: + # Dimensions not included in the dependencies are considered + # as base dimensions: + return dict(self.dimensional_dependencies.get(dimension, {dimension: 1})) + + if dimension.name.is_number or dimension.name.is_NumberSymbol: + return {} + + get_for_name = self._get_dimensional_dependencies_for_name + + if dimension.name.is_Mul: + ret = collections.defaultdict(int) + dicts = [get_for_name(i) for i in dimension.name.args] + for d in dicts: + for k, v in d.items(): + ret[k] += v + return {k: v for (k, v) in ret.items() if v != 0} + + if dimension.name.is_Add: + dicts = [get_for_name(i) for i in dimension.name.args] + if all(d == dicts[0] for d in dicts[1:]): + return dicts[0] + raise TypeError("Only equivalent dimensions can be added or subtracted.") + + if dimension.name.is_Pow: + dim_base = get_for_name(dimension.name.base) + dim_exp = get_for_name(dimension.name.exp) + if dim_exp == {} or dimension.name.exp.is_Symbol: + return {k: v * dimension.name.exp for (k, v) in dim_base.items()} + else: + raise TypeError("The exponent for the power operator must be a Symbol or dimensionless.") + + if dimension.name.is_Function: + args = (Dimension._from_dimensional_dependencies( + get_for_name(arg)) for arg in dimension.name.args) + result = dimension.name.func(*args) + + dicts = [get_for_name(i) for i in dimension.name.args] + + if isinstance(result, Dimension): + return self.get_dimensional_dependencies(result) + elif result.func == dimension.name.func: + if isinstance(dimension.name, TrigonometricFunction): + if dicts[0] in ({}, {Dimension('angle'): 1}): + return {} + else: + raise TypeError("The input argument for the function {} must be dimensionless or have dimensions of angle.".format(dimension.func)) + else: + if all(item == {} for item in dicts): + return {} + else: + raise TypeError("The input arguments for the function {} must be dimensionless.".format(dimension.func)) + else: + return get_for_name(result) + + raise TypeError("Type {} not implemented for get_dimensional_dependencies".format(type(dimension.name))) + + def get_dimensional_dependencies(self, name, mark_dimensionless=False): + dimdep = self._get_dimensional_dependencies_for_name(name) + if mark_dimensionless and dimdep == {}: + return {Dimension(1): 1} + return dict(dimdep.items()) + + def equivalent_dims(self, dim1, dim2): + deps1 = self.get_dimensional_dependencies(dim1) + deps2 = self.get_dimensional_dependencies(dim2) + return deps1 == deps2 + + def extend(self, new_base_dims, new_derived_dims=(), new_dim_deps=None): + deps = dict(self.dimensional_dependencies) + if new_dim_deps: + deps.update(new_dim_deps) + + new_dim_sys = DimensionSystem( + tuple(self.base_dims) + tuple(new_base_dims), + tuple(self.derived_dims) + tuple(new_derived_dims), + deps + ) + new_dim_sys._quantity_dimension_map.update(self._quantity_dimension_map) + new_dim_sys._quantity_scale_factors.update(self._quantity_scale_factors) + return new_dim_sys + + def is_dimensionless(self, dimension): + """ + Check if the dimension object really has a dimension. + + A dimension should have at least one component with non-zero power. + """ + if dimension.name == 1: + return True + return self.get_dimensional_dependencies(dimension) == {} + + @property + def list_can_dims(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + List all canonical dimension names. + """ + dimset = set() + for i in self.base_dims: + dimset.update(set(self.get_dimensional_dependencies(i).keys())) + return tuple(sorted(dimset, key=str)) + + @property + def inv_can_transf_matrix(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Compute the inverse transformation matrix from the base to the + canonical dimension basis. + + It corresponds to the matrix where columns are the vector of base + dimensions in canonical basis. + + This matrix will almost never be used because dimensions are always + defined with respect to the canonical basis, so no work has to be done + to get them in this basis. Nonetheless if this matrix is not square + (or not invertible) it means that we have chosen a bad basis. + """ + matrix = reduce(lambda x, y: x.row_join(y), + [self.dim_can_vector(d) for d in self.base_dims]) + return matrix + + @property + def can_transf_matrix(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Return the canonical transformation matrix from the canonical to the + base dimension basis. + + It is the inverse of the matrix computed with inv_can_transf_matrix(). + """ + + #TODO: the inversion will fail if the system is inconsistent, for + # example if the matrix is not a square + return reduce(lambda x, y: x.row_join(y), + [self.dim_can_vector(d) for d in sorted(self.base_dims, key=str)] + ).inv() + + def dim_can_vector(self, dim): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Dimensional representation in terms of the canonical base dimensions. + """ + + vec = [] + for d in self.list_can_dims: + vec.append(self.get_dimensional_dependencies(dim).get(d, 0)) + return Matrix(vec) + + def dim_vector(self, dim): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + + Vector representation in terms of the base dimensions. + """ + return self.can_transf_matrix * Matrix(self.dim_can_vector(dim)) + + def print_dim_base(self, dim): + """ + Give the string expression of a dimension in term of the basis symbols. + """ + dims = self.dim_vector(dim) + symbols = [i.symbol if i.symbol is not None else i.name for i in self.base_dims] + res = S.One + for (s, p) in zip(symbols, dims): + res *= s**p + return res + + @property + def dim(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Give the dimension of the system. + + That is return the number of dimensions forming the basis. + """ + return len(self.base_dims) + + @property + def is_consistent(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Check if the system is well defined. + """ + + # not enough or too many base dimensions compared to independent + # dimensions + # in vector language: the set of vectors do not form a basis + return self.inv_can_transf_matrix.is_square diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/prefixes.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/prefixes.py new file mode 100644 index 0000000000000000000000000000000000000000..44fd7cb9efe4b1d6307810af6b9cd140817126f9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/prefixes.py @@ -0,0 +1,219 @@ +""" +Module defining unit prefixe class and some constants. + +Constant dict for SI and binary prefixes are defined as PREFIXES and +BIN_PREFIXES. +""" +from sympy.core.expr import Expr +from sympy.core.sympify import sympify +from sympy.core.singleton import S + +class Prefix(Expr): + """ + This class represent prefixes, with their name, symbol and factor. + + Prefixes are used to create derived units from a given unit. They should + always be encapsulated into units. + + The factor is constructed from a base (default is 10) to some power, and + it gives the total multiple or fraction. For example the kilometer km + is constructed from the meter (factor 1) and the kilo (10 to the power 3, + i.e. 1000). The base can be changed to allow e.g. binary prefixes. + + A prefix multiplied by something will always return the product of this + other object times the factor, except if the other object: + + - is a prefix and they can be combined into a new prefix; + - defines multiplication with prefixes (which is the case for the Unit + class). + """ + _op_priority = 13.0 + is_commutative = True + + def __new__(cls, name, abbrev, exponent, base=sympify(10), latex_repr=None): + + name = sympify(name) + abbrev = sympify(abbrev) + exponent = sympify(exponent) + base = sympify(base) + + obj = Expr.__new__(cls, name, abbrev, exponent, base) + obj._name = name + obj._abbrev = abbrev + obj._scale_factor = base**exponent + obj._exponent = exponent + obj._base = base + obj._latex_repr = latex_repr + return obj + + @property + def name(self): + return self._name + + @property + def abbrev(self): + return self._abbrev + + @property + def scale_factor(self): + return self._scale_factor + + def _latex(self, printer): + if self._latex_repr is None: + return r'\text{%s}' % self._abbrev + return self._latex_repr + + @property + def base(self): + return self._base + + def __str__(self): + return str(self._abbrev) + + def __repr__(self): + if self.base == 10: + return "Prefix(%r, %r, %r)" % ( + str(self.name), str(self.abbrev), self._exponent) + else: + return "Prefix(%r, %r, %r, %r)" % ( + str(self.name), str(self.abbrev), self._exponent, self.base) + + def __mul__(self, other): + from sympy.physics.units import Quantity + if not isinstance(other, (Quantity, Prefix)): + return super().__mul__(other) + + fact = self.scale_factor * other.scale_factor + + if isinstance(other, Prefix): + if fact == 1: + return S.One + # simplify prefix + for p in PREFIXES: + if PREFIXES[p].scale_factor == fact: + return PREFIXES[p] + return fact + + return self.scale_factor * other + + def __truediv__(self, other): + if not hasattr(other, "scale_factor"): + return super().__truediv__(other) + + fact = self.scale_factor / other.scale_factor + + if fact == 1: + return S.One + elif isinstance(other, Prefix): + for p in PREFIXES: + if PREFIXES[p].scale_factor == fact: + return PREFIXES[p] + return fact + + return self.scale_factor / other + + def __rtruediv__(self, other): + if other == 1: + for p in PREFIXES: + if PREFIXES[p].scale_factor == 1 / self.scale_factor: + return PREFIXES[p] + return other / self.scale_factor + + +def prefix_unit(unit, prefixes): + """ + Return a list of all units formed by unit and the given prefixes. + + You can use the predefined PREFIXES or BIN_PREFIXES, but you can also + pass as argument a subdict of them if you do not want all prefixed units. + + >>> from sympy.physics.units.prefixes import (PREFIXES, + ... prefix_unit) + >>> from sympy.physics.units import m + >>> pref = {"m": PREFIXES["m"], "c": PREFIXES["c"], "d": PREFIXES["d"]} + >>> prefix_unit(m, pref) # doctest: +SKIP + [millimeter, centimeter, decimeter] + """ + + from sympy.physics.units.quantities import Quantity + from sympy.physics.units import UnitSystem + + prefixed_units = [] + + for prefix in prefixes.values(): + quantity = Quantity( + "%s%s" % (prefix.name, unit.name), + abbrev=("%s%s" % (prefix.abbrev, unit.abbrev)), + is_prefixed=True, + ) + UnitSystem._quantity_dimensional_equivalence_map_global[quantity] = unit + UnitSystem._quantity_scale_factors_global[quantity] = (prefix.scale_factor, unit) + prefixed_units.append(quantity) + + return prefixed_units + + +yotta = Prefix('yotta', 'Y', 24) +zetta = Prefix('zetta', 'Z', 21) +exa = Prefix('exa', 'E', 18) +peta = Prefix('peta', 'P', 15) +tera = Prefix('tera', 'T', 12) +giga = Prefix('giga', 'G', 9) +mega = Prefix('mega', 'M', 6) +kilo = Prefix('kilo', 'k', 3) +hecto = Prefix('hecto', 'h', 2) +deca = Prefix('deca', 'da', 1) +deci = Prefix('deci', 'd', -1) +centi = Prefix('centi', 'c', -2) +milli = Prefix('milli', 'm', -3) +micro = Prefix('micro', 'mu', -6, latex_repr=r"\mu") +nano = Prefix('nano', 'n', -9) +pico = Prefix('pico', 'p', -12) +femto = Prefix('femto', 'f', -15) +atto = Prefix('atto', 'a', -18) +zepto = Prefix('zepto', 'z', -21) +yocto = Prefix('yocto', 'y', -24) + + +# https://physics.nist.gov/cuu/Units/prefixes.html +PREFIXES = { + 'Y': yotta, + 'Z': zetta, + 'E': exa, + 'P': peta, + 'T': tera, + 'G': giga, + 'M': mega, + 'k': kilo, + 'h': hecto, + 'da': deca, + 'd': deci, + 'c': centi, + 'm': milli, + 'mu': micro, + 'n': nano, + 'p': pico, + 'f': femto, + 'a': atto, + 'z': zepto, + 'y': yocto, +} + + +kibi = Prefix('kibi', 'Y', 10, 2) +mebi = Prefix('mebi', 'Y', 20, 2) +gibi = Prefix('gibi', 'Y', 30, 2) +tebi = Prefix('tebi', 'Y', 40, 2) +pebi = Prefix('pebi', 'Y', 50, 2) +exbi = Prefix('exbi', 'Y', 60, 2) + + +# https://physics.nist.gov/cuu/Units/binary.html +BIN_PREFIXES = { + 'Ki': kibi, + 'Mi': mebi, + 'Gi': gibi, + 'Ti': tebi, + 'Pi': pebi, + 'Ei': exbi, +} diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/quantities.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/quantities.py new file mode 100644 index 0000000000000000000000000000000000000000..cc19e72aea83b5bd8ae7cf2f63dd49388a3815ee --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/quantities.py @@ -0,0 +1,152 @@ +""" +Physical quantities. +""" + +from sympy.core.expr import AtomicExpr +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.physics.units.dimensions import _QuantityMapper +from sympy.physics.units.prefixes import Prefix + + +class Quantity(AtomicExpr): + """ + Physical quantity: can be a unit of measure, a constant or a generic quantity. + """ + + is_commutative = True + is_real = True + is_number = False + is_nonzero = True + is_physical_constant = False + _diff_wrt = True + + def __new__(cls, name, abbrev=None, + latex_repr=None, pretty_unicode_repr=None, + pretty_ascii_repr=None, mathml_presentation_repr=None, + is_prefixed=False, + **assumptions): + + if not isinstance(name, Symbol): + name = Symbol(name) + + if abbrev is None: + abbrev = name + elif isinstance(abbrev, str): + abbrev = Symbol(abbrev) + + # HACK: These are here purely for type checking. They actually get assigned below. + cls._is_prefixed = is_prefixed + + obj = AtomicExpr.__new__(cls, name, abbrev) + obj._name = name + obj._abbrev = abbrev + obj._latex_repr = latex_repr + obj._unicode_repr = pretty_unicode_repr + obj._ascii_repr = pretty_ascii_repr + obj._mathml_repr = mathml_presentation_repr + obj._is_prefixed = is_prefixed + return obj + + def set_global_dimension(self, dimension): + _QuantityMapper._quantity_dimension_global[self] = dimension + + def set_global_relative_scale_factor(self, scale_factor, reference_quantity): + """ + Setting a scale factor that is valid across all unit system. + """ + from sympy.physics.units import UnitSystem + scale_factor = sympify(scale_factor) + if isinstance(scale_factor, Prefix): + self._is_prefixed = True + # replace all prefixes by their ratio to canonical units: + scale_factor = scale_factor.replace( + lambda x: isinstance(x, Prefix), + lambda x: x.scale_factor + ) + scale_factor = sympify(scale_factor) + UnitSystem._quantity_scale_factors_global[self] = (scale_factor, reference_quantity) + UnitSystem._quantity_dimensional_equivalence_map_global[self] = reference_quantity + + @property + def name(self): + return self._name + + @property + def dimension(self): + from sympy.physics.units import UnitSystem + unit_system = UnitSystem.get_default_unit_system() + return unit_system.get_quantity_dimension(self) + + @property + def abbrev(self): + """ + Symbol representing the unit name. + + Prepend the abbreviation with the prefix symbol if it is defines. + """ + return self._abbrev + + @property + def scale_factor(self): + """ + Overall magnitude of the quantity as compared to the canonical units. + """ + from sympy.physics.units import UnitSystem + unit_system = UnitSystem.get_default_unit_system() + return unit_system.get_quantity_scale_factor(self) + + def _eval_is_positive(self): + return True + + def _eval_is_constant(self): + return True + + def _eval_Abs(self): + return self + + def _eval_subs(self, old, new): + if isinstance(new, Quantity) and self != old: + return self + + def _latex(self, printer): + if self._latex_repr: + return self._latex_repr + else: + return r'\text{{{}}}'.format(self.args[1] \ + if len(self.args) >= 2 else self.args[0]) + + def convert_to(self, other, unit_system="SI"): + """ + Convert the quantity to another quantity of same dimensions. + + Examples + ======== + + >>> from sympy.physics.units import speed_of_light, meter, second + >>> speed_of_light + speed_of_light + >>> speed_of_light.convert_to(meter/second) + 299792458*meter/second + + >>> from sympy.physics.units import liter + >>> liter.convert_to(meter**3) + meter**3/1000 + """ + from .util import convert_to + return convert_to(self, other, unit_system) + + @property + def free_symbols(self): + """Return free symbols from quantity.""" + return set() + + @property + def is_prefixed(self): + """Whether or not the quantity is prefixed. Eg. `kilogram` is prefixed, but `gram` is not.""" + return self._is_prefixed + +class PhysicalConstant(Quantity): + """Represents a physical constant, eg. `speed_of_light` or `avogadro_constant`.""" + + is_physical_constant = True diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/unitsystem.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/unitsystem.py new file mode 100644 index 0000000000000000000000000000000000000000..795f8026e9df7236fdb2abf882043a843797219d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/unitsystem.py @@ -0,0 +1,204 @@ +""" +Unit system for physical quantities; include definition of constants. +""" +from __future__ import annotations + +from sympy.core.add import Add +from sympy.core.function import (Derivative, Function) +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.physics.units.dimensions import _QuantityMapper +from sympy.physics.units.quantities import Quantity + +from .dimensions import Dimension + + +class UnitSystem(_QuantityMapper): + """ + UnitSystem represents a coherent set of units. + + A unit system is basically a dimension system with notions of scales. Many + of the methods are defined in the same way. + + It is much better if all base units have a symbol. + """ + + _unit_systems: dict[str, UnitSystem] = {} + + def __init__(self, base_units, units=(), name="", descr="", dimension_system=None, derived_units: dict[Dimension, Quantity]={}): + + UnitSystem._unit_systems[name] = self + + self.name = name + self.descr = descr + + self._base_units = base_units + self._dimension_system = dimension_system + self._units = tuple(set(base_units) | set(units)) + self._base_units = tuple(base_units) + self._derived_units = derived_units + + super().__init__() + + def __str__(self): + """ + Return the name of the system. + + If it does not exist, then it makes a list of symbols (or names) of + the base dimensions. + """ + + if self.name != "": + return self.name + else: + return "UnitSystem((%s))" % ", ".join( + str(d) for d in self._base_units) + + def __repr__(self): + return '' % repr(self._base_units) + + def extend(self, base, units=(), name="", description="", dimension_system=None, derived_units: dict[Dimension, Quantity]={}): + """Extend the current system into a new one. + + Take the base and normal units of the current system to merge + them to the base and normal units given in argument. + If not provided, name and description are overridden by empty strings. + """ + + base = self._base_units + tuple(base) + units = self._units + tuple(units) + + return UnitSystem(base, units, name, description, dimension_system, {**self._derived_units, **derived_units}) + + def get_dimension_system(self): + return self._dimension_system + + def get_quantity_dimension(self, unit): + qdm = self.get_dimension_system()._quantity_dimension_map + if unit in qdm: + return qdm[unit] + return super().get_quantity_dimension(unit) + + def get_quantity_scale_factor(self, unit): + qsfm = self.get_dimension_system()._quantity_scale_factors + if unit in qsfm: + return qsfm[unit] + return super().get_quantity_scale_factor(unit) + + @staticmethod + def get_unit_system(unit_system): + if isinstance(unit_system, UnitSystem): + return unit_system + + if unit_system not in UnitSystem._unit_systems: + raise ValueError( + "Unit system is not supported. Currently" + "supported unit systems are {}".format( + ", ".join(sorted(UnitSystem._unit_systems)) + ) + ) + + return UnitSystem._unit_systems[unit_system] + + @staticmethod + def get_default_unit_system(): + return UnitSystem._unit_systems["SI"] + + @property + def dim(self): + """ + Give the dimension of the system. + + That is return the number of units forming the basis. + """ + return len(self._base_units) + + @property + def is_consistent(self): + """ + Check if the underlying dimension system is consistent. + """ + # test is performed in DimensionSystem + return self.get_dimension_system().is_consistent + + @property + def derived_units(self) -> dict[Dimension, Quantity]: + return self._derived_units + + def get_dimensional_expr(self, expr): + from sympy.physics.units import Quantity + if isinstance(expr, Mul): + return Mul(*[self.get_dimensional_expr(i) for i in expr.args]) + elif isinstance(expr, Pow): + return self.get_dimensional_expr(expr.base) ** expr.exp + elif isinstance(expr, Add): + return self.get_dimensional_expr(expr.args[0]) + elif isinstance(expr, Derivative): + dim = self.get_dimensional_expr(expr.expr) + for independent, count in expr.variable_count: + dim /= self.get_dimensional_expr(independent)**count + return dim + elif isinstance(expr, Function): + args = [self.get_dimensional_expr(arg) for arg in expr.args] + if all(i == 1 for i in args): + return S.One + return expr.func(*args) + elif isinstance(expr, Quantity): + return self.get_quantity_dimension(expr).name + return S.One + + def _collect_factor_and_dimension(self, expr): + """ + Return tuple with scale factor expression and dimension expression. + """ + from sympy.physics.units import Quantity + if isinstance(expr, Quantity): + return expr.scale_factor, expr.dimension + elif isinstance(expr, Mul): + factor = 1 + dimension = Dimension(1) + for arg in expr.args: + arg_factor, arg_dim = self._collect_factor_and_dimension(arg) + factor *= arg_factor + dimension *= arg_dim + return factor, dimension + elif isinstance(expr, Pow): + factor, dim = self._collect_factor_and_dimension(expr.base) + exp_factor, exp_dim = self._collect_factor_and_dimension(expr.exp) + if self.get_dimension_system().is_dimensionless(exp_dim): + exp_dim = 1 + return factor ** exp_factor, dim ** (exp_factor * exp_dim) + elif isinstance(expr, Add): + factor, dim = self._collect_factor_and_dimension(expr.args[0]) + for addend in expr.args[1:]: + addend_factor, addend_dim = \ + self._collect_factor_and_dimension(addend) + if not self.get_dimension_system().equivalent_dims(dim, addend_dim): + raise ValueError( + 'Dimension of "{}" is {}, ' + 'but it should be {}'.format( + addend, addend_dim, dim)) + factor += addend_factor + return factor, dim + elif isinstance(expr, Derivative): + factor, dim = self._collect_factor_and_dimension(expr.args[0]) + for independent, count in expr.variable_count: + ifactor, idim = self._collect_factor_and_dimension(independent) + factor /= ifactor**count + dim /= idim**count + return factor, dim + elif isinstance(expr, Function): + fds = [self._collect_factor_and_dimension(arg) for arg in expr.args] + dims = [Dimension(1) if self.get_dimension_system().is_dimensionless(d[1]) else d[1] for d in fds] + return (expr.func(*(f[0] for f in fds)), *dims) + elif isinstance(expr, Dimension): + return S.One, expr + else: + return expr, Dimension(1) + + def get_units_non_prefixed(self) -> set[Quantity]: + """ + Return the units of the system that do not have a prefix. + """ + return set(filter(lambda u: not u.is_prefixed and not u.is_physical_constant, self._units)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/util.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/util.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd6300acdb1a3c60b74076d4700e7f699ca46f5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/util.py @@ -0,0 +1,265 @@ +""" +Several methods to simplify expressions involving unit objects. +""" +from functools import reduce +from collections.abc import Iterable +from typing import Optional + +from sympy import default_sort_key +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.sorting import ordered +from sympy.core.sympify import sympify +from sympy.core.function import Function +from sympy.matrices.exceptions import NonInvertibleMatrixError +from sympy.physics.units.dimensions import Dimension, DimensionSystem +from sympy.physics.units.prefixes import Prefix +from sympy.physics.units.quantities import Quantity +from sympy.physics.units.unitsystem import UnitSystem +from sympy.utilities.iterables import sift + + +def _get_conversion_matrix_for_expr(expr, target_units, unit_system): + from sympy.matrices.dense import Matrix + + dimension_system = unit_system.get_dimension_system() + + expr_dim = Dimension(unit_system.get_dimensional_expr(expr)) + dim_dependencies = dimension_system.get_dimensional_dependencies(expr_dim, mark_dimensionless=True) + target_dims = [Dimension(unit_system.get_dimensional_expr(x)) for x in target_units] + canon_dim_units = [i for x in target_dims for i in dimension_system.get_dimensional_dependencies(x, mark_dimensionless=True)] + canon_expr_units = set(dim_dependencies) + + if not canon_expr_units.issubset(set(canon_dim_units)): + return None + + seen = set() + canon_dim_units = [i for i in canon_dim_units if not (i in seen or seen.add(i))] + + camat = Matrix([[dimension_system.get_dimensional_dependencies(i, mark_dimensionless=True).get(j, 0) for i in target_dims] for j in canon_dim_units]) + exprmat = Matrix([dim_dependencies.get(k, 0) for k in canon_dim_units]) + + try: + res_exponents = camat.solve(exprmat) + except NonInvertibleMatrixError: + return None + + return res_exponents + + +def convert_to(expr, target_units, unit_system="SI"): + """ + Convert ``expr`` to the same expression with all of its units and quantities + represented as factors of ``target_units``, whenever the dimension is compatible. + + ``target_units`` may be a single unit/quantity, or a collection of + units/quantities. + + Examples + ======== + + >>> from sympy.physics.units import speed_of_light, meter, gram, second, day + >>> from sympy.physics.units import mile, newton, kilogram, atomic_mass_constant + >>> from sympy.physics.units import kilometer, centimeter + >>> from sympy.physics.units import gravitational_constant, hbar + >>> from sympy.physics.units import convert_to + >>> convert_to(mile, kilometer) + 25146*kilometer/15625 + >>> convert_to(mile, kilometer).n() + 1.609344*kilometer + >>> convert_to(speed_of_light, meter/second) + 299792458*meter/second + >>> convert_to(day, second) + 86400*second + >>> 3*newton + 3*newton + >>> convert_to(3*newton, kilogram*meter/second**2) + 3*kilogram*meter/second**2 + >>> convert_to(atomic_mass_constant, gram) + 1.660539060e-24*gram + + Conversion to multiple units: + + >>> convert_to(speed_of_light, [meter, second]) + 299792458*meter/second + >>> convert_to(3*newton, [centimeter, gram, second]) + 300000*centimeter*gram/second**2 + + Conversion to Planck units: + + >>> convert_to(atomic_mass_constant, [gravitational_constant, speed_of_light, hbar]).n() + 7.62963087839509e-20*hbar**0.5*speed_of_light**0.5/gravitational_constant**0.5 + + """ + from sympy.physics.units import UnitSystem + unit_system = UnitSystem.get_unit_system(unit_system) + + if not isinstance(target_units, (Iterable, Tuple)): + target_units = [target_units] + + def handle_Adds(expr): + return Add.fromiter(convert_to(i, target_units, unit_system) + for i in expr.args) + + if isinstance(expr, Add): + return handle_Adds(expr) + elif isinstance(expr, Pow) and isinstance(expr.base, Add): + return handle_Adds(expr.base) ** expr.exp + + expr = sympify(expr) + target_units = sympify(target_units) + + if isinstance(expr, Function): + expr = expr.together() + + if not isinstance(expr, Quantity) and expr.has(Quantity): + expr = expr.replace(lambda x: isinstance(x, Quantity), + lambda x: x.convert_to(target_units, unit_system)) + + def get_total_scale_factor(expr): + if isinstance(expr, Mul): + return reduce(lambda x, y: x * y, + [get_total_scale_factor(i) for i in expr.args]) + elif isinstance(expr, Pow): + return get_total_scale_factor(expr.base) ** expr.exp + elif isinstance(expr, Quantity): + return unit_system.get_quantity_scale_factor(expr) + return expr + + depmat = _get_conversion_matrix_for_expr(expr, target_units, unit_system) + if depmat is None: + return expr + + expr_scale_factor = get_total_scale_factor(expr) + return expr_scale_factor * Mul.fromiter( + (1/get_total_scale_factor(u)*u)**p for u, p in + zip(target_units, depmat)) + + +def quantity_simplify(expr, across_dimensions: bool=False, unit_system=None): + """Return an equivalent expression in which prefixes are replaced + with numerical values and all units of a given dimension are the + unified in a canonical manner by default. `across_dimensions` allows + for units of different dimensions to be simplified together. + + `unit_system` must be specified if `across_dimensions` is True. + + Examples + ======== + + >>> from sympy.physics.units.util import quantity_simplify + >>> from sympy.physics.units.prefixes import kilo + >>> from sympy.physics.units import foot, inch, joule, coulomb + >>> quantity_simplify(kilo*foot*inch) + 250*foot**2/3 + >>> quantity_simplify(foot - 6*inch) + foot/2 + >>> quantity_simplify(5*joule/coulomb, across_dimensions=True, unit_system="SI") + 5*volt + """ + + if expr.is_Atom or not expr.has(Prefix, Quantity): + return expr + + # replace all prefixes with numerical values + p = expr.atoms(Prefix) + expr = expr.xreplace({p: p.scale_factor for p in p}) + + # replace all quantities of given dimension with a canonical + # quantity, chosen from those in the expression + d = sift(expr.atoms(Quantity), lambda i: i.dimension) + for k in d: + if len(d[k]) == 1: + continue + v = list(ordered(d[k])) + ref = v[0]/v[0].scale_factor + expr = expr.xreplace({vi: ref*vi.scale_factor for vi in v[1:]}) + + if across_dimensions: + # combine quantities of different dimensions into a single + # quantity that is equivalent to the original expression + + if unit_system is None: + raise ValueError("unit_system must be specified if across_dimensions is True") + + unit_system = UnitSystem.get_unit_system(unit_system) + dimension_system: DimensionSystem = unit_system.get_dimension_system() + dim_expr = unit_system.get_dimensional_expr(expr) + dim_deps = dimension_system.get_dimensional_dependencies(dim_expr, mark_dimensionless=True) + + target_dimension: Optional[Dimension] = None + for ds_dim, ds_dim_deps in dimension_system.dimensional_dependencies.items(): + if ds_dim_deps == dim_deps: + target_dimension = ds_dim + break + + if target_dimension is None: + # if we can't find a target dimension, we can't do anything. unsure how to handle this case. + return expr + + target_unit = unit_system.derived_units.get(target_dimension) + if target_unit: + expr = convert_to(expr, target_unit, unit_system) + + return expr + + +def check_dimensions(expr, unit_system="SI"): + """Return expr if units in addends have the same + base dimensions, else raise a ValueError.""" + # the case of adding a number to a dimensional quantity + # is ignored for the sake of SymPy core routines, so this + # function will raise an error now if such an addend is + # found. + # Also, when doing substitutions, multiplicative constants + # might be introduced, so remove those now + + from sympy.physics.units import UnitSystem + unit_system = UnitSystem.get_unit_system(unit_system) + + def addDict(dict1, dict2): + """Merge dictionaries by adding values of common keys and + removing keys with value of 0.""" + dict3 = {**dict1, **dict2} + for key, value in dict3.items(): + if key in dict1 and key in dict2: + dict3[key] = value + dict1[key] + return {key:val for key, val in dict3.items() if val != 0} + + adds = expr.atoms(Add) + DIM_OF = unit_system.get_dimension_system().get_dimensional_dependencies + for a in adds: + deset = set() + for ai in a.args: + if ai.is_number: + deset.add(()) + continue + dims = [] + skip = False + dimdict = {} + for i in Mul.make_args(ai): + if i.has(Quantity): + i = Dimension(unit_system.get_dimensional_expr(i)) + if i.has(Dimension): + dimdict = addDict(dimdict, DIM_OF(i)) + elif i.free_symbols: + skip = True + break + dims.extend(dimdict.items()) + if not skip: + deset.add(tuple(sorted(dims, key=default_sort_key))) + if len(deset) > 1: + raise ValueError( + "addends have incompatible dimensions: {}".format(deset)) + + # clear multiplicative constants on Dimensions which may be + # left after substitution + reps = {} + for m in expr.atoms(Mul): + if any(isinstance(i, Dimension) for i in m.args): + reps[m] = m.func(*[ + i for i in m.args if not i.is_number]) + + return expr.xreplace(reps) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/wigner.py b/.venv/lib/python3.13/site-packages/sympy/physics/wigner.py new file mode 100644 index 0000000000000000000000000000000000000000..e08f3fb4a480439fd2bb1f8ff8c305bf69d7abae --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/wigner.py @@ -0,0 +1,1213 @@ +# -*- coding: utf-8 -*- +r""" +Wigner, Clebsch-Gordan, Racah, and Gaunt coefficients + +Collection of functions for calculating Wigner 3j, 6j, 9j, +Clebsch-Gordan, Racah as well as Gaunt coefficients exactly, all +evaluating to a rational number times the square root of a rational +number [Rasch03]_. + +Please see the description of the individual functions for further +details and examples. + +References +========== + +.. [Regge58] 'Symmetry Properties of Clebsch-Gordan Coefficients', + T. Regge, Nuovo Cimento, Volume 10, pp. 544 (1958) +.. [Regge59] 'Symmetry Properties of Racah Coefficients', + T. Regge, Nuovo Cimento, Volume 11, pp. 116 (1959) +.. [Edmonds74] A. R. Edmonds. Angular momentum in quantum mechanics. + Investigations in physics, 4.; Investigations in physics, no. 4. + Princeton, N.J., Princeton University Press, 1957. +.. [Rasch03] J. Rasch and A. C. H. Yu, 'Efficient Storage Scheme for + Pre-calculated Wigner 3j, 6j and Gaunt Coefficients', SIAM + J. Sci. Comput. Volume 25, Issue 4, pp. 1416-1428 (2003) +.. [Liberatodebrito82] 'FORTRAN program for the integral of three + spherical harmonics', A. Liberato de Brito, + Comput. Phys. Commun., Volume 25, pp. 81-85 (1982) +.. [Homeier96] 'Some Properties of the Coupling Coefficients of Real + Spherical Harmonics and Their Relation to Gaunt Coefficients', + H. H. H. Homeier and E. O. Steinborn J. Mol. Struct., Volume 368, + pp. 31-37 (1996) + +Credits and Copyright +===================== + +This code was taken from Sage with the permission of all authors: + +https://groups.google.com/forum/#!topic/sage-devel/M4NZdu-7O38 + +Authors +======= + +- Jens Rasch (2009-03-24): initial version for Sage + +- Jens Rasch (2009-05-31): updated to sage-4.0 + +- Oscar Gerardo Lazo Arjona (2017-06-18): added Wigner D matrices + +- Phil Adam LeMaitre (2022-09-19): added real Gaunt coefficient + +Copyright (C) 2008 Jens Rasch + +""" +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.numbers import int_valued +from sympy.core.function import Function +from sympy.core.numbers import (Float, I, Integer, pi, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.complexes import re +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.spherical_harmonics import Ynm +from sympy.matrices.dense import zeros +from sympy.matrices.immutable import ImmutableMatrix +from sympy.utilities.misc import as_int + +# This list of precomputed factorials is needed to massively +# accelerate future calculations of the various coefficients +_Factlist = [1] + + +def _calc_factlist(nn): + r""" + Function calculates a list of precomputed factorials in order to + massively accelerate future calculations of the various + coefficients. + + Parameters + ========== + + nn : integer + Highest factorial to be computed. + + Returns + ======= + + list of integers : + The list of precomputed factorials. + + Examples + ======== + + Calculate list of factorials:: + + sage: from sage.functions.wigner import _calc_factlist + sage: _calc_factlist(10) + [1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880, 3628800] + """ + if nn >= len(_Factlist): + for ii in range(len(_Factlist), int(nn + 1)): + _Factlist.append(_Factlist[ii - 1] * ii) + return _Factlist[:int(nn) + 1] + + +def _int_or_halfint(value): + """return Python int unless value is half-int (then return float)""" + if isinstance(value, int): + return value + elif type(value) is float: + if value.is_integer(): + return int(value) # an int + if (2*value).is_integer(): + return value # a float + elif isinstance(value, Rational): + if value.q == 2: + return value.p/value.q # a float + elif value.q == 1: + return value.p # an int + elif isinstance(value, Float): + return _int_or_halfint(float(value)) + raise ValueError("expecting integer or half-integer, got %s" % value) + + +def wigner_3j(j_1, j_2, j_3, m_1, m_2, m_3): + r""" + Calculate the Wigner 3j symbol `\operatorname{Wigner3j}(j_1,j_2,j_3,m_1,m_2,m_3)`. + + Parameters + ========== + + j_1, j_2, j_3, m_1, m_2, m_3 : + Integer or half integer. + + Returns + ======= + + Rational number times the square root of a rational number. + + Examples + ======== + + >>> from sympy.physics.wigner import wigner_3j + >>> wigner_3j(2, 6, 4, 0, 0, 0) + sqrt(715)/143 + >>> wigner_3j(2, 6, 4, 0, 0, 1) + 0 + + It is an error to have arguments that are not integer or half + integer values:: + + sage: wigner_3j(2.1, 6, 4, 0, 0, 0) + Traceback (most recent call last): + ... + ValueError: j values must be integer or half integer + sage: wigner_3j(2, 6, 4, 1, 0, -1.1) + Traceback (most recent call last): + ... + ValueError: m values must be integer or half integer + + Notes + ===== + + The Wigner 3j symbol obeys the following symmetry rules: + + - invariant under any permutation of the columns (with the + exception of a sign change where `J:=j_1+j_2+j_3`): + + .. math:: + + \begin{aligned} + \operatorname{Wigner3j}(j_1,j_2,j_3,m_1,m_2,m_3) + &=\operatorname{Wigner3j}(j_3,j_1,j_2,m_3,m_1,m_2) \\ + &=\operatorname{Wigner3j}(j_2,j_3,j_1,m_2,m_3,m_1) \\ + &=(-1)^J \operatorname{Wigner3j}(j_3,j_2,j_1,m_3,m_2,m_1) \\ + &=(-1)^J \operatorname{Wigner3j}(j_1,j_3,j_2,m_1,m_3,m_2) \\ + &=(-1)^J \operatorname{Wigner3j}(j_2,j_1,j_3,m_2,m_1,m_3) + \end{aligned} + + - invariant under space inflection, i.e. + + .. math:: + + \operatorname{Wigner3j}(j_1,j_2,j_3,m_1,m_2,m_3) + =(-1)^J \operatorname{Wigner3j}(j_1,j_2,j_3,-m_1,-m_2,-m_3) + + - symmetric with respect to the 72 additional symmetries based on + the work by [Regge58]_ + + - zero for `j_1`, `j_2`, `j_3` not fulfilling triangle relation + + - zero for `m_1 + m_2 + m_3 \neq 0` + + - zero for violating any one of the conditions + `m_1 \in \{-|j_1|, \ldots, |j_1|\}`, + `m_2 \in \{-|j_2|, \ldots, |j_2|\}`, + `m_3 \in \{-|j_3|, \ldots, |j_3|\}` + + Algorithm + ========= + + This function uses the algorithm of [Edmonds74]_ to calculate the + value of the 3j symbol exactly. Note that the formula contains + alternating sums over large factorials and is therefore unsuitable + for finite precision arithmetic and only useful for a computer + algebra system [Rasch03]_. + + Authors + ======= + + - Jens Rasch (2009-03-24): initial version + """ + + j_1, j_2, j_3, m_1, m_2, m_3 = \ + map(_int_or_halfint, map(sympify, + [j_1, j_2, j_3, m_1, m_2, m_3])) + + if m_1 + m_2 + m_3 != 0: + return S.Zero + a1 = j_1 + j_2 - j_3 + if a1 < 0: + return S.Zero + a2 = j_1 - j_2 + j_3 + if a2 < 0: + return S.Zero + a3 = -j_1 + j_2 + j_3 + if a3 < 0: + return S.Zero + if (abs(m_1) > j_1) or (abs(m_2) > j_2) or (abs(m_3) > j_3): + return S.Zero + if not (int_valued(j_1 - m_1) and \ + int_valued(j_2 - m_2) and \ + int_valued(j_3 - m_3)): + return S.Zero + + maxfact = max(j_1 + j_2 + j_3 + 1, j_1 + abs(m_1), j_2 + abs(m_2), + j_3 + abs(m_3)) + _calc_factlist(int(maxfact)) + + argsqrt = Integer(_Factlist[int(j_1 + j_2 - j_3)] * + _Factlist[int(j_1 - j_2 + j_3)] * + _Factlist[int(-j_1 + j_2 + j_3)] * + _Factlist[int(j_1 - m_1)] * + _Factlist[int(j_1 + m_1)] * + _Factlist[int(j_2 - m_2)] * + _Factlist[int(j_2 + m_2)] * + _Factlist[int(j_3 - m_3)] * + _Factlist[int(j_3 + m_3)]) / \ + _Factlist[int(j_1 + j_2 + j_3 + 1)] + + ressqrt = sqrt(argsqrt) + if ressqrt.is_complex or ressqrt.is_infinite: + ressqrt = ressqrt.as_real_imag()[0] + + imin = max(-j_3 + j_1 + m_2, -j_3 + j_2 - m_1, 0) + imax = min(j_2 + m_2, j_1 - m_1, j_1 + j_2 - j_3) + sumres = 0 + for ii in range(int(imin), int(imax) + 1): + den = _Factlist[ii] * \ + _Factlist[int(ii + j_3 - j_1 - m_2)] * \ + _Factlist[int(j_2 + m_2 - ii)] * \ + _Factlist[int(j_1 - ii - m_1)] * \ + _Factlist[int(ii + j_3 - j_2 + m_1)] * \ + _Factlist[int(j_1 + j_2 - j_3 - ii)] + sumres = sumres + Integer((-1) ** ii) / den + + prefid = Integer((-1) ** int(j_1 - j_2 - m_3)) + res = ressqrt * sumres * prefid + return res + + +def clebsch_gordan(j_1, j_2, j_3, m_1, m_2, m_3): + r""" + Calculates the Clebsch-Gordan coefficient. + `\left\langle j_1 m_1 \; j_2 m_2 | j_3 m_3 \right\rangle`. + + The reference for this function is [Edmonds74]_. + + Parameters + ========== + + j_1, j_2, j_3, m_1, m_2, m_3 : + Integer or half integer. + + Returns + ======= + + Rational number times the square root of a rational number. + + Examples + ======== + + >>> from sympy import S + >>> from sympy.physics.wigner import clebsch_gordan + >>> clebsch_gordan(S(3)/2, S(1)/2, 2, S(3)/2, S(1)/2, 2) + 1 + >>> clebsch_gordan(S(3)/2, S(1)/2, 1, S(3)/2, -S(1)/2, 1) + sqrt(3)/2 + >>> clebsch_gordan(S(3)/2, S(1)/2, 1, -S(1)/2, S(1)/2, 0) + -sqrt(2)/2 + + Notes + ===== + + The Clebsch-Gordan coefficient will be evaluated via its relation + to Wigner 3j symbols: + + .. math:: + + \left\langle j_1 m_1 \; j_2 m_2 | j_3 m_3 \right\rangle + =(-1)^{j_1-j_2+m_3} \sqrt{2j_3+1} + \operatorname{Wigner3j}(j_1,j_2,j_3,m_1,m_2,-m_3) + + See also the documentation on Wigner 3j symbols which exhibit much + higher symmetry relations than the Clebsch-Gordan coefficient. + + Authors + ======= + + - Jens Rasch (2009-03-24): initial version + """ + j_1 = sympify(j_1) + j_2 = sympify(j_2) + j_3 = sympify(j_3) + m_1 = sympify(m_1) + m_2 = sympify(m_2) + m_3 = sympify(m_3) + + w = wigner_3j(j_1, j_2, j_3, m_1, m_2, -m_3) + + return (-1) ** (j_1 - j_2 + m_3) * sqrt(2 * j_3 + 1) * w + + +def _big_delta_coeff(aa, bb, cc, prec=None): + r""" + Calculates the Delta coefficient of the 3 angular momenta for + Racah symbols. Also checks that the differences are of integer + value. + + Parameters + ========== + + aa : + First angular momentum, integer or half integer. + bb : + Second angular momentum, integer or half integer. + cc : + Third angular momentum, integer or half integer. + prec : + Precision of the ``sqrt()`` calculation. + + Returns + ======= + + double : Value of the Delta coefficient. + + Examples + ======== + + sage: from sage.functions.wigner import _big_delta_coeff + sage: _big_delta_coeff(1,1,1) + 1/2*sqrt(1/6) + """ + + # the triangle test will only pass if a) all 3 values are ints or + # b) 1 is an int and the other two are half-ints + if not int_valued(aa + bb - cc): + raise ValueError("j values must be integer or half integer and fulfill the triangle relation") + if not int_valued(aa + cc - bb): + raise ValueError("j values must be integer or half integer and fulfill the triangle relation") + if not int_valued(bb + cc - aa): + raise ValueError("j values must be integer or half integer and fulfill the triangle relation") + if (aa + bb - cc) < 0: + return S.Zero + if (aa + cc - bb) < 0: + return S.Zero + if (bb + cc - aa) < 0: + return S.Zero + + maxfact = max(aa + bb - cc, aa + cc - bb, bb + cc - aa, aa + bb + cc + 1) + _calc_factlist(maxfact) + + argsqrt = Integer(_Factlist[int(aa + bb - cc)] * + _Factlist[int(aa + cc - bb)] * + _Factlist[int(bb + cc - aa)]) / \ + Integer(_Factlist[int(aa + bb + cc + 1)]) + + ressqrt = sqrt(argsqrt) + if prec: + ressqrt = ressqrt.evalf(prec).as_real_imag()[0] + return ressqrt + + +def racah(aa, bb, cc, dd, ee, ff, prec=None): + r""" + Calculate the Racah symbol `W(a,b,c,d;e,f)`. + + Parameters + ========== + + a, ..., f : + Integer or half integer. + prec : + Precision, default: ``None``. Providing a precision can + drastically speed up the calculation. + + Returns + ======= + + Rational number times the square root of a rational number + (if ``prec=None``), or real number if a precision is given. + + Examples + ======== + + >>> from sympy.physics.wigner import racah + >>> racah(3,3,3,3,3,3) + -1/14 + + Notes + ===== + + The Racah symbol is related to the Wigner 6j symbol: + + .. math:: + + \operatorname{Wigner6j}(j_1,j_2,j_3,j_4,j_5,j_6) + =(-1)^{j_1+j_2+j_4+j_5} W(j_1,j_2,j_5,j_4,j_3,j_6) + + Please see the 6j symbol for its much richer symmetries and for + additional properties. + + Algorithm + ========= + + This function uses the algorithm of [Edmonds74]_ to calculate the + value of the 6j symbol exactly. Note that the formula contains + alternating sums over large factorials and is therefore unsuitable + for finite precision arithmetic and only useful for a computer + algebra system [Rasch03]_. + + Authors + ======= + + - Jens Rasch (2009-03-24): initial version + """ + prefac = _big_delta_coeff(aa, bb, ee, prec) * \ + _big_delta_coeff(cc, dd, ee, prec) * \ + _big_delta_coeff(aa, cc, ff, prec) * \ + _big_delta_coeff(bb, dd, ff, prec) + if prefac == 0: + return S.Zero + imin = max(aa + bb + ee, cc + dd + ee, aa + cc + ff, bb + dd + ff) + imax = min(aa + bb + cc + dd, aa + dd + ee + ff, bb + cc + ee + ff) + + maxfact = max(imax + 1, aa + bb + cc + dd, aa + dd + ee + ff, + bb + cc + ee + ff) + _calc_factlist(maxfact) + + sumres = 0 + for kk in range(int(imin), int(imax) + 1): + den = _Factlist[int(kk - aa - bb - ee)] * \ + _Factlist[int(kk - cc - dd - ee)] * \ + _Factlist[int(kk - aa - cc - ff)] * \ + _Factlist[int(kk - bb - dd - ff)] * \ + _Factlist[int(aa + bb + cc + dd - kk)] * \ + _Factlist[int(aa + dd + ee + ff - kk)] * \ + _Factlist[int(bb + cc + ee + ff - kk)] + sumres = sumres + Integer((-1) ** kk * _Factlist[kk + 1]) / den + + res = prefac * sumres * (-1) ** int(aa + bb + cc + dd) + return res + + +def wigner_6j(j_1, j_2, j_3, j_4, j_5, j_6, prec=None): + r""" + Calculate the Wigner 6j symbol `\operatorname{Wigner6j}(j_1,j_2,j_3,j_4,j_5,j_6)`. + + Parameters + ========== + + j_1, ..., j_6 : + Integer or half integer. + prec : + Precision, default: ``None``. Providing a precision can + drastically speed up the calculation. + + Returns + ======= + + Rational number times the square root of a rational number + (if ``prec=None``), or real number if a precision is given. + + Examples + ======== + + >>> from sympy.physics.wigner import wigner_6j + >>> wigner_6j(3,3,3,3,3,3) + -1/14 + >>> wigner_6j(5,5,5,5,5,5) + 1/52 + + It is an error to have arguments that are not integer or half + integer values or do not fulfill the triangle relation:: + + sage: wigner_6j(2.5,2.5,2.5,2.5,2.5,2.5) + Traceback (most recent call last): + ... + ValueError: j values must be integer or half integer and fulfill the triangle relation + sage: wigner_6j(0.5,0.5,1.1,0.5,0.5,1.1) + Traceback (most recent call last): + ... + ValueError: j values must be integer or half integer and fulfill the triangle relation + + Notes + ===== + + The Wigner 6j symbol is related to the Racah symbol but exhibits + more symmetries as detailed below. + + .. math:: + + \operatorname{Wigner6j}(j_1,j_2,j_3,j_4,j_5,j_6) + =(-1)^{j_1+j_2+j_4+j_5} W(j_1,j_2,j_5,j_4,j_3,j_6) + + The Wigner 6j symbol obeys the following symmetry rules: + + - Wigner 6j symbols are left invariant under any permutation of + the columns: + + .. math:: + + \begin{aligned} + \operatorname{Wigner6j}(j_1,j_2,j_3,j_4,j_5,j_6) + &=\operatorname{Wigner6j}(j_3,j_1,j_2,j_6,j_4,j_5) \\ + &=\operatorname{Wigner6j}(j_2,j_3,j_1,j_5,j_6,j_4) \\ + &=\operatorname{Wigner6j}(j_3,j_2,j_1,j_6,j_5,j_4) \\ + &=\operatorname{Wigner6j}(j_1,j_3,j_2,j_4,j_6,j_5) \\ + &=\operatorname{Wigner6j}(j_2,j_1,j_3,j_5,j_4,j_6) + \end{aligned} + + - They are invariant under the exchange of the upper and lower + arguments in each of any two columns, i.e. + + .. math:: + + \begin{aligned} + \operatorname{Wigner6j}(j_1,j_2,j_3,j_4,j_5,j_6) + &=\operatorname{Wigner6j}(j_1,j_5,j_6,j_4,j_2,j_3)\\ + &=\operatorname{Wigner6j}(j_4,j_2,j_6,j_1,j_5,j_3)\\ + &=\operatorname{Wigner6j}(j_4,j_5,j_3,j_1,j_2,j_6) + \end{aligned} + + - additional 6 symmetries [Regge59]_ giving rise to 144 symmetries + in total + + - only non-zero if any triple of `j`'s fulfill a triangle relation + + Algorithm + ========= + + This function uses the algorithm of [Edmonds74]_ to calculate the + value of the 6j symbol exactly. Note that the formula contains + alternating sums over large factorials and is therefore unsuitable + for finite precision arithmetic and only useful for a computer + algebra system [Rasch03]_. + + """ + j_1, j_2, j_3, j_4, j_5, j_6 = map(sympify, \ + [j_1, j_2, j_3, j_4, j_5, j_6]) + res = (-1) ** int(j_1 + j_2 + j_4 + j_5) * \ + racah(j_1, j_2, j_5, j_4, j_3, j_6, prec) + return res + + +def wigner_9j(j_1, j_2, j_3, j_4, j_5, j_6, j_7, j_8, j_9, prec=None): + r""" + Calculate the Wigner 9j symbol + `\operatorname{Wigner9j}(j_1,j_2,j_3,j_4,j_5,j_6,j_7,j_8,j_9)`. + + Parameters + ========== + + j_1, ..., j_9 : + Integer or half integer. + prec : precision, default + ``None``. Providing a precision can + drastically speed up the calculation. + + Returns + ======= + + Rational number times the square root of a rational number + (if ``prec=None``), or real number if a precision is given. + + Examples + ======== + + >>> from sympy.physics.wigner import wigner_9j + >>> wigner_9j(1,1,1, 1,1,1, 1,1,0, prec=64) + 0.05555555555555555555555555555555555555555555555555555555555555555 + + >>> wigner_9j(1/2,1/2,0, 1/2,3/2,1, 0,1,1, prec=64) + 0.1666666666666666666666666666666666666666666666666666666666666667 + + It is an error to have arguments that are not integer or half + integer values or do not fulfill the triangle relation:: + + sage: wigner_9j(0.5,0.5,0.5, 0.5,0.5,0.5, 0.5,0.5,0.5,prec=64) + Traceback (most recent call last): + ... + ValueError: j values must be integer or half integer and fulfill the triangle relation + sage: wigner_9j(1,1,1, 0.5,1,1.5, 0.5,1,2.5,prec=64) + Traceback (most recent call last): + ... + ValueError: j values must be integer or half integer and fulfill the triangle relation + + Algorithm + ========= + + This function uses the algorithm of [Edmonds74]_ to calculate the + value of the 3j symbol exactly. Note that the formula contains + alternating sums over large factorials and is therefore unsuitable + for finite precision arithmetic and only useful for a computer + algebra system [Rasch03]_. + """ + j_1, j_2, j_3, j_4, j_5, j_6, j_7, j_8, j_9 = map(sympify, \ + [j_1, j_2, j_3, j_4, j_5, j_6, j_7, j_8, j_9]) + imax = int(min(j_1 + j_9, j_2 + j_6, j_4 + j_8) * 2) + imin = imax % 2 + sumres = 0 + for kk in range(imin, int(imax) + 1, 2): + sumres = sumres + (kk + 1) * \ + racah(j_1, j_2, j_9, j_6, j_3, kk / 2, prec) * \ + racah(j_4, j_6, j_8, j_2, j_5, kk / 2, prec) * \ + racah(j_1, j_4, j_9, j_8, j_7, kk / 2, prec) + return sumres + + +def gaunt(l_1, l_2, l_3, m_1, m_2, m_3, prec=None): + r""" + Calculate the Gaunt coefficient. + + Explanation + =========== + + The Gaunt coefficient is defined as the integral over three + spherical harmonics: + + .. math:: + + \begin{aligned} + \operatorname{Gaunt}(l_1,l_2,l_3,m_1,m_2,m_3) + &=\int Y_{l_1,m_1}(\Omega) + Y_{l_2,m_2}(\Omega) Y_{l_3,m_3}(\Omega) \,d\Omega \\ + &=\sqrt{\frac{(2l_1+1)(2l_2+1)(2l_3+1)}{4\pi}} + \operatorname{Wigner3j}(l_1,l_2,l_3,0,0,0) + \operatorname{Wigner3j}(l_1,l_2,l_3,m_1,m_2,m_3) + \end{aligned} + + Parameters + ========== + + l_1, l_2, l_3, m_1, m_2, m_3 : + Integer. + prec - precision, default: ``None``. + Providing a precision can + drastically speed up the calculation. + + Returns + ======= + + Rational number times the square root of a rational number + (if ``prec=None``), or real number if a precision is given. + + Examples + ======== + + >>> from sympy.physics.wigner import gaunt + >>> gaunt(1,0,1,1,0,-1) + -1/(2*sqrt(pi)) + >>> gaunt(1000,1000,1200,9,3,-12).n(64) + 0.006895004219221134484332976156744208248842039317638217822322799675 + + It is an error to use non-integer values for `l` and `m`:: + + sage: gaunt(1.2,0,1.2,0,0,0) + Traceback (most recent call last): + ... + ValueError: l values must be integer + sage: gaunt(1,0,1,1.1,0,-1.1) + Traceback (most recent call last): + ... + ValueError: m values must be integer + + Notes + ===== + + The Gaunt coefficient obeys the following symmetry rules: + + - invariant under any permutation of the columns + + .. math:: + \begin{aligned} + Y(l_1,l_2,l_3,m_1,m_2,m_3) + &=Y(l_3,l_1,l_2,m_3,m_1,m_2) \\ + &=Y(l_2,l_3,l_1,m_2,m_3,m_1) \\ + &=Y(l_3,l_2,l_1,m_3,m_2,m_1) \\ + &=Y(l_1,l_3,l_2,m_1,m_3,m_2) \\ + &=Y(l_2,l_1,l_3,m_2,m_1,m_3) + \end{aligned} + + - invariant under space inflection, i.e. + + .. math:: + Y(l_1,l_2,l_3,m_1,m_2,m_3) + =Y(l_1,l_2,l_3,-m_1,-m_2,-m_3) + + - symmetric with respect to the 72 Regge symmetries as inherited + for the `3j` symbols [Regge58]_ + + - zero for `l_1`, `l_2`, `l_3` not fulfilling triangle relation + + - zero for violating any one of the conditions: `l_1 \ge |m_1|`, + `l_2 \ge |m_2|`, `l_3 \ge |m_3|` + + - non-zero only for an even sum of the `l_i`, i.e. + `L = l_1 + l_2 + l_3 = 2n` for `n` in `\mathbb{N}` + + Algorithms + ========== + + This function uses the algorithm of [Liberatodebrito82]_ to + calculate the value of the Gaunt coefficient exactly. Note that + the formula contains alternating sums over large factorials and is + therefore unsuitable for finite precision arithmetic and only + useful for a computer algebra system [Rasch03]_. + + Authors + ======= + + Jens Rasch (2009-03-24): initial version for Sage. + """ + l_1, l_2, l_3, m_1, m_2, m_3 = [ + as_int(i) for i in (l_1, l_2, l_3, m_1, m_2, m_3)] + + if l_1 + l_2 - l_3 < 0: + return S.Zero + if l_1 - l_2 + l_3 < 0: + return S.Zero + if -l_1 + l_2 + l_3 < 0: + return S.Zero + if (m_1 + m_2 + m_3) != 0: + return S.Zero + if (abs(m_1) > l_1) or (abs(m_2) > l_2) or (abs(m_3) > l_3): + return S.Zero + bigL, remL = divmod(l_1 + l_2 + l_3, 2) + if remL % 2: + return S.Zero + + imin = max(-l_3 + l_1 + m_2, -l_3 + l_2 - m_1, 0) + imax = min(l_2 + m_2, l_1 - m_1, l_1 + l_2 - l_3) + + _calc_factlist(max(l_1 + l_2 + l_3 + 1, imax + 1)) + + ressqrt = sqrt((2 * l_1 + 1) * (2 * l_2 + 1) * (2 * l_3 + 1) * \ + _Factlist[l_1 - m_1] * _Factlist[l_1 + m_1] * _Factlist[l_2 - m_2] * \ + _Factlist[l_2 + m_2] * _Factlist[l_3 - m_3] * _Factlist[l_3 + m_3] / \ + (4*pi)) + + prefac = Integer(_Factlist[bigL] * _Factlist[l_2 - l_1 + l_3] * + _Factlist[l_1 - l_2 + l_3] * _Factlist[l_1 + l_2 - l_3])/ \ + _Factlist[2 * bigL + 1]/ \ + (_Factlist[bigL - l_1] * + _Factlist[bigL - l_2] * _Factlist[bigL - l_3]) + + sumres = 0 + for ii in range(int(imin), int(imax) + 1): + den = _Factlist[ii] * _Factlist[ii + l_3 - l_1 - m_2] * \ + _Factlist[l_2 + m_2 - ii] * _Factlist[l_1 - ii - m_1] * \ + _Factlist[ii + l_3 - l_2 + m_1] * _Factlist[l_1 + l_2 - l_3 - ii] + sumres = sumres + Integer((-1) ** ii) / den + + res = ressqrt * prefac * sumres * Integer((-1) ** (bigL + l_3 + m_1 - m_2)) + if prec is not None: + res = res.n(prec) + return res + + +def real_gaunt(l_1, l_2, l_3, mu_1, mu_2, mu_3, prec=None): + r""" + Calculate the real Gaunt coefficient. + + Explanation + =========== + + The real Gaunt coefficient is defined as the integral over three + real spherical harmonics: + + .. math:: + \begin{aligned} + \operatorname{RealGaunt}(l_1,l_2,l_3,\mu_1,\mu_2,\mu_3) + &=\int Z^{\mu_1}_{l_1}(\Omega) + Z^{\mu_2}_{l_2}(\Omega) Z^{\mu_3}_{l_3}(\Omega) \,d\Omega \\ + \end{aligned} + + Alternatively, it can be defined in terms of the standard Gaunt + coefficient by relating the real spherical harmonics to the standard + spherical harmonics via a unitary transformation `U`, i.e. + `Z^{\mu}_{l}(\Omega)=\sum_{m'}U^{\mu}_{m'}Y^{m'}_{l}(\Omega)` [Homeier96]_. + The real Gaunt coefficient is then defined as + + .. math:: + \begin{aligned} + \operatorname{RealGaunt}(l_1,l_2,l_3,\mu_1,\mu_2,\mu_3) + &=\int Z^{\mu_1}_{l_1}(\Omega) + Z^{\mu_2}_{l_2}(\Omega) Z^{\mu_3}_{l_3}(\Omega) \,d\Omega \\ + &=\sum_{m'_1 m'_2 m'_3} U^{\mu_1}_{m'_1}U^{\mu_2}_{m'_2}U^{\mu_3}_{m'_3} + \operatorname{Gaunt}(l_1,l_2,l_3,m'_1,m'_2,m'_3) + \end{aligned} + + The unitary matrix `U` has components + + .. math:: + \begin{aligned} + U^\mu_{m} = \delta_{|\mu||m|}*(\delta_{m0}\delta_{\mu 0} + \frac{1}{\sqrt{2}}\big[\Theta(\mu)\big(\delta_{m\mu}+(-1)^{m}\delta_{m-\mu}\big) + +i \Theta(-\mu)\big((-1)^{m}\delta_{m\mu}-\delta_{m-\mu}\big)\big]) + \end{aligned} + + + where `\delta_{ij}` is the Kronecker delta symbol and `\Theta` is a step + function defined as + + .. math:: + \begin{aligned} + \Theta(x) = \begin{cases} 1 \,\text{for}\, x > 0 \\ 0 \,\text{for}\, x \leq 0 \end{cases} + \end{aligned} + + Parameters + ========== + + l_1, l_2, l_3, mu_1, mu_2, mu_3 : + Integer degree and order + + prec - precision, default: ``None``. + Providing a precision can + drastically speed up the calculation. + + Returns + ======= + + Rational number times the square root of a rational number. + + Examples + ======== + >>> from sympy.physics.wigner import real_gaunt + >>> real_gaunt(1,1,2,-1,1,-2) + sqrt(15)/(10*sqrt(pi)) + >>> real_gaunt(10,10,20,-9,-9,0,prec=64) + -0.00002480019791932209313156167176797577821140084216297395518482071448 + + It is an error to use non-integer values for `l` and `\mu`:: + real_gaunt(2.8,0.5,1.3,0,0,0) + Traceback (most recent call last): + ... + ValueError: l values must be integer + + real_gaunt(2,2,4,0.7,1,-3.4) + Traceback (most recent call last): + ... + ValueError: mu values must be integer + + Notes + ===== + + The real Gaunt coefficient inherits from the standard Gaunt coefficient, + the invariance under any permutation of the pairs `(l_i, \mu_i)` and the + requirement that the sum of the `l_i` be even to yield a non-zero value. + It also obeys the following symmetry rules: + + - zero for `l_1`, `l_2`, `l_3` not fulfilling the condition + `l_1 \in \{l_{\text{max}}, l_{\text{max}}-2, \ldots, l_{\text{min}}\}`, + where `l_{\text{max}} = l_2+l_3`, + + .. math:: + \begin{aligned} + l_{\text{min}} = \begin{cases} \kappa(l_2, l_3, \mu_2, \mu_3) & \text{if}\, + \kappa(l_2, l_3, \mu_2, \mu_3) + l_{\text{max}}\, \text{is even} \\ + \kappa(l_2, l_3, \mu_2, \mu_3)+1 & \text{if}\, \kappa(l_2, l_3, \mu_2, \mu_3) + + l_{\text{max}}\, \text{is odd}\end{cases} + \end{aligned} + + and `\kappa(l_2, l_3, \mu_2, \mu_3) = \max{\big(|l_2-l_3|, \min{\big(|\mu_2+\mu_3|, + |\mu_2-\mu_3|\big)}\big)}` + + - zero for an odd number of negative `\mu_i` + + Algorithms + ========== + + This function uses the algorithms of [Homeier96]_ and [Rasch03]_ to + calculate the value of the real Gaunt coefficient exactly. Note that + the formula used in [Rasch03]_ contains alternating sums over large + factorials and is therefore unsuitable for finite precision arithmetic + and only useful for a computer algebra system [Rasch03]_. However, this + function can in principle use any algorithm that computes the Gaunt + coefficient, so it is suitable for finite precision arithmetic in so far + as the algorithm which computes the Gaunt coefficient is. + """ + l_1, l_2, l_3, mu_1, mu_2, mu_3 = [ + as_int(i) for i in (l_1, l_2, l_3, mu_1, mu_2, mu_3)] + + # check for quick exits + if sum(1 for i in (mu_1, mu_2, mu_3) if i < 0) % 2: + return S.Zero # odd number of negative m + if (l_1 + l_2 + l_3) % 2: + return S.Zero # sum of l is odd + lmax = l_2 + l_3 + lmin = max(abs(l_2 - l_3), min(abs(mu_2 + mu_3), abs(mu_2 - mu_3))) + if (lmin + lmax) % 2: + lmin += 1 + if lmin not in range(lmax, lmin - 2, -2): + return S.Zero + + kron_del = lambda i, j: 1 if i == j else 0 + s = lambda e: -1 if e % 2 else 1 # (-1)**e to give +/-1, avoiding float when e<0 + + t = lambda x: 1 if x > 0 else 0 + A = lambda mu, m: t(-mu) * (s(m) * kron_del(m, mu) - kron_del(m, -mu)) + B = lambda mu, m: t(mu) * (kron_del(m, mu) + s(m) * kron_del(m, -mu)) + U = lambda mu, m: kron_del(abs(mu), abs(m)) * (kron_del(mu, 0) * kron_del(m, 0) + (B(mu, m) + I * A(mu, m))/sqrt(2)) + + ugnt = 0 + for m1 in range(-l_1, l_1+1): + U1 = U(mu_1, m1) + for m2 in range(-l_2, l_2+1): + U2 = U(mu_2, m2) + U3 = U(mu_3,-m1-m2) + ugnt = ugnt + re(U1*U2*U3)*gaunt(l_1, l_2, l_3, m1, m2, -m1 - m2, prec=prec) + + return ugnt + + +class Wigner3j(Function): + + def doit(self, **hints): + if all(obj.is_number for obj in self.args): + return wigner_3j(*self.args) + else: + return self + +def dot_rot_grad_Ynm(j, p, l, m, theta, phi): + r""" + Returns dot product of rotational gradients of spherical harmonics. + + Explanation + =========== + + This function returns the right hand side of the following expression: + + .. math :: + \vec{R}Y{_j^{p}} \cdot \vec{R}Y{_l^{m}} = (-1)^{m+p} + \sum\limits_{k=|l-j|}^{l+j}Y{_k^{m+p}} * \alpha_{l,m,j,p,k} * + \frac{1}{2} (k^2-j^2-l^2+k-j-l) + + + Arguments + ========= + + j, p, l, m .... indices in spherical harmonics (expressions or integers) + theta, phi .... angle arguments in spherical harmonics + + Example + ======= + + >>> from sympy import symbols + >>> from sympy.physics.wigner import dot_rot_grad_Ynm + >>> theta, phi = symbols("theta phi") + >>> dot_rot_grad_Ynm(3, 2, 2, 0, theta, phi).doit() + 3*sqrt(55)*Ynm(5, 2, theta, phi)/(11*sqrt(pi)) + + """ + j = sympify(j) + p = sympify(p) + l = sympify(l) + m = sympify(m) + theta = sympify(theta) + phi = sympify(phi) + k = Dummy("k") + + def alpha(l,m,j,p,k): + return sqrt((2*l+1)*(2*j+1)*(2*k+1)/(4*pi)) * \ + Wigner3j(j, l, k, S.Zero, S.Zero, S.Zero) * \ + Wigner3j(j, l, k, p, m, -m-p) + + return (S.NegativeOne)**(m+p) * Sum(Ynm(k, m+p, theta, phi) * alpha(l,m,j,p,k) / 2 \ + *(k**2-j**2-l**2+k-j-l), (k, abs(l-j), l+j)) + + +def wigner_d_small(J, beta): + """Return the small Wigner d matrix for angular momentum J. + + Explanation + =========== + + J : An integer, half-integer, or SymPy symbol for the total angular + momentum of the angular momentum space being rotated. + beta : A real number representing the Euler angle of rotation about + the so-called line of nodes. See [Edmonds74]_. + + Returns + ======= + + A matrix representing the corresponding Euler angle rotation( in the basis + of eigenvectors of `J_z`). + + .. math :: + \\mathcal{d}_{\\beta} = \\exp\\big( \\frac{i\\beta}{\\hbar} J_y\\big) + + such that + + .. math :: + d^{(J)}_{m',m}(\\beta) = \\mathtt{wigner\\_d\\_small(J,beta)[J-mprime,J-m]} + + The components are calculated using the general form [Edmonds74]_, + equation 4.1.15. + + Examples + ======== + + >>> from sympy import Integer, symbols, pi, pprint + >>> from sympy.physics.wigner import wigner_d_small + >>> half = 1/Integer(2) + >>> beta = symbols("beta", real=True) + >>> pprint(wigner_d_small(half, beta), use_unicode=True) + ⎡ ⎛β⎞ ⎛β⎞⎤ + ⎢cos⎜─⎟ sin⎜─⎟⎥ + ⎢ ⎝2⎠ ⎝2⎠⎥ + ⎢ ⎥ + ⎢ ⎛β⎞ ⎛β⎞⎥ + ⎢-sin⎜─⎟ cos⎜─⎟⎥ + ⎣ ⎝2⎠ ⎝2⎠⎦ + + >>> pprint(wigner_d_small(2*half, beta), use_unicode=True) + ⎡ 2⎛β⎞ ⎛β⎞ ⎛β⎞ 2⎛β⎞ ⎤ + ⎢ cos ⎜─⎟ √2⋅sin⎜─⎟⋅cos⎜─⎟ sin ⎜─⎟ ⎥ + ⎢ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎥ + ⎢ ⎥ + ⎢ ⎛β⎞ ⎛β⎞ 2⎛β⎞ 2⎛β⎞ ⎛β⎞ ⎛β⎞⎥ + ⎢-√2⋅sin⎜─⎟⋅cos⎜─⎟ - sin ⎜─⎟ + cos ⎜─⎟ √2⋅sin⎜─⎟⋅cos⎜─⎟⎥ + ⎢ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎝2⎠⎥ + ⎢ ⎥ + ⎢ 2⎛β⎞ ⎛β⎞ ⎛β⎞ 2⎛β⎞ ⎥ + ⎢ sin ⎜─⎟ -√2⋅sin⎜─⎟⋅cos⎜─⎟ cos ⎜─⎟ ⎥ + ⎣ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎦ + + From table 4 in [Edmonds74]_ + + >>> pprint(wigner_d_small(half, beta).subs({beta:pi/2}), use_unicode=True) + ⎡ √2 √2⎤ + ⎢ ── ──⎥ + ⎢ 2 2 ⎥ + ⎢ ⎥ + ⎢-√2 √2⎥ + ⎢──── ──⎥ + ⎣ 2 2 ⎦ + + >>> pprint(wigner_d_small(2*half, beta).subs({beta:pi/2}), + ... use_unicode=True) + ⎡ √2 ⎤ + ⎢1/2 ── 1/2⎥ + ⎢ 2 ⎥ + ⎢ ⎥ + ⎢-√2 √2 ⎥ + ⎢──── 0 ── ⎥ + ⎢ 2 2 ⎥ + ⎢ ⎥ + ⎢ -√2 ⎥ + ⎢1/2 ──── 1/2⎥ + ⎣ 2 ⎦ + + >>> pprint(wigner_d_small(3*half, beta).subs({beta:pi/2}), + ... use_unicode=True) + ⎡ √2 √6 √6 √2⎤ + ⎢ ── ── ── ──⎥ + ⎢ 4 4 4 4 ⎥ + ⎢ ⎥ + ⎢-√6 -√2 √2 √6⎥ + ⎢──── ──── ── ──⎥ + ⎢ 4 4 4 4 ⎥ + ⎢ ⎥ + ⎢ √6 -√2 -√2 √6⎥ + ⎢ ── ──── ──── ──⎥ + ⎢ 4 4 4 4 ⎥ + ⎢ ⎥ + ⎢-√2 √6 -√6 √2⎥ + ⎢──── ── ──── ──⎥ + ⎣ 4 4 4 4 ⎦ + + >>> pprint(wigner_d_small(4*half, beta).subs({beta:pi/2}), + ... use_unicode=True) + ⎡ √6 ⎤ + ⎢1/4 1/2 ── 1/2 1/4⎥ + ⎢ 4 ⎥ + ⎢ ⎥ + ⎢-1/2 -1/2 0 1/2 1/2⎥ + ⎢ ⎥ + ⎢ √6 √6 ⎥ + ⎢ ── 0 -1/2 0 ── ⎥ + ⎢ 4 4 ⎥ + ⎢ ⎥ + ⎢-1/2 1/2 0 -1/2 1/2⎥ + ⎢ ⎥ + ⎢ √6 ⎥ + ⎢1/4 -1/2 ── -1/2 1/4⎥ + ⎣ 4 ⎦ + + """ + M = [J-i for i in range(2*J+1)] + d = zeros(2*J+1) + + # Mi corresponds to Edmonds' $m'$, and Mj to $m$. + for i, Mi in enumerate(M): + for j, Mj in enumerate(M): + + # We get the maximum and minimum value of sigma. + sigmamax = min([J-Mi, J-Mj]) + sigmamin = max([0, -Mi-Mj]) + + dij = sqrt(factorial(J+Mi)*factorial(J-Mi) / + factorial(J+Mj)/factorial(J-Mj)) + terms = [(-1)**(J-Mi-s) * + binomial(J+Mj, J-Mi-s) * + binomial(J-Mj, s) * + cos(beta/2)**(2*s+Mi+Mj) * + sin(beta/2)**(2*J-2*s-Mj-Mi) + for s in range(sigmamin, sigmamax+1)] + + d[i, j] = dij*Add(*terms) + + return ImmutableMatrix(d) + + +def wigner_d(J, alpha, beta, gamma): + """Return the Wigner D matrix for angular momentum J. + + Explanation + =========== + + J : + An integer, half-integer, or SymPy symbol for the total angular + momentum of the angular momentum space being rotated. + alpha, beta, gamma - Real numbers representing the Euler. + Angles of rotation about the so-called figure axis, line of nodes, + and vertical. See [Edmonds74]_, however note that the symbols alpha + and gamma are swapped in this implementation. + + Returns + ======= + + A matrix representing the corresponding Euler angle rotation (in the basis + of eigenvectors of `J_z`). + + .. math :: + \\mathcal{D}_{\\alpha \\beta \\gamma} = + \\exp\\big( \\frac{i\\alpha}{\\hbar} J_z\\big) + \\exp\\big( \\frac{i\\beta}{\\hbar} J_y\\big) + \\exp\\big( \\frac{i\\gamma}{\\hbar} J_z\\big) + + such that + + .. math :: + \\mathcal{D}^{(J)}_{m',m}(\\alpha, \\beta, \\gamma) = + \\mathtt{wigner_d(J, alpha, beta, gamma)[J-mprime,J-m]} + + The components are calculated using the general form [Edmonds74]_, + equation 4.1.12, however note that the angles alpha and gamma are swapped + in this implementation. + + Examples + ======== + + The simplest possible example: + + >>> from sympy.physics.wigner import wigner_d + >>> from sympy import Integer, symbols, pprint + >>> half = 1/Integer(2) + >>> alpha, beta, gamma = symbols("alpha, beta, gamma", real=True) + >>> pprint(wigner_d(half, alpha, beta, gamma), use_unicode=True) + ⎡ ⅈ⋅α ⅈ⋅γ ⅈ⋅α -ⅈ⋅γ ⎤ + ⎢ ─── ─── ─── ───── ⎥ + ⎢ 2 2 ⎛β⎞ 2 2 ⎛β⎞ ⎥ + ⎢ ℯ ⋅ℯ ⋅cos⎜─⎟ ℯ ⋅ℯ ⋅sin⎜─⎟ ⎥ + ⎢ ⎝2⎠ ⎝2⎠ ⎥ + ⎢ ⎥ + ⎢ -ⅈ⋅α ⅈ⋅γ -ⅈ⋅α -ⅈ⋅γ ⎥ + ⎢ ───── ─── ───── ───── ⎥ + ⎢ 2 2 ⎛β⎞ 2 2 ⎛β⎞⎥ + ⎢-ℯ ⋅ℯ ⋅sin⎜─⎟ ℯ ⋅ℯ ⋅cos⎜─⎟⎥ + ⎣ ⎝2⎠ ⎝2⎠⎦ + + """ + d = wigner_d_small(J, beta) + M = [J-i for i in range(2*J+1)] + # Mi corresponds to Edmonds' $m'$, and Mj to $m$. + D = [[exp(I*Mi*alpha)*d[i, j]*exp(I*Mj*gamma) + for j, Mj in enumerate(M)] for i, Mi in enumerate(M)] + return ImmutableMatrix(D) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/__init__.py b/.venv/lib/python3.13/site-packages/sympy/simplify/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0619d1c3ebbd6c6a7d663093c7ed2202114148af --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/__init__.py @@ -0,0 +1,60 @@ +"""The module helps converting SymPy expressions into shorter forms of them. + +for example: +the expression E**(pi*I) will be converted into -1 +the expression (x+x)**2 will be converted into 4*x**2 +""" +from .simplify import (simplify, hypersimp, hypersimilar, + logcombine, separatevars, posify, besselsimp, kroneckersimp, + signsimp, nsimplify) + +from .fu import FU, fu + +from .sqrtdenest import sqrtdenest + +from .cse_main import cse + +from .epathtools import epath, EPath + +from .hyperexpand import hyperexpand + +from .radsimp import collect, rcollect, radsimp, collect_const, fraction, numer, denom + +from .trigsimp import trigsimp, exptrigsimp + +from .powsimp import powsimp, powdenest + +from .combsimp import combsimp + +from .gammasimp import gammasimp + +from .ratsimp import ratsimp, ratsimpmodprime + +__all__ = [ + 'simplify', 'hypersimp', 'hypersimilar', 'logcombine', 'separatevars', + 'posify', 'besselsimp', 'kroneckersimp', 'signsimp', + 'nsimplify', + + 'FU', 'fu', + + 'sqrtdenest', + + 'cse', + + 'epath', 'EPath', + + 'hyperexpand', + + 'collect', 'rcollect', 'radsimp', 'collect_const', 'fraction', 'numer', + 'denom', + + 'trigsimp', 'exptrigsimp', + + 'powsimp', 'powdenest', + + 'combsimp', + + 'gammasimp', + + 'ratsimp', 'ratsimpmodprime', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/_cse_diff.py b/.venv/lib/python3.13/site-packages/sympy/simplify/_cse_diff.py new file mode 100644 index 0000000000000000000000000000000000000000..3496ad3b31a4f45312cac002429be40aa9aa0868 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/_cse_diff.py @@ -0,0 +1,291 @@ +"""Module for differentiation using CSE.""" + +from sympy import cse, Matrix, Derivative, MatrixBase +from sympy.utilities.iterables import iterable + + +def _remove_cse_from_derivative(replacements, reduced_expressions): + """ + This function is designed to postprocess the output of a common subexpression + elimination (CSE) operation. Specifically, it removes any CSE replacement + symbols from the arguments of ``Derivative`` terms in the expression. This + is necessary to ensure that the forward Jacobian function correctly handles + derivative terms. + + Parameters + ========== + + replacements : list of (Symbol, expression) pairs + Replacement symbols and relative common subexpressions that have been + replaced during a CSE operation. + + reduced_expressions : list of SymPy expressions + The reduced expressions with all the replacements from the + replacements list above. + + Returns + ======= + + processed_replacements : list of (Symbol, expression) pairs + Processed replacement list, in the same format of the + ``replacements`` input list. + + processed_reduced : list of SymPy expressions + Processed reduced list, in the same format of the + ``reduced_expressions`` input list. + """ + + def traverse(node, repl_dict): + if isinstance(node, Derivative): + return replace_all(node, repl_dict) + if not node.args: + return node + new_args = [traverse(arg, repl_dict) for arg in node.args] + return node.func(*new_args) + + def replace_all(node, repl_dict): + result = node + while True: + free_symbols = result.free_symbols + symbols_dict = {k: repl_dict[k] for k in free_symbols if k in repl_dict} + if not symbols_dict: + break + result = result.xreplace(symbols_dict) + return result + + repl_dict = dict(replacements) + processed_replacements = [ + (rep_sym, traverse(sub_exp, repl_dict)) + for rep_sym, sub_exp in replacements + ] + processed_reduced = [ + red_exp.__class__([traverse(exp, repl_dict) for exp in red_exp]) + for red_exp in reduced_expressions + ] + + return processed_replacements, processed_reduced + + +def _forward_jacobian_cse(replacements, reduced_expr, wrt): + """ + Core function to compute the Jacobian of an input Matrix of expressions + through forward accumulation. Takes directly the output of a CSE operation + (replacements and reduced_expr), and an iterable of variables (wrt) with + respect to which to differentiate the reduced expression and returns the + reduced Jacobian matrix and the ``replacements`` list. + + The function also returns a list of precomputed free symbols for each + subexpression, which are useful in the substitution process. + + Parameters + ========== + + replacements : list of (Symbol, expression) pairs + Replacement symbols and relative common subexpressions that have been + replaced during a CSE operation. + + reduced_expr : list of SymPy expressions + The reduced expressions with all the replacements from the + replacements list above. + + wrt : iterable + Iterable of expressions with respect to which to compute the + Jacobian matrix. + + Returns + ======= + + replacements : list of (Symbol, expression) pairs + Replacement symbols and relative common subexpressions that have been + replaced during a CSE operation. Compared to the input replacement list, + the output one doesn't contain replacement symbols inside + ``Derivative``'s arguments. + + jacobian : list of SymPy expressions + The list only contains one element, which is the Jacobian matrix with + elements in reduced form (replacement symbols are present). + + precomputed_fs: list + List of sets, which store the free symbols present in each sub-expression. + Useful in the substitution process. + """ + + if not isinstance(reduced_expr[0], MatrixBase): + raise TypeError("``expr`` must be of matrix type") + + if not (reduced_expr[0].shape[0] == 1 or reduced_expr[0].shape[1] == 1): + raise TypeError("``expr`` must be a row or a column matrix") + + if not iterable(wrt): + raise TypeError("``wrt`` must be an iterable of variables") + + elif not isinstance(wrt, MatrixBase): + wrt = Matrix(wrt) + + if not (wrt.shape[0] == 1 or wrt.shape[1] == 1): + raise TypeError("``wrt`` must be a row or a column matrix") + + replacements, reduced_expr = _remove_cse_from_derivative(replacements, reduced_expr) + + if replacements: + rep_sym, sub_expr = map(Matrix, zip(*replacements)) + else: + rep_sym, sub_expr = Matrix([]), Matrix([]) + + l_sub, l_wrt, l_red = len(sub_expr), len(wrt), len(reduced_expr[0]) + + f1 = reduced_expr[0].__class__.from_dok(l_red, l_wrt, + { + (i, j): diff_value + for i, r in enumerate(reduced_expr[0]) + for j, w in enumerate(wrt) + if (diff_value := r.diff(w)) != 0 + }, + ) + + if not replacements: + return [], [f1], [] + + f2 = Matrix.from_dok(l_red, l_sub, + { + (i, j): diff_value + for i, (r, fs) in enumerate([(r, r.free_symbols) for r in reduced_expr[0]]) + for j, s in enumerate(rep_sym) + if s in fs and (diff_value := r.diff(s)) != 0 + }, + ) + + rep_sym_set = set(rep_sym) + precomputed_fs = [s.free_symbols & rep_sym_set for s in sub_expr ] + + c_matrix = Matrix.from_dok(1, l_wrt, + {(0, j): diff_value for j, w in enumerate(wrt) + if (diff_value := sub_expr[0].diff(w)) != 0}) + + for i in range(1, l_sub): + + bi_matrix = Matrix.from_dok(1, i, + {(0, j): diff_value for j in range(i + 1) + if rep_sym[j] in precomputed_fs[i] + and (diff_value := sub_expr[i].diff(rep_sym[j])) != 0}) + + ai_matrix = Matrix.from_dok(1, l_wrt, + {(0, j): diff_value for j, w in enumerate(wrt) + if (diff_value := sub_expr[i].diff(w)) != 0}) + + if bi_matrix._rep.nnz(): + ci_matrix = bi_matrix.multiply(c_matrix).add(ai_matrix) + c_matrix = Matrix.vstack(c_matrix, ci_matrix) + else: + c_matrix = Matrix.vstack(c_matrix, ai_matrix) + + jacobian = f2.multiply(c_matrix).add(f1) + jacobian = [reduced_expr[0].__class__(jacobian)] + + return replacements, jacobian, precomputed_fs + + +def _forward_jacobian_norm_in_cse_out(expr, wrt): + """ + Function to compute the Jacobian of an input Matrix of expressions through + forward accumulation. Takes a sympy Matrix of expressions (expr) as input + and an iterable of variables (wrt) with respect to which to compute the + Jacobian matrix. The matrix is returned in reduced form (containing + replacement symbols) along with the ``replacements`` list. + + The function also returns a list of precomputed free symbols for each + subexpression, which are useful in the substitution process. + + Parameters + ========== + + expr : Matrix + The vector to be differentiated. + + wrt : iterable + The vector with respect to which to perform the differentiation. + Can be a matrix or an iterable of variables. + + Returns + ======= + + replacements : list of (Symbol, expression) pairs + Replacement symbols and relative common subexpressions that have been + replaced during a CSE operation. The output replacement list doesn't + contain replacement symbols inside ``Derivative``'s arguments. + + jacobian : list of SymPy expressions + The list only contains one element, which is the Jacobian matrix with + elements in reduced form (replacement symbols are present). + + precomputed_fs: list + List of sets, which store the free symbols present in each + sub-expression. Useful in the substitution process. + """ + + replacements, reduced_expr = cse(expr) + replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt) + + return replacements, jacobian, precomputed_fs + + +def _forward_jacobian(expr, wrt): + """ + Function to compute the Jacobian of an input Matrix of expressions through + forward accumulation. Takes a sympy Matrix of expressions (expr) as input + and an iterable of variables (wrt) with respect to which to compute the + Jacobian matrix. + + Explanation + =========== + + Expressions often contain repeated subexpressions. Using a tree structure, + these subexpressions are duplicated and differentiated multiple times, + leading to inefficiency. + + Instead, if a data structure called a directed acyclic graph (DAG) is used + then each of these repeated subexpressions will only exist a single time. + This function uses a combination of representing the expression as a DAG and + a forward accumulation algorithm (repeated application of the chain rule + symbolically) to more efficiently calculate the Jacobian matrix of a target + expression ``expr`` with respect to an expression or set of expressions + ``wrt``. + + Note that this function is intended to improve performance when + differentiating large expressions that contain many common subexpressions. + For small and simple expressions it is likely less performant than using + SymPy's standard differentiation functions and methods. + + Parameters + ========== + + expr : Matrix + The vector to be differentiated. + + wrt : iterable + The vector with respect to which to do the differentiation. + Can be a matrix or an iterable of variables. + + See Also + ======== + + Direct Acyclic Graph : https://en.wikipedia.org/wiki/Directed_acyclic_graph + """ + + replacements, reduced_expr = cse(expr) + + if replacements: + rep_sym, _ = map(Matrix, zip(*replacements)) + else: + rep_sym = Matrix([]) + + replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt) + + if not replacements: return jacobian[0] + + sub_rep = dict(replacements) + for i, ik in enumerate(precomputed_fs): + sub_dict = {j: sub_rep[j] for j in ik} + sub_rep[rep_sym[i]] = sub_rep[rep_sym[i]].xreplace(sub_dict) + + return jacobian[0].xreplace(sub_rep) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/combsimp.py b/.venv/lib/python3.13/site-packages/sympy/simplify/combsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..8b0b3cefcba11b4b7759b7d3ec3c2d4415cfd849 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/combsimp.py @@ -0,0 +1,114 @@ +from sympy.core import Mul +from sympy.core.function import count_ops +from sympy.core.traversal import preorder_traversal, bottom_up +from sympy.functions.combinatorial.factorials import binomial, factorial +from sympy.functions import gamma +from sympy.simplify.gammasimp import gammasimp, _gammasimp + +from sympy.utilities.timeutils import timethis + + +@timethis('combsimp') +def combsimp(expr): + r""" + Simplify combinatorial expressions. + + Explanation + =========== + + This function takes as input an expression containing factorials, + binomials, Pochhammer symbol and other "combinatorial" functions, + and tries to minimize the number of those functions and reduce + the size of their arguments. + + The algorithm works by rewriting all combinatorial functions as + gamma functions and applying gammasimp() except simplification + steps that may make an integer argument non-integer. See docstring + of gammasimp for more information. + + Then it rewrites expression in terms of factorials and binomials by + rewriting gammas as factorials and converting (a+b)!/a!b! into + binomials. + + If expression has gamma functions or combinatorial functions + with non-integer argument, it is automatically passed to gammasimp. + + Examples + ======== + + >>> from sympy.simplify import combsimp + >>> from sympy import factorial, binomial, symbols + >>> n, k = symbols('n k', integer = True) + + >>> combsimp(factorial(n)/factorial(n - 3)) + n*(n - 2)*(n - 1) + >>> combsimp(binomial(n+1, k+1)/binomial(n, k)) + (n + 1)/(k + 1) + + """ + + expr = expr.rewrite(gamma, piecewise=False) + if any(isinstance(node, gamma) and not node.args[0].is_integer + for node in preorder_traversal(expr)): + return gammasimp(expr) + + expr = _gammasimp(expr, as_comb = True) + expr = _gamma_as_comb(expr) + return expr + + +def _gamma_as_comb(expr): + """ + Helper function for combsimp. + + Rewrites expression in terms of factorials and binomials + """ + + expr = expr.rewrite(factorial) + + def f(rv): + if not rv.is_Mul: + return rv + rvd = rv.as_powers_dict() + nd_fact_args = [[], []] # numerator, denominator + + for k in rvd: + if isinstance(k, factorial) and rvd[k].is_Integer: + if rvd[k].is_positive: + nd_fact_args[0].extend([k.args[0]]*rvd[k]) + else: + nd_fact_args[1].extend([k.args[0]]*-rvd[k]) + rvd[k] = 0 + if not nd_fact_args[0] or not nd_fact_args[1]: + return rv + + hit = False + for m in range(2): + i = 0 + while i < len(nd_fact_args[m]): + ai = nd_fact_args[m][i] + for j in range(i + 1, len(nd_fact_args[m])): + aj = nd_fact_args[m][j] + + sum = ai + aj + if sum in nd_fact_args[1 - m]: + hit = True + + nd_fact_args[1 - m].remove(sum) + del nd_fact_args[m][j] + del nd_fact_args[m][i] + + rvd[binomial(sum, ai if count_ops(ai) < + count_ops(aj) else aj)] += ( + -1 if m == 0 else 1) + break + else: + i += 1 + + if hit: + return Mul(*([k**rvd[k] for k in rvd] + [factorial(k) + for k in nd_fact_args[0]]))/Mul(*[factorial(k) + for k in nd_fact_args[1]]) + return rv + + return bottom_up(expr, f) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/cse_main.py b/.venv/lib/python3.13/site-packages/sympy/simplify/cse_main.py new file mode 100644 index 0000000000000000000000000000000000000000..bcd1b2e50adae8c3d3400d6c489e63a44ae1e59b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/cse_main.py @@ -0,0 +1,945 @@ +""" Tools for doing common subexpression elimination. +""" +from collections import defaultdict + +from sympy.core import Basic, Mul, Add, Pow, sympify +from sympy.core.containers import Tuple, OrderedSet +from sympy.core.exprtools import factor_terms +from sympy.core.singleton import S +from sympy.core.sorting import ordered +from sympy.core.symbol import symbols, Symbol +from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix, + SparseMatrix, ImmutableSparseMatrix) +from sympy.matrices.expressions import (MatrixExpr, MatrixSymbol, MatMul, + MatAdd, MatPow, Inverse) +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.polys.rootoftools import RootOf +from sympy.utilities.iterables import numbered_symbols, sift, \ + topological_sort, iterable + +from . import cse_opts + +# (preprocessor, postprocessor) pairs which are commonly useful. They should +# each take a SymPy expression and return a possibly transformed expression. +# When used in the function ``cse()``, the target expressions will be transformed +# by each of the preprocessor functions in order. After the common +# subexpressions are eliminated, each resulting expression will have the +# postprocessor functions transform them in *reverse* order in order to undo the +# transformation if necessary. This allows the algorithm to operate on +# a representation of the expressions that allows for more optimization +# opportunities. +# ``None`` can be used to specify no transformation for either the preprocessor or +# postprocessor. + + +basic_optimizations = [(cse_opts.sub_pre, cse_opts.sub_post), + (factor_terms, None)] + +# sometimes we want the output in a different format; non-trivial +# transformations can be put here for users +# =============================================================== + + +def reps_toposort(r): + """Sort replacements ``r`` so (k1, v1) appears before (k2, v2) + if k2 is in v1's free symbols. This orders items in the + way that cse returns its results (hence, in order to use the + replacements in a substitution option it would make sense + to reverse the order). + + Examples + ======== + + >>> from sympy.simplify.cse_main import reps_toposort + >>> from sympy.abc import x, y + >>> from sympy import Eq + >>> for l, r in reps_toposort([(x, y + 1), (y, 2)]): + ... print(Eq(l, r)) + ... + Eq(y, 2) + Eq(x, y + 1) + + """ + r = sympify(r) + E = [] + for c1, (k1, v1) in enumerate(r): + for c2, (k2, v2) in enumerate(r): + if k1 in v2.free_symbols: + E.append((c1, c2)) + return [r[i] for i in topological_sort((range(len(r)), E))] + + +def cse_separate(r, e): + """Move expressions that are in the form (symbol, expr) out of the + expressions and sort them into the replacements using the reps_toposort. + + Examples + ======== + + >>> from sympy.simplify.cse_main import cse_separate + >>> from sympy.abc import x, y, z + >>> from sympy import cos, exp, cse, Eq, symbols + >>> x0, x1 = symbols('x:2') + >>> eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1)) + >>> cse([eq, Eq(x, z + 1), z - 2], postprocess=cse_separate) in [ + ... [[(x0, y + 1), (x, z + 1), (x1, x + 1)], + ... [x1 + exp(x1/x0) + cos(x0), z - 2]], + ... [[(x1, y + 1), (x, z + 1), (x0, x + 1)], + ... [x0 + exp(x0/x1) + cos(x1), z - 2]]] + ... + True + """ + d = sift(e, lambda w: w.is_Equality and w.lhs.is_Symbol) + r = r + [w.args for w in d[True]] + e = d[False] + return [reps_toposort(r), e] + + +def cse_release_variables(r, e): + """ + Return tuples giving ``(a, b)`` where ``a`` is a symbol and ``b`` is + either an expression or None. The value of None is used when a + symbol is no longer needed for subsequent expressions. + + Use of such output can reduce the memory footprint of lambdified + expressions that contain large, repeated subexpressions. + + Examples + ======== + + >>> from sympy import cse + >>> from sympy.simplify.cse_main import cse_release_variables + >>> from sympy.abc import x, y + >>> eqs = [(x + y - 1)**2, x, x + y, (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)] + >>> defs, rvs = cse_release_variables(*cse(eqs)) + >>> for i in defs: + ... print(i) + ... + (x0, x + y) + (x1, (x0 - 1)**2) + (x2, 2*x + 1) + (_3, x0/x2 + x1) + (_4, x2**x0) + (x2, None) + (_0, x1) + (x1, None) + (_2, x0) + (x0, None) + (_1, x) + >>> print(rvs) + (_0, _1, _2, _3, _4) + """ + if not r: + return r, e + + s, p = zip(*r) + esyms = symbols('_:%d' % len(e)) + syms = list(esyms) + s = list(s) + in_use = set(s) + p = list(p) + # sort e so those with most sub-expressions appear first + e = [(e[i], syms[i]) for i in range(len(e))] + e, syms = zip(*sorted(e, + key=lambda x: -sum(p[s.index(i)].count_ops() + for i in x[0].free_symbols & in_use))) + syms = list(syms) + p += e + rv = [] + i = len(p) - 1 + while i >= 0: + _p = p.pop() + c = in_use & _p.free_symbols + if c: # sorting for canonical results + rv.extend([(s, None) for s in sorted(c, key=str)]) + if i >= len(r): + rv.append((syms.pop(), _p)) + else: + rv.append((s[i], _p)) + in_use -= c + i -= 1 + rv.reverse() + return rv, esyms + + +# ====end of cse postprocess idioms=========================== + + +def preprocess_for_cse(expr, optimizations): + """ Preprocess an expression to optimize for common subexpression + elimination. + + Parameters + ========== + + expr : SymPy expression + The target expression to optimize. + optimizations : list of (callable, callable) pairs + The (preprocessor, postprocessor) pairs. + + Returns + ======= + + expr : SymPy expression + The transformed expression. + """ + for pre, post in optimizations: + if pre is not None: + expr = pre(expr) + return expr + + +def postprocess_for_cse(expr, optimizations): + """Postprocess an expression after common subexpression elimination to + return the expression to canonical SymPy form. + + Parameters + ========== + + expr : SymPy expression + The target expression to transform. + optimizations : list of (callable, callable) pairs, optional + The (preprocessor, postprocessor) pairs. The postprocessors will be + applied in reversed order to undo the effects of the preprocessors + correctly. + + Returns + ======= + + expr : SymPy expression + The transformed expression. + """ + for pre, post in reversed(optimizations): + if post is not None: + expr = post(expr) + return expr + + +class FuncArgTracker: + """ + A class which manages a mapping from functions to arguments and an inverse + mapping from arguments to functions. + """ + + def __init__(self, funcs): + # To minimize the number of symbolic comparisons, all function arguments + # get assigned a value number. + self.value_numbers = {} + self.value_number_to_value = [] + + # Both of these maps use integer indices for arguments / functions. + self.arg_to_funcset = [] + self.func_to_argset = [] + + for func_i, func in enumerate(funcs): + func_argset = OrderedSet() + + for func_arg in func.args: + arg_number = self.get_or_add_value_number(func_arg) + func_argset.add(arg_number) + self.arg_to_funcset[arg_number].add(func_i) + + self.func_to_argset.append(func_argset) + + def get_args_in_value_order(self, argset): + """ + Return the list of arguments in sorted order according to their value + numbers. + """ + return [self.value_number_to_value[argn] for argn in sorted(argset)] + + def get_or_add_value_number(self, value): + """ + Return the value number for the given argument. + """ + nvalues = len(self.value_numbers) + value_number = self.value_numbers.setdefault(value, nvalues) + if value_number == nvalues: + self.value_number_to_value.append(value) + self.arg_to_funcset.append(OrderedSet()) + return value_number + + def stop_arg_tracking(self, func_i): + """ + Remove the function func_i from the argument to function mapping. + """ + for arg in self.func_to_argset[func_i]: + self.arg_to_funcset[arg].remove(func_i) + + + def get_common_arg_candidates(self, argset, min_func_i=0): + """Return a dict whose keys are function numbers. The entries of the dict are + the number of arguments said function has in common with + ``argset``. Entries have at least 2 items in common. All keys have + value at least ``min_func_i``. + """ + count_map = defaultdict(lambda: 0) + if not argset: + return count_map + + funcsets = [self.arg_to_funcset[arg] for arg in argset] + # As an optimization below, we handle the largest funcset separately from + # the others. + largest_funcset = max(funcsets, key=len) + + for funcset in funcsets: + if largest_funcset is funcset: + continue + for func_i in funcset: + if func_i >= min_func_i: + count_map[func_i] += 1 + + # We pick the smaller of the two containers (count_map, largest_funcset) + # to iterate over to reduce the number of iterations needed. + (smaller_funcs_container, + larger_funcs_container) = sorted( + [largest_funcset, count_map], + key=len) + + for func_i in smaller_funcs_container: + # Not already in count_map? It can't possibly be in the output, so + # skip it. + if count_map[func_i] < 1: + continue + + if func_i in larger_funcs_container: + count_map[func_i] += 1 + + return {k: v for k, v in count_map.items() if v >= 2} + + def get_subset_candidates(self, argset, restrict_to_funcset=None): + """ + Return a set of functions each of which whose argument list contains + ``argset``, optionally filtered only to contain functions in + ``restrict_to_funcset``. + """ + iarg = iter(argset) + + indices = OrderedSet( + fi for fi in self.arg_to_funcset[next(iarg)]) + + if restrict_to_funcset is not None: + indices &= restrict_to_funcset + + for arg in iarg: + indices &= self.arg_to_funcset[arg] + + return indices + + def update_func_argset(self, func_i, new_argset): + """ + Update a function with a new set of arguments. + """ + new_args = OrderedSet(new_argset) + old_args = self.func_to_argset[func_i] + + for deleted_arg in old_args - new_args: + self.arg_to_funcset[deleted_arg].remove(func_i) + for added_arg in new_args - old_args: + self.arg_to_funcset[added_arg].add(func_i) + + self.func_to_argset[func_i].clear() + self.func_to_argset[func_i].update(new_args) + + +class Unevaluated: + + def __init__(self, func, args): + self.func = func + self.args = args + + def __str__(self): + return "Uneval<{}>({})".format( + self.func, ", ".join(str(a) for a in self.args)) + + def as_unevaluated_basic(self): + return self.func(*self.args, evaluate=False) + + @property + def free_symbols(self): + return set().union(*[a.free_symbols for a in self.args]) + + __repr__ = __str__ + + +def match_common_args(func_class, funcs, opt_subs): + """ + Recognize and extract common subexpressions of function arguments within a + set of function calls. For instance, for the following function calls:: + + x + z + y + sin(x + y) + + this will extract a common subexpression of `x + y`:: + + w = x + y + w + z + sin(w) + + The function we work with is assumed to be associative and commutative. + + Parameters + ========== + + func_class: class + The function class (e.g. Add, Mul) + funcs: list of functions + A list of function calls. + opt_subs: dict + A dictionary of substitutions which this function may update. + """ + + # Sort to ensure that whole-function subexpressions come before the items + # that use them. + funcs = sorted(funcs, key=lambda f: len(f.args)) + arg_tracker = FuncArgTracker(funcs) + + changed = OrderedSet() + + for i in range(len(funcs)): + common_arg_candidates_counts = arg_tracker.get_common_arg_candidates( + arg_tracker.func_to_argset[i], min_func_i=i + 1) + + # Sort the candidates in order of match size. + # This makes us try combining smaller matches first. + common_arg_candidates = OrderedSet(sorted( + common_arg_candidates_counts.keys(), + key=lambda k: (common_arg_candidates_counts[k], k))) + + while common_arg_candidates: + j = common_arg_candidates.pop(last=False) + + com_args = arg_tracker.func_to_argset[i].intersection( + arg_tracker.func_to_argset[j]) + + if len(com_args) <= 1: + # This may happen if a set of common arguments was already + # combined in a previous iteration. + continue + + # For all sets, replace the common symbols by the function + # over them, to allow recursive matches. + + diff_i = arg_tracker.func_to_argset[i].difference(com_args) + if diff_i: + # com_func needs to be unevaluated to allow for recursive matches. + com_func = Unevaluated( + func_class, arg_tracker.get_args_in_value_order(com_args)) + com_func_number = arg_tracker.get_or_add_value_number(com_func) + arg_tracker.update_func_argset(i, diff_i | OrderedSet([com_func_number])) + changed.add(i) + else: + # Treat the whole expression as a CSE. + # + # The reason this needs to be done is somewhat subtle. Within + # tree_cse(), to_eliminate only contains expressions that are + # seen more than once. The problem is unevaluated expressions + # do not compare equal to the evaluated equivalent. So + # tree_cse() won't mark funcs[i] as a CSE if we use an + # unevaluated version. + com_func_number = arg_tracker.get_or_add_value_number(funcs[i]) + + diff_j = arg_tracker.func_to_argset[j].difference(com_args) + arg_tracker.update_func_argset(j, diff_j | OrderedSet([com_func_number])) + changed.add(j) + + for k in arg_tracker.get_subset_candidates( + com_args, common_arg_candidates): + diff_k = arg_tracker.func_to_argset[k].difference(com_args) + arg_tracker.update_func_argset(k, diff_k | OrderedSet([com_func_number])) + changed.add(k) + + if i in changed: + opt_subs[funcs[i]] = Unevaluated(func_class, + arg_tracker.get_args_in_value_order(arg_tracker.func_to_argset[i])) + + arg_tracker.stop_arg_tracking(i) + + +def opt_cse(exprs, order='canonical'): + """Find optimization opportunities in Adds, Muls, Pows and negative + coefficient Muls. + + Parameters + ========== + + exprs : list of SymPy expressions + The expressions to optimize. + order : string, 'none' or 'canonical' + The order by which Mul and Add arguments are processed. For large + expressions where speed is a concern, use the setting order='none'. + + Returns + ======= + + opt_subs : dictionary of expression substitutions + The expression substitutions which can be useful to optimize CSE. + + Examples + ======== + + >>> from sympy.simplify.cse_main import opt_cse + >>> from sympy.abc import x + >>> opt_subs = opt_cse([x**-2]) + >>> k, v = list(opt_subs.keys())[0], list(opt_subs.values())[0] + >>> print((k, v.as_unevaluated_basic())) + (x**(-2), 1/(x**2)) + """ + opt_subs = {} + + adds = OrderedSet() + muls = OrderedSet() + + seen_subexp = set() + collapsible_subexp = set() + + def _find_opts(expr): + + if not isinstance(expr, (Basic, Unevaluated)): + return + + if expr.is_Atom or expr.is_Order: + return + + if iterable(expr): + list(map(_find_opts, expr)) + return + + if expr in seen_subexp: + return expr + seen_subexp.add(expr) + + list(map(_find_opts, expr.args)) + + if not isinstance(expr, MatrixExpr) and expr.could_extract_minus_sign(): + # XXX -expr does not always work rigorously for some expressions + # containing UnevaluatedExpr. + # https://github.com/sympy/sympy/issues/24818 + if isinstance(expr, Add): + neg_expr = Add(*(-i for i in expr.args)) + else: + neg_expr = -expr + + if not neg_expr.is_Atom: + opt_subs[expr] = Unevaluated(Mul, (S.NegativeOne, neg_expr)) + seen_subexp.add(neg_expr) + expr = neg_expr + + if isinstance(expr, (Mul, MatMul)): + if len(expr.args) == 1: + collapsible_subexp.add(expr) + else: + muls.add(expr) + + elif isinstance(expr, (Add, MatAdd)): + if len(expr.args) == 1: + collapsible_subexp.add(expr) + else: + adds.add(expr) + + elif isinstance(expr, Inverse): + # Do not want to treat `Inverse` as a `MatPow` + pass + + elif isinstance(expr, (Pow, MatPow)): + base, exp = expr.base, expr.exp + if exp.could_extract_minus_sign(): + opt_subs[expr] = Unevaluated(Pow, (Pow(base, -exp), -1)) + + for e in exprs: + if isinstance(e, (Basic, Unevaluated)): + _find_opts(e) + + # Handle collapsing of multinary operations with single arguments + edges = [(s, s.args[0]) for s in collapsible_subexp + if s.args[0] in collapsible_subexp] + for e in reversed(topological_sort((collapsible_subexp, edges))): + opt_subs[e] = opt_subs.get(e.args[0], e.args[0]) + + # split muls into commutative + commutative_muls = OrderedSet() + for m in muls: + c, nc = m.args_cnc(cset=False) + if c: + c_mul = m.func(*c) + if nc: + if c_mul == 1: + new_obj = m.func(*nc) + else: + if isinstance(m, MatMul): + new_obj = m.func(c_mul, *nc, evaluate=False) + else: + new_obj = m.func(c_mul, m.func(*nc), evaluate=False) + opt_subs[m] = new_obj + if len(c) > 1: + commutative_muls.add(c_mul) + + match_common_args(Add, adds, opt_subs) + match_common_args(Mul, commutative_muls, opt_subs) + + return opt_subs + + +def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()): + """Perform raw CSE on expression tree, taking opt_subs into account. + + Parameters + ========== + + exprs : list of SymPy expressions + The expressions to reduce. + symbols : infinite iterator yielding unique Symbols + The symbols used to label the common subexpressions which are pulled + out. + opt_subs : dictionary of expression substitutions + The expressions to be substituted before any CSE action is performed. + order : string, 'none' or 'canonical' + The order by which Mul and Add arguments are processed. For large + expressions where speed is a concern, use the setting order='none'. + ignore : iterable of Symbols + Substitutions containing any Symbol from ``ignore`` will be ignored. + """ + if opt_subs is None: + opt_subs = {} + + ## Find repeated sub-expressions + + to_eliminate = set() + + seen_subexp = set() + excluded_symbols = set() + + def _find_repeated(expr): + if not isinstance(expr, (Basic, Unevaluated)): + return + + if isinstance(expr, RootOf): + return + + if isinstance(expr, Basic) and ( + expr.is_Atom or + expr.is_Order or + isinstance(expr, (MatrixSymbol, MatrixElement))): + if expr.is_Symbol: + excluded_symbols.add(expr.name) + return + + if iterable(expr): + args = expr + + else: + if expr in seen_subexp: + for ign in ignore: + if ign in expr.free_symbols: + break + else: + to_eliminate.add(expr) + return + + seen_subexp.add(expr) + + if expr in opt_subs: + expr = opt_subs[expr] + + args = expr.args + + list(map(_find_repeated, args)) + + for e in exprs: + if isinstance(e, Basic): + _find_repeated(e) + + ## Rebuild tree + + # Remove symbols from the generator that conflict with names in the expressions. + symbols = (_ for _ in symbols if _.name not in excluded_symbols) + + replacements = [] + + subs = {} + + def _rebuild(expr): + if not isinstance(expr, (Basic, Unevaluated)): + return expr + + if not expr.args: + return expr + + if iterable(expr): + new_args = [_rebuild(arg) for arg in expr.args] + return expr.func(*new_args) + + if expr in subs: + return subs[expr] + + orig_expr = expr + if expr in opt_subs: + expr = opt_subs[expr] + + # If enabled, parse Muls and Adds arguments by order to ensure + # replacement order independent from hashes + if order != 'none': + if isinstance(expr, (Mul, MatMul)): + c, nc = expr.args_cnc() + if c == [1]: + args = nc + else: + args = list(ordered(c)) + nc + elif isinstance(expr, (Add, MatAdd)): + args = list(ordered(expr.args)) + else: + args = expr.args + else: + args = expr.args + + new_args = list(map(_rebuild, args)) + if isinstance(expr, Unevaluated) or new_args != args: + new_expr = expr.func(*new_args) + else: + new_expr = expr + + if orig_expr in to_eliminate: + try: + sym = next(symbols) + except StopIteration: + raise ValueError("Symbols iterator ran out of symbols.") + + if isinstance(orig_expr, MatrixExpr): + sym = MatrixSymbol(sym.name, orig_expr.rows, + orig_expr.cols) + + subs[orig_expr] = sym + replacements.append((sym, new_expr)) + return sym + + else: + return new_expr + + reduced_exprs = [] + for e in exprs: + if isinstance(e, Basic): + reduced_e = _rebuild(e) + else: + reduced_e = e + reduced_exprs.append(reduced_e) + return replacements, reduced_exprs + + +def cse(exprs, symbols=None, optimizations=None, postprocess=None, + order='canonical', ignore=(), list=True): + """ Perform common subexpression elimination on an expression. + + Parameters + ========== + + exprs : list of SymPy expressions, or a single SymPy expression + The expressions to reduce. + symbols : infinite iterator yielding unique Symbols + The symbols used to label the common subexpressions which are pulled + out. The ``numbered_symbols`` generator is useful. The default is a + stream of symbols of the form "x0", "x1", etc. This must be an + infinite iterator. + optimizations : list of (callable, callable) pairs + The (preprocessor, postprocessor) pairs of external optimization + functions. Optionally 'basic' can be passed for a set of predefined + basic optimizations. Such 'basic' optimizations were used by default + in old implementation, however they can be really slow on larger + expressions. Now, no pre or post optimizations are made by default. + postprocess : a function which accepts the two return values of cse and + returns the desired form of output from cse, e.g. if you want the + replacements reversed the function might be the following lambda: + lambda r, e: return reversed(r), e + order : string, 'none' or 'canonical' + The order by which Mul and Add arguments are processed. If set to + 'canonical', arguments will be canonically ordered. If set to 'none', + ordering will be faster but dependent on expressions hashes, thus + machine dependent and variable. For large expressions where speed is a + concern, use the setting order='none'. + ignore : iterable of Symbols + Substitutions containing any Symbol from ``ignore`` will be ignored. + list : bool, (default True) + Returns expression in list or else with same type as input (when False). + + Returns + ======= + + replacements : list of (Symbol, expression) pairs + All of the common subexpressions that were replaced. Subexpressions + earlier in this list might show up in subexpressions later in this + list. + reduced_exprs : list of SymPy expressions + The reduced expressions with all of the replacements above. + + Examples + ======== + + >>> from sympy import cse, SparseMatrix + >>> from sympy.abc import x, y, z, w + >>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3) + ([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3]) + + + List of expressions with recursive substitutions: + + >>> m = SparseMatrix([x + y, x + y + z]) + >>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m]) + ([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([ + [x0], + [x1]])]) + + Note: the type and mutability of input matrices is retained. + + >>> isinstance(_[1][-1], SparseMatrix) + True + + The user may disallow substitutions containing certain symbols: + + >>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,)) + ([(x0, x + 1)], [x0*y**2, 3*x0*y**2]) + + The default return value for the reduced expression(s) is a list, even if there is only + one expression. The `list` flag preserves the type of the input in the output: + + >>> cse(x) + ([], [x]) + >>> cse(x, list=False) + ([], x) + """ + if not list: + return _cse_homogeneous(exprs, + symbols=symbols, optimizations=optimizations, + postprocess=postprocess, order=order, ignore=ignore) + + if isinstance(exprs, (int, float)): + exprs = sympify(exprs) + + # Handle the case if just one expression was passed. + if isinstance(exprs, (Basic, MatrixBase)): + exprs = [exprs] + + copy = exprs + temp = [] + for e in exprs: + if isinstance(e, (Matrix, ImmutableMatrix)): + temp.append(Tuple(*e.flat())) + elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): + temp.append(Tuple(*e.todok().items())) + else: + temp.append(e) + exprs = temp + del temp + + if optimizations is None: + optimizations = [] + elif optimizations == 'basic': + optimizations = basic_optimizations + + # Preprocess the expressions to give us better optimization opportunities. + reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] + + if symbols is None: + symbols = numbered_symbols(cls=Symbol) + else: + # In case we get passed an iterable with an __iter__ method instead of + # an actual iterator. + symbols = iter(symbols) + + # Find other optimization opportunities. + opt_subs = opt_cse(reduced_exprs, order) + + # Main CSE algorithm. + replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs, + order, ignore) + + # Postprocess the expressions to return the expressions to canonical form. + exprs = copy + replacements = [(sym, postprocess_for_cse(subtree, optimizations)) + for sym, subtree in replacements] + reduced_exprs = [postprocess_for_cse(e, optimizations) + for e in reduced_exprs] + + # Get the matrices back + for i, e in enumerate(exprs): + if isinstance(e, (Matrix, ImmutableMatrix)): + reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i]) + if isinstance(e, ImmutableMatrix): + reduced_exprs[i] = reduced_exprs[i].as_immutable() + elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): + m = SparseMatrix(e.rows, e.cols, {}) + for k, v in reduced_exprs[i]: + m[k] = v + if isinstance(e, ImmutableSparseMatrix): + m = m.as_immutable() + reduced_exprs[i] = m + + if postprocess is None: + return replacements, reduced_exprs + + return postprocess(replacements, reduced_exprs) + + +def _cse_homogeneous(exprs, **kwargs): + """ + Same as ``cse`` but the ``reduced_exprs`` are returned + with the same type as ``exprs`` or a sympified version of the same. + + Parameters + ========== + + exprs : an Expr, iterable of Expr or dictionary with Expr values + the expressions in which repeated subexpressions will be identified + kwargs : additional arguments for the ``cse`` function + + Returns + ======= + + replacements : list of (Symbol, expression) pairs + All of the common subexpressions that were replaced. Subexpressions + earlier in this list might show up in subexpressions later in this + list. + reduced_exprs : list of SymPy expressions + The reduced expressions with all of the replacements above. + + Examples + ======== + + >>> from sympy.simplify.cse_main import cse + >>> from sympy import cos, Tuple, Matrix + >>> from sympy.abc import x + >>> output = lambda x: type(cse(x, list=False)[1]) + >>> output(1) + + >>> output('cos(x)') + + >>> output(cos(x)) + cos + >>> output(Tuple(1, x)) + + >>> output(Matrix([[1,0], [0,1]])) + + >>> output([1, x]) + + >>> output((1, x)) + + >>> output({1, x}) + + """ + if isinstance(exprs, str): + replacements, reduced_exprs = _cse_homogeneous( + sympify(exprs), **kwargs) + return replacements, repr(reduced_exprs) + if isinstance(exprs, (list, tuple, set)): + replacements, reduced_exprs = cse(exprs, **kwargs) + return replacements, type(exprs)(reduced_exprs) + if isinstance(exprs, dict): + keys = list(exprs.keys()) # In order to guarantee the order of the elements. + replacements, values = cse([exprs[k] for k in keys], **kwargs) + reduced_exprs = dict(zip(keys, values)) + return replacements, reduced_exprs + + try: + replacements, (reduced_exprs,) = cse(exprs, **kwargs) + except TypeError: # For example 'mpf' objects + return [], exprs + else: + return replacements, reduced_exprs diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/cse_opts.py b/.venv/lib/python3.13/site-packages/sympy/simplify/cse_opts.py new file mode 100644 index 0000000000000000000000000000000000000000..36a59857411de740ae47423442af88b118a3395d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/cse_opts.py @@ -0,0 +1,52 @@ +""" Optimizations of the expression tree representation for better CSE +opportunities. +""" +from sympy.core import Add, Basic, Mul +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.traversal import preorder_traversal + + +def sub_pre(e): + """ Replace y - x with -(x - y) if -1 can be extracted from y - x. + """ + # replacing Add, A, from which -1 can be extracted with -1*-A + adds = [a for a in e.atoms(Add) if a.could_extract_minus_sign()] + reps = {} + ignore = set() + for a in adds: + na = -a + if na.is_Mul: # e.g. MatExpr + ignore.add(a) + continue + reps[a] = Mul._from_args([S.NegativeOne, na]) + + e = e.xreplace(reps) + + # repeat again for persisting Adds but mark these with a leading 1, -1 + # e.g. y - x -> 1*-1*(x - y) + if isinstance(e, Basic): + negs = {} + for a in sorted(e.atoms(Add), key=default_sort_key): + if a in ignore: + continue + if a in reps: + negs[a] = reps[a] + elif a.could_extract_minus_sign(): + negs[a] = Mul._from_args([S.One, S.NegativeOne, -a]) + e = e.xreplace(negs) + return e + + +def sub_post(e): + """ Replace 1*-1*x with -x. + """ + replacements = [] + for node in preorder_traversal(e): + if isinstance(node, Mul) and \ + node.args[0] is S.One and node.args[1] is S.NegativeOne: + replacements.append((node, -Mul._from_args(node.args[2:]))) + for node, replacement in replacements: + e = e.xreplace({node: replacement}) + + return e diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/epathtools.py b/.venv/lib/python3.13/site-packages/sympy/simplify/epathtools.py new file mode 100644 index 0000000000000000000000000000000000000000..7be983ada63fd39d7d467acf9afd62b3a41a2d85 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/epathtools.py @@ -0,0 +1,352 @@ +"""Tools for manipulation of expressions using paths. """ + +from sympy.core import Basic + + +class EPath: + r""" + Manipulate expressions using paths. + + EPath grammar in EBNF notation:: + + literal ::= /[A-Za-z_][A-Za-z_0-9]*/ + number ::= /-?\d+/ + type ::= literal + attribute ::= literal "?" + all ::= "*" + slice ::= "[" number? (":" number? (":" number?)?)? "]" + range ::= all | slice + query ::= (type | attribute) ("|" (type | attribute))* + selector ::= range | query range? + path ::= "/" selector ("/" selector)* + + See the docstring of the epath() function. + + """ + + __slots__ = ("_path", "_epath") + + def __new__(cls, path): + """Construct new EPath. """ + if isinstance(path, EPath): + return path + + if not path: + raise ValueError("empty EPath") + + _path = path + + if path[0] == '/': + path = path[1:] + else: + raise NotImplementedError("non-root EPath") + + epath = [] + + for selector in path.split('/'): + selector = selector.strip() + + if not selector: + raise ValueError("empty selector") + + index = 0 + + for c in selector: + if c.isalnum() or c in ('_', '|', '?'): + index += 1 + else: + break + + attrs = [] + types = [] + + if index: + elements = selector[:index] + selector = selector[index:] + + for element in elements.split('|'): + element = element.strip() + + if not element: + raise ValueError("empty element") + + if element.endswith('?'): + attrs.append(element[:-1]) + else: + types.append(element) + + span = None + + if selector == '*': + pass + else: + if selector.startswith('['): + try: + i = selector.index(']') + except ValueError: + raise ValueError("expected ']', got EOL") + + _span, span = selector[1:i], [] + + if ':' not in _span: + span = int(_span) + else: + for elt in _span.split(':', 3): + if not elt: + span.append(None) + else: + span.append(int(elt)) + + span = slice(*span) + + selector = selector[i + 1:] + + if selector: + raise ValueError("trailing characters in selector") + + epath.append((attrs, types, span)) + + obj = object.__new__(cls) + + obj._path = _path + obj._epath = epath + + return obj + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._path) + + def _get_ordered_args(self, expr): + """Sort ``expr.args`` using printing order. """ + if expr.is_Add: + return expr.as_ordered_terms() + elif expr.is_Mul: + return expr.as_ordered_factors() + else: + return expr.args + + def _hasattrs(self, expr, attrs) -> bool: + """Check if ``expr`` has any of ``attrs``. """ + return all(hasattr(expr, attr) for attr in attrs) + + def _hastypes(self, expr, types): + """Check if ``expr`` is any of ``types``. """ + _types = [ cls.__name__ for cls in expr.__class__.mro() ] + return bool(set(_types).intersection(types)) + + def _has(self, expr, attrs, types): + """Apply ``_hasattrs`` and ``_hastypes`` to ``expr``. """ + if not (attrs or types): + return True + + if attrs and self._hasattrs(expr, attrs): + return True + + if types and self._hastypes(expr, types): + return True + + return False + + def apply(self, expr, func, args=None, kwargs=None): + """ + Modify parts of an expression selected by a path. + + Examples + ======== + + >>> from sympy.simplify.epathtools import EPath + >>> from sympy import sin, cos, E + >>> from sympy.abc import x, y, z, t + + >>> path = EPath("/*/[0]/Symbol") + >>> expr = [((x, 1), 2), ((3, y), z)] + + >>> path.apply(expr, lambda expr: expr**2) + [((x**2, 1), 2), ((3, y**2), z)] + + >>> path = EPath("/*/*/Symbol") + >>> expr = t + sin(x + 1) + cos(x + y + E) + + >>> path.apply(expr, lambda expr: 2*expr) + t + sin(2*x + 1) + cos(2*x + 2*y + E) + + """ + def _apply(path, expr, func): + if not path: + return func(expr) + else: + selector, path = path[0], path[1:] + attrs, types, span = selector + + if isinstance(expr, Basic): + if not expr.is_Atom: + args, basic = self._get_ordered_args(expr), True + else: + return expr + elif hasattr(expr, '__iter__'): + args, basic = expr, False + else: + return expr + + args = list(args) + + if span is not None: + if isinstance(span, slice): + indices = range(*span.indices(len(args))) + else: + indices = [span] + else: + indices = range(len(args)) + + for i in indices: + try: + arg = args[i] + except IndexError: + continue + + if self._has(arg, attrs, types): + args[i] = _apply(path, arg, func) + + if basic: + return expr.func(*args) + else: + return expr.__class__(args) + + _args, _kwargs = args or (), kwargs or {} + _func = lambda expr: func(expr, *_args, **_kwargs) + + return _apply(self._epath, expr, _func) + + def select(self, expr): + """ + Retrieve parts of an expression selected by a path. + + Examples + ======== + + >>> from sympy.simplify.epathtools import EPath + >>> from sympy import sin, cos, E + >>> from sympy.abc import x, y, z, t + + >>> path = EPath("/*/[0]/Symbol") + >>> expr = [((x, 1), 2), ((3, y), z)] + + >>> path.select(expr) + [x, y] + + >>> path = EPath("/*/*/Symbol") + >>> expr = t + sin(x + 1) + cos(x + y + E) + + >>> path.select(expr) + [x, x, y] + + """ + result = [] + + def _select(path, expr): + if not path: + result.append(expr) + else: + selector, path = path[0], path[1:] + attrs, types, span = selector + + if isinstance(expr, Basic): + args = self._get_ordered_args(expr) + elif hasattr(expr, '__iter__'): + args = expr + else: + return + + if span is not None: + if isinstance(span, slice): + args = args[span] + else: + try: + args = [args[span]] + except IndexError: + return + + for arg in args: + if self._has(arg, attrs, types): + _select(path, arg) + + _select(self._epath, expr) + return result + + +def epath(path, expr=None, func=None, args=None, kwargs=None): + r""" + Manipulate parts of an expression selected by a path. + + Explanation + =========== + + This function allows to manipulate large nested expressions in single + line of code, utilizing techniques to those applied in XML processing + standards (e.g. XPath). + + If ``func`` is ``None``, :func:`epath` retrieves elements selected by + the ``path``. Otherwise it applies ``func`` to each matching element. + + Note that it is more efficient to create an EPath object and use the select + and apply methods of that object, since this will compile the path string + only once. This function should only be used as a convenient shortcut for + interactive use. + + This is the supported syntax: + + * select all: ``/*`` + Equivalent of ``for arg in args:``. + * select slice: ``/[0]`` or ``/[1:5]`` or ``/[1:5:2]`` + Supports standard Python's slice syntax. + * select by type: ``/list`` or ``/list|tuple`` + Emulates ``isinstance()``. + * select by attribute: ``/__iter__?`` + Emulates ``hasattr()``. + + Parameters + ========== + + path : str | EPath + A path as a string or a compiled EPath. + expr : Basic | iterable + An expression or a container of expressions. + func : callable (optional) + A callable that will be applied to matching parts. + args : tuple (optional) + Additional positional arguments to ``func``. + kwargs : dict (optional) + Additional keyword arguments to ``func``. + + Examples + ======== + + >>> from sympy.simplify.epathtools import epath + >>> from sympy import sin, cos, E + >>> from sympy.abc import x, y, z, t + + >>> path = "/*/[0]/Symbol" + >>> expr = [((x, 1), 2), ((3, y), z)] + + >>> epath(path, expr) + [x, y] + >>> epath(path, expr, lambda expr: expr**2) + [((x**2, 1), 2), ((3, y**2), z)] + + >>> path = "/*/*/Symbol" + >>> expr = t + sin(x + 1) + cos(x + y + E) + + >>> epath(path, expr) + [x, x, y] + >>> epath(path, expr, lambda expr: 2*expr) + t + sin(2*x + 1) + cos(2*x + 2*y + E) + + """ + _epath = EPath(path) + + if expr is None: + return _epath + if func is None: + return _epath.select(expr) + else: + return _epath.apply(expr, func, args, kwargs) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/fu.py b/.venv/lib/python3.13/site-packages/sympy/simplify/fu.py new file mode 100644 index 0000000000000000000000000000000000000000..a26706edca98385df0009a8ee41476a17d36420c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/fu.py @@ -0,0 +1,2112 @@ +from collections import defaultdict + +from sympy.core.add import Add +from sympy.core.cache import cacheit +from sympy.core.expr import Expr +from sympy.core.exprtools import Factors, gcd_terms, factor_terms +from sympy.core.function import expand_mul +from sympy.core.mul import Mul +from sympy.core.numbers import pi, I +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.sorting import ordered +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.core.traversal import bottom_up +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.hyperbolic import ( + cosh, sinh, tanh, coth, sech, csch, HyperbolicFunction) +from sympy.functions.elementary.trigonometric import ( + cos, sin, tan, cot, sec, csc, sqrt, TrigonometricFunction) +from sympy.ntheory.factor_ import perfect_power +from sympy.polys.polytools import factor +from sympy.strategies.tree import greedy +from sympy.strategies.core import identity, debug + +from sympy import SYMPY_DEBUG + + +# ================== Fu-like tools =========================== + + +def TR0(rv): + """Simplification of rational polynomials, trying to simplify + the expression, e.g. combine things like 3*x + 2*x, etc.... + """ + # although it would be nice to use cancel, it doesn't work + # with noncommutatives + return rv.normal().factor().expand() + + +def TR1(rv): + """Replace sec, csc with 1/cos, 1/sin + + Examples + ======== + + >>> from sympy.simplify.fu import TR1, sec, csc + >>> from sympy.abc import x + >>> TR1(2*csc(x) + sec(x)) + 1/cos(x) + 2/sin(x) + """ + + def f(rv): + if isinstance(rv, sec): + a = rv.args[0] + return S.One/cos(a) + elif isinstance(rv, csc): + a = rv.args[0] + return S.One/sin(a) + return rv + + return bottom_up(rv, f) + + +def TR2(rv): + """Replace tan and cot with sin/cos and cos/sin + + Examples + ======== + + >>> from sympy.simplify.fu import TR2 + >>> from sympy.abc import x + >>> from sympy import tan, cot, sin, cos + >>> TR2(tan(x)) + sin(x)/cos(x) + >>> TR2(cot(x)) + cos(x)/sin(x) + >>> TR2(tan(tan(x) - sin(x)/cos(x))) + 0 + + """ + + def f(rv): + if isinstance(rv, tan): + a = rv.args[0] + return sin(a)/cos(a) + elif isinstance(rv, cot): + a = rv.args[0] + return cos(a)/sin(a) + return rv + + return bottom_up(rv, f) + + +def TR2i(rv, half=False): + """Converts ratios involving sin and cos as follows:: + sin(x)/cos(x) -> tan(x) + sin(x)/(cos(x) + 1) -> tan(x/2) if half=True + + Examples + ======== + + >>> from sympy.simplify.fu import TR2i + >>> from sympy.abc import x, a + >>> from sympy import sin, cos + >>> TR2i(sin(x)/cos(x)) + tan(x) + + Powers of the numerator and denominator are also recognized + + >>> TR2i(sin(x)**2/(cos(x) + 1)**2, half=True) + tan(x/2)**2 + + The transformation does not take place unless assumptions allow + (i.e. the base must be positive or the exponent must be an integer + for both numerator and denominator) + + >>> TR2i(sin(x)**a/(cos(x) + 1)**a) + sin(x)**a/(cos(x) + 1)**a + + """ + + def f(rv): + if not rv.is_Mul: + return rv + + n, d = rv.as_numer_denom() + if n.is_Atom or d.is_Atom: + return rv + + def ok(k, e): + # initial filtering of factors + return ( + (e.is_integer or k.is_positive) and ( + k.func in (sin, cos) or (half and + k.is_Add and + len(k.args) >= 2 and + any(any(isinstance(ai, cos) or ai.is_Pow and ai.base is cos + for ai in Mul.make_args(a)) for a in k.args)))) + + n = n.as_powers_dict() + ndone = [(k, n.pop(k)) for k in list(n.keys()) if not ok(k, n[k])] + if not n: + return rv + + d = d.as_powers_dict() + ddone = [(k, d.pop(k)) for k in list(d.keys()) if not ok(k, d[k])] + if not d: + return rv + + # factoring if necessary + + def factorize(d, ddone): + newk = [] + for k in d: + if k.is_Add and len(k.args) > 1: + knew = factor(k) if half else factor_terms(k) + if knew != k: + newk.append((k, knew)) + if newk: + for i, (k, knew) in enumerate(newk): + del d[k] + newk[i] = knew + newk = Mul(*newk).as_powers_dict() + for k in newk: + v = d[k] + newk[k] + if ok(k, v): + d[k] = v + else: + ddone.append((k, v)) + del newk + factorize(n, ndone) + factorize(d, ddone) + + # joining + t = [] + for k in n: + if isinstance(k, sin): + a = cos(k.args[0], evaluate=False) + if a in d and d[a] == n[k]: + t.append(tan(k.args[0])**n[k]) + n[k] = d[a] = None + elif half: + a1 = 1 + a + if a1 in d and d[a1] == n[k]: + t.append((tan(k.args[0]/2))**n[k]) + n[k] = d[a1] = None + elif isinstance(k, cos): + a = sin(k.args[0], evaluate=False) + if a in d and d[a] == n[k]: + t.append(tan(k.args[0])**-n[k]) + n[k] = d[a] = None + elif half and k.is_Add and k.args[0] is S.One and \ + isinstance(k.args[1], cos): + a = sin(k.args[1].args[0], evaluate=False) + if a in d and d[a] == n[k] and (d[a].is_integer or \ + a.is_positive): + t.append(tan(a.args[0]/2)**-n[k]) + n[k] = d[a] = None + + if t: + rv = Mul(*(t + [b**e for b, e in n.items() if e]))/\ + Mul(*[b**e for b, e in d.items() if e]) + rv *= Mul(*[b**e for b, e in ndone])/Mul(*[b**e for b, e in ddone]) + + return rv + + return bottom_up(rv, f) + + +def TR3(rv): + """Induced formula: example sin(-a) = -sin(a) + + Examples + ======== + + >>> from sympy.simplify.fu import TR3 + >>> from sympy.abc import x, y + >>> from sympy import pi + >>> from sympy import cos + >>> TR3(cos(y - x*(y - x))) + cos(x*(x - y) + y) + >>> cos(pi/2 + x) + -sin(x) + >>> cos(30*pi/2 + x) + -cos(x) + + """ + from sympy.simplify.simplify import signsimp + + # Negative argument (already automatic for funcs like sin(-x) -> -sin(x) + # but more complicated expressions can use it, too). Also, trig angles + # between pi/4 and pi/2 are not reduced to an angle between 0 and pi/4. + # The following are automatically handled: + # Argument of type: pi/2 +/- angle + # Argument of type: pi +/- angle + # Argument of type : 2k*pi +/- angle + + def f(rv): + if not isinstance(rv, TrigonometricFunction): + return rv + rv = rv.func(signsimp(rv.args[0])) + if not isinstance(rv, TrigonometricFunction): + return rv + if (rv.args[0] - S.Pi/4).is_positive is (S.Pi/2 - rv.args[0]).is_positive is True: + fmap = {cos: sin, sin: cos, tan: cot, cot: tan, sec: csc, csc: sec} + rv = fmap[type(rv)](S.Pi/2 - rv.args[0]) + return rv + + # touch numbers iside of trig functions to let them automatically update + rv = rv.replace( + lambda x: isinstance(x, TrigonometricFunction), + lambda x: x.replace( + lambda n: n.is_number and n.is_Mul, + lambda n: n.func(*n.args))) + + return bottom_up(rv, f) + + +def TR4(rv): + """Identify values of special angles. + + a= 0 pi/6 pi/4 pi/3 pi/2 + ---------------------------------------------------- + sin(a) 0 1/2 sqrt(2)/2 sqrt(3)/2 1 + cos(a) 1 sqrt(3)/2 sqrt(2)/2 1/2 0 + tan(a) 0 sqt(3)/3 1 sqrt(3) -- + + Examples + ======== + + >>> from sympy import pi + >>> from sympy import cos, sin, tan, cot + >>> for s in (0, pi/6, pi/4, pi/3, pi/2): + ... print('%s %s %s %s' % (cos(s), sin(s), tan(s), cot(s))) + ... + 1 0 0 zoo + sqrt(3)/2 1/2 sqrt(3)/3 sqrt(3) + sqrt(2)/2 sqrt(2)/2 1 1 + 1/2 sqrt(3)/2 sqrt(3) sqrt(3)/3 + 0 1 zoo 0 + """ + # special values at 0, pi/6, pi/4, pi/3, pi/2 already handled + return rv.replace( + lambda x: + isinstance(x, TrigonometricFunction) and + (r:=x.args[0]/pi).is_Rational and r.q in (1, 2, 3, 4, 6), + lambda x: + x.func(x.args[0].func(*x.args[0].args))) + + +def _TR56(rv, f, g, h, max, pow): + """Helper for TR5 and TR6 to replace f**2 with h(g**2) + + Options + ======= + + max : controls size of exponent that can appear on f + e.g. if max=4 then f**4 will be changed to h(g**2)**2. + pow : controls whether the exponent must be a perfect power of 2 + e.g. if pow=True (and max >= 6) then f**6 will not be changed + but f**8 will be changed to h(g**2)**4 + + >>> from sympy.simplify.fu import _TR56 as T + >>> from sympy.abc import x + >>> from sympy import sin, cos + >>> h = lambda x: 1 - x + >>> T(sin(x)**3, sin, cos, h, 4, False) + (1 - cos(x)**2)*sin(x) + >>> T(sin(x)**6, sin, cos, h, 6, False) + (1 - cos(x)**2)**3 + >>> T(sin(x)**6, sin, cos, h, 6, True) + sin(x)**6 + >>> T(sin(x)**8, sin, cos, h, 10, True) + (1 - cos(x)**2)**4 + """ + + def _f(rv): + # I'm not sure if this transformation should target all even powers + # or only those expressible as powers of 2. Also, should it only + # make the changes in powers that appear in sums -- making an isolated + # change is not going to allow a simplification as far as I can tell. + if not (rv.is_Pow and rv.base.func == f): + return rv + if not rv.exp.is_real: + return rv + + if (rv.exp < 0) == True: + return rv + if (rv.exp > max) == True: + return rv + if rv.exp == 1: + return rv + if rv.exp == 2: + return h(g(rv.base.args[0])**2) + else: + if rv.exp % 2 == 1: + e = rv.exp//2 + return f(rv.base.args[0])*h(g(rv.base.args[0])**2)**e + elif rv.exp == 4: + e = 2 + elif not pow: + if rv.exp % 2: + return rv + e = rv.exp//2 + else: + p = perfect_power(rv.exp) + if not p: + return rv + e = rv.exp//2 + return h(g(rv.base.args[0])**2)**e + + return bottom_up(rv, _f) + + +def TR5(rv, max=4, pow=False): + """Replacement of sin**2 with 1 - cos(x)**2. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR5 + >>> from sympy.abc import x + >>> from sympy import sin + >>> TR5(sin(x)**2) + 1 - cos(x)**2 + >>> TR5(sin(x)**-2) # unchanged + sin(x)**(-2) + >>> TR5(sin(x)**4) + (1 - cos(x)**2)**2 + """ + return _TR56(rv, sin, cos, lambda x: 1 - x, max=max, pow=pow) + + +def TR6(rv, max=4, pow=False): + """Replacement of cos**2 with 1 - sin(x)**2. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR6 + >>> from sympy.abc import x + >>> from sympy import cos + >>> TR6(cos(x)**2) + 1 - sin(x)**2 + >>> TR6(cos(x)**-2) #unchanged + cos(x)**(-2) + >>> TR6(cos(x)**4) + (1 - sin(x)**2)**2 + """ + return _TR56(rv, cos, sin, lambda x: 1 - x, max=max, pow=pow) + + +def TR7(rv): + """Lowering the degree of cos(x)**2. + + Examples + ======== + + >>> from sympy.simplify.fu import TR7 + >>> from sympy.abc import x + >>> from sympy import cos + >>> TR7(cos(x)**2) + cos(2*x)/2 + 1/2 + >>> TR7(cos(x)**2 + 1) + cos(2*x)/2 + 3/2 + + """ + + def f(rv): + if not (rv.is_Pow and rv.base.func == cos and rv.exp == 2): + return rv + return (1 + cos(2*rv.base.args[0]))/2 + + return bottom_up(rv, f) + + +def TR8(rv, first=True): + """Converting products of ``cos`` and/or ``sin`` to a sum or + difference of ``cos`` and or ``sin`` terms. + + Examples + ======== + + >>> from sympy.simplify.fu import TR8 + >>> from sympy import cos, sin + >>> TR8(cos(2)*cos(3)) + cos(5)/2 + cos(1)/2 + >>> TR8(cos(2)*sin(3)) + sin(5)/2 + sin(1)/2 + >>> TR8(sin(2)*sin(3)) + -cos(5)/2 + cos(1)/2 + """ + + def f(rv): + if not ( + rv.is_Mul or + rv.is_Pow and + rv.base.func in (cos, sin) and + (rv.exp.is_integer or rv.base.is_positive)): + return rv + + if first: + n, d = [expand_mul(i) for i in rv.as_numer_denom()] + newn = TR8(n, first=False) + newd = TR8(d, first=False) + if newn != n or newd != d: + rv = gcd_terms(newn/newd) + if rv.is_Mul and rv.args[0].is_Rational and \ + len(rv.args) == 2 and rv.args[1].is_Add: + rv = Mul(*rv.as_coeff_Mul()) + return rv + + args = {cos: [], sin: [], None: []} + for a in Mul.make_args(rv): + if a.func in (cos, sin): + args[type(a)].append(a.args[0]) + elif (a.is_Pow and a.exp.is_Integer and a.exp > 0 and \ + a.base.func in (cos, sin)): + # XXX this is ok but pathological expression could be handled + # more efficiently as in TRmorrie + args[type(a.base)].extend([a.base.args[0]]*a.exp) + else: + args[None].append(a) + c = args[cos] + s = args[sin] + if not (c and s or len(c) > 1 or len(s) > 1): + return rv + + args = args[None] + n = min(len(c), len(s)) + for i in range(n): + a1 = s.pop() + a2 = c.pop() + args.append((sin(a1 + a2) + sin(a1 - a2))/2) + while len(c) > 1: + a1 = c.pop() + a2 = c.pop() + args.append((cos(a1 + a2) + cos(a1 - a2))/2) + if c: + args.append(cos(c.pop())) + while len(s) > 1: + a1 = s.pop() + a2 = s.pop() + args.append((-cos(a1 + a2) + cos(a1 - a2))/2) + if s: + args.append(sin(s.pop())) + return TR8(expand_mul(Mul(*args))) + + return bottom_up(rv, f) + + +def TR9(rv): + """Sum of ``cos`` or ``sin`` terms as a product of ``cos`` or ``sin``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR9 + >>> from sympy import cos, sin + >>> TR9(cos(1) + cos(2)) + 2*cos(1/2)*cos(3/2) + >>> TR9(cos(1) + 2*sin(1) + 2*sin(2)) + cos(1) + 4*sin(3/2)*cos(1/2) + + If no change is made by TR9, no re-arrangement of the + expression will be made. For example, though factoring + of common term is attempted, if the factored expression + was not changed, the original expression will be returned: + + >>> TR9(cos(3) + cos(3)*cos(2)) + cos(3) + cos(2)*cos(3) + + """ + + def f(rv): + if not rv.is_Add: + return rv + + def do(rv, first=True): + # cos(a)+/-cos(b) can be combined into a product of cosines and + # sin(a)+/-sin(b) can be combined into a product of cosine and + # sine. + # + # If there are more than two args, the pairs which "work" will + # have a gcd extractable and the remaining two terms will have + # the above structure -- all pairs must be checked to find the + # ones that work. args that don't have a common set of symbols + # are skipped since this doesn't lead to a simpler formula and + # also has the arbitrariness of combining, for example, the x + # and y term instead of the y and z term in something like + # cos(x) + cos(y) + cos(z). + + if not rv.is_Add: + return rv + + args = list(ordered(rv.args)) + if len(args) != 2: + hit = False + for i in range(len(args)): + ai = args[i] + if ai is None: + continue + for j in range(i + 1, len(args)): + aj = args[j] + if aj is None: + continue + was = ai + aj + new = do(was) + if new != was: + args[i] = new # update in place + args[j] = None + hit = True + break # go to next i + if hit: + rv = Add(*[_f for _f in args if _f]) + if rv.is_Add: + rv = do(rv) + + return rv + + # two-arg Add + split = trig_split(*args) + if not split: + return rv + gcd, n1, n2, a, b, iscos = split + + # application of rule if possible + if iscos: + if n1 == n2: + return gcd*n1*2*cos((a + b)/2)*cos((a - b)/2) + if n1 < 0: + a, b = b, a + return -2*gcd*sin((a + b)/2)*sin((a - b)/2) + else: + if n1 == n2: + return gcd*n1*2*sin((a + b)/2)*cos((a - b)/2) + if n1 < 0: + a, b = b, a + return 2*gcd*cos((a + b)/2)*sin((a - b)/2) + + return process_common_addends(rv, do) # DON'T sift by free symbols + + return bottom_up(rv, f) + + +def TR10(rv, first=True): + """Separate sums in ``cos`` and ``sin``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR10 + >>> from sympy.abc import a, b, c + >>> from sympy import cos, sin + >>> TR10(cos(a + b)) + -sin(a)*sin(b) + cos(a)*cos(b) + >>> TR10(sin(a + b)) + sin(a)*cos(b) + sin(b)*cos(a) + >>> TR10(sin(a + b + c)) + (-sin(a)*sin(b) + cos(a)*cos(b))*sin(c) + \ + (sin(a)*cos(b) + sin(b)*cos(a))*cos(c) + """ + + def f(rv): + if rv.func not in (cos, sin): + return rv + + f = rv.func + arg = rv.args[0] + if arg.is_Add: + if first: + args = list(ordered(arg.args)) + else: + args = list(arg.args) + a = args.pop() + b = Add._from_args(args) + if b.is_Add: + if f == sin: + return sin(a)*TR10(cos(b), first=False) + \ + cos(a)*TR10(sin(b), first=False) + else: + return cos(a)*TR10(cos(b), first=False) - \ + sin(a)*TR10(sin(b), first=False) + else: + if f == sin: + return sin(a)*cos(b) + cos(a)*sin(b) + else: + return cos(a)*cos(b) - sin(a)*sin(b) + return rv + + return bottom_up(rv, f) + + +def TR10i(rv): + """Sum of products to function of sum. + + Examples + ======== + + >>> from sympy.simplify.fu import TR10i + >>> from sympy import cos, sin, sqrt + >>> from sympy.abc import x + + >>> TR10i(cos(1)*cos(3) + sin(1)*sin(3)) + cos(2) + >>> TR10i(cos(1)*sin(3) + sin(1)*cos(3) + cos(3)) + cos(3) + sin(4) + >>> TR10i(sqrt(2)*cos(x)*x + sqrt(6)*sin(x)*x) + 2*sqrt(2)*x*sin(x + pi/6) + + """ + def f(rv): + if not rv.is_Add: + return rv + + def do(rv, first=True): + # args which can be expressed as A*(cos(a)*cos(b)+/-sin(a)*sin(b)) + # or B*(cos(a)*sin(b)+/-cos(b)*sin(a)) can be combined into + # A*f(a+/-b) where f is either sin or cos. + # + # If there are more than two args, the pairs which "work" will have + # a gcd extractable and the remaining two terms will have the above + # structure -- all pairs must be checked to find the ones that + # work. + + if not rv.is_Add: + return rv + + args = list(ordered(rv.args)) + if len(args) != 2: + hit = False + for i in range(len(args)): + ai = args[i] + if ai is None: + continue + for j in range(i + 1, len(args)): + aj = args[j] + if aj is None: + continue + was = ai + aj + new = do(was) + if new != was: + args[i] = new # update in place + args[j] = None + hit = True + break # go to next i + if hit: + rv = Add(*[_f for _f in args if _f]) + if rv.is_Add: + rv = do(rv) + + return rv + + # two-arg Add + split = trig_split(*args, two=True) + if not split: + return rv + gcd, n1, n2, a, b, same = split + + # identify and get c1 to be cos then apply rule if possible + if same: # coscos, sinsin + gcd = n1*gcd + if n1 == n2: + return gcd*cos(a - b) + return gcd*cos(a + b) + else: #cossin, cossin + gcd = n1*gcd + if n1 == n2: + return gcd*sin(a + b) + return gcd*sin(b - a) + + rv = process_common_addends( + rv, do, lambda x: tuple(ordered(x.free_symbols))) + + # need to check for inducible pairs in ratio of sqrt(3):1 that + # appeared in different lists when sorting by coefficient + while rv.is_Add: + byrad = defaultdict(list) + for a in rv.args: + hit = 0 + if a.is_Mul: + for ai in a.args: + if ai.is_Pow and ai.exp is S.Half and \ + ai.base.is_Integer: + byrad[ai].append(a) + hit = 1 + break + if not hit: + byrad[S.One].append(a) + + # no need to check all pairs -- just check for the onees + # that have the right ratio + args = [] + for a in byrad: + for b in [_ROOT3()*a, _invROOT3()]: + if b in byrad: + for i in range(len(byrad[a])): + if byrad[a][i] is None: + continue + for j in range(len(byrad[b])): + if byrad[b][j] is None: + continue + was = Add(byrad[a][i] + byrad[b][j]) + new = do(was) + if new != was: + args.append(new) + byrad[a][i] = None + byrad[b][j] = None + break + if args: + rv = Add(*(args + [Add(*[_f for _f in v if _f]) + for v in byrad.values()])) + else: + rv = do(rv) # final pass to resolve any new inducible pairs + break + + return rv + + return bottom_up(rv, f) + + +def TR11(rv, base=None): + """Function of double angle to product. The ``base`` argument can be used + to indicate what is the un-doubled argument, e.g. if 3*pi/7 is the base + then cosine and sine functions with argument 6*pi/7 will be replaced. + + Examples + ======== + + >>> from sympy.simplify.fu import TR11 + >>> from sympy import cos, sin, pi + >>> from sympy.abc import x + >>> TR11(sin(2*x)) + 2*sin(x)*cos(x) + >>> TR11(cos(2*x)) + -sin(x)**2 + cos(x)**2 + >>> TR11(sin(4*x)) + 4*(-sin(x)**2 + cos(x)**2)*sin(x)*cos(x) + >>> TR11(sin(4*x/3)) + 4*(-sin(x/3)**2 + cos(x/3)**2)*sin(x/3)*cos(x/3) + + If the arguments are simply integers, no change is made + unless a base is provided: + + >>> TR11(cos(2)) + cos(2) + >>> TR11(cos(4), 2) + -sin(2)**2 + cos(2)**2 + + There is a subtle issue here in that autosimplification will convert + some higher angles to lower angles + + >>> cos(6*pi/7) + cos(3*pi/7) + -cos(pi/7) + cos(3*pi/7) + + The 6*pi/7 angle is now pi/7 but can be targeted with TR11 by supplying + the 3*pi/7 base: + + >>> TR11(_, 3*pi/7) + -sin(3*pi/7)**2 + cos(3*pi/7)**2 + cos(3*pi/7) + + """ + + def f(rv): + if rv.func not in (cos, sin): + return rv + + if base: + f = rv.func + t = f(base*2) + co = S.One + if t.is_Mul: + co, t = t.as_coeff_Mul() + if t.func not in (cos, sin): + return rv + if rv.args[0] == t.args[0]: + c = cos(base) + s = sin(base) + if f is cos: + return (c**2 - s**2)/co + else: + return 2*c*s/co + return rv + + elif not rv.args[0].is_Number: + # make a change if the leading coefficient's numerator is + # divisible by 2 + c, m = rv.args[0].as_coeff_Mul(rational=True) + if c.p % 2 == 0: + arg = c.p//2*m/c.q + c = TR11(cos(arg)) + s = TR11(sin(arg)) + if rv.func == sin: + rv = 2*s*c + else: + rv = c**2 - s**2 + return rv + + return bottom_up(rv, f) + + +def _TR11(rv): + """ + Helper for TR11 to find half-arguments for sin in factors of + num/den that appear in cos or sin factors in the den/num. + + Examples + ======== + + >>> from sympy.simplify.fu import TR11, _TR11 + >>> from sympy import cos, sin + >>> from sympy.abc import x + >>> TR11(sin(x/3)/(cos(x/6))) + sin(x/3)/cos(x/6) + >>> _TR11(sin(x/3)/(cos(x/6))) + 2*sin(x/6) + >>> TR11(sin(x/6)/(sin(x/3))) + sin(x/6)/sin(x/3) + >>> _TR11(sin(x/6)/(sin(x/3))) + 1/(2*cos(x/6)) + + """ + def f(rv): + if not isinstance(rv, Expr): + return rv + + def sincos_args(flat): + # find arguments of sin and cos that + # appears as bases in args of flat + # and have Integer exponents + args = defaultdict(set) + for fi in Mul.make_args(flat): + b, e = fi.as_base_exp() + if e.is_Integer and e > 0: + if b.func in (cos, sin): + args[type(b)].add(b.args[0]) + return args + num_args, den_args = map(sincos_args, rv.as_numer_denom()) + def handle_match(rv, num_args, den_args): + # for arg in sin args of num_args, look for arg/2 + # in den_args and pass this half-angle to TR11 + # for handling in rv + for narg in num_args[sin]: + half = narg/2 + if half in den_args[cos]: + func = cos + elif half in den_args[sin]: + func = sin + else: + continue + rv = TR11(rv, half) + den_args[func].remove(half) + return rv + # sin in num, sin or cos in den + rv = handle_match(rv, num_args, den_args) + # sin in den, sin or cos in num + rv = handle_match(rv, den_args, num_args) + return rv + + return bottom_up(rv, f) + + +def TR12(rv, first=True): + """Separate sums in ``tan``. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import tan + >>> from sympy.simplify.fu import TR12 + >>> TR12(tan(x + y)) + (tan(x) + tan(y))/(-tan(x)*tan(y) + 1) + """ + + def f(rv): + if not rv.func == tan: + return rv + + arg = rv.args[0] + if arg.is_Add: + if first: + args = list(ordered(arg.args)) + else: + args = list(arg.args) + a = args.pop() + b = Add._from_args(args) + if b.is_Add: + tb = TR12(tan(b), first=False) + else: + tb = tan(b) + return (tan(a) + tb)/(1 - tan(a)*tb) + return rv + + return bottom_up(rv, f) + + +def TR12i(rv): + """Combine tan arguments as + (tan(y) + tan(x))/(tan(x)*tan(y) - 1) -> -tan(x + y). + + Examples + ======== + + >>> from sympy.simplify.fu import TR12i + >>> from sympy import tan + >>> from sympy.abc import a, b, c + >>> ta, tb, tc = [tan(i) for i in (a, b, c)] + >>> TR12i((ta + tb)/(-ta*tb + 1)) + tan(a + b) + >>> TR12i((ta + tb)/(ta*tb - 1)) + -tan(a + b) + >>> TR12i((-ta - tb)/(ta*tb - 1)) + tan(a + b) + >>> eq = (ta + tb)/(-ta*tb + 1)**2*(-3*ta - 3*tc)/(2*(ta*tc - 1)) + >>> TR12i(eq.expand()) + -3*tan(a + b)*tan(a + c)/(2*(tan(a) + tan(b) - 1)) + """ + def f(rv): + if not (rv.is_Add or rv.is_Mul or rv.is_Pow): + return rv + + n, d = rv.as_numer_denom() + if not d.args or not n.args: + return rv + + dok = {} + + def ok(di): + m = as_f_sign_1(di) + if m: + g, f, s = m + if s is S.NegativeOne and f.is_Mul and len(f.args) == 2 and \ + all(isinstance(fi, tan) for fi in f.args): + return g, f + + d_args = list(Mul.make_args(d)) + for i, di in enumerate(d_args): + m = ok(di) + if m: + g, t = m + s = Add(*[_.args[0] for _ in t.args]) + dok[s] = S.One + d_args[i] = g + continue + if di.is_Add: + di = factor(di) + if di.is_Mul: + d_args.extend(di.args) + d_args[i] = S.One + elif di.is_Pow and (di.exp.is_integer or di.base.is_positive): + m = ok(di.base) + if m: + g, t = m + s = Add(*[_.args[0] for _ in t.args]) + dok[s] = di.exp + d_args[i] = g**di.exp + else: + di = factor(di) + if di.is_Mul: + d_args.extend(di.args) + d_args[i] = S.One + if not dok: + return rv + + def ok(ni): + if ni.is_Add and len(ni.args) == 2: + a, b = ni.args + if isinstance(a, tan) and isinstance(b, tan): + return a, b + n_args = list(Mul.make_args(factor_terms(n))) + hit = False + for i, ni in enumerate(n_args): + m = ok(ni) + if not m: + m = ok(-ni) + if m: + n_args[i] = S.NegativeOne + else: + if ni.is_Add: + ni = factor(ni) + if ni.is_Mul: + n_args.extend(ni.args) + n_args[i] = S.One + continue + elif ni.is_Pow and ( + ni.exp.is_integer or ni.base.is_positive): + m = ok(ni.base) + if m: + n_args[i] = S.One + else: + ni = factor(ni) + if ni.is_Mul: + n_args.extend(ni.args) + n_args[i] = S.One + continue + else: + continue + else: + n_args[i] = S.One + hit = True + s = Add(*[_.args[0] for _ in m]) + ed = dok[s] + newed = ed.extract_additively(S.One) + if newed is not None: + if newed: + dok[s] = newed + else: + dok.pop(s) + n_args[i] *= -tan(s) + + if hit: + rv = Mul(*n_args)/Mul(*d_args)/Mul(*[(Add(*[ + tan(a) for a in i.args]) - 1)**e for i, e in dok.items()]) + + return rv + + return bottom_up(rv, f) + + +def TR13(rv): + """Change products of ``tan`` or ``cot``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR13 + >>> from sympy import tan, cot + >>> TR13(tan(3)*tan(2)) + -tan(2)/tan(5) - tan(3)/tan(5) + 1 + >>> TR13(cot(3)*cot(2)) + cot(2)*cot(5) + 1 + cot(3)*cot(5) + """ + + def f(rv): + if not rv.is_Mul: + return rv + + # XXX handle products of powers? or let power-reducing handle it? + args = {tan: [], cot: [], None: []} + for a in Mul.make_args(rv): + if a.func in (tan, cot): + args[type(a)].append(a.args[0]) + else: + args[None].append(a) + t = args[tan] + c = args[cot] + if len(t) < 2 and len(c) < 2: + return rv + args = args[None] + while len(t) > 1: + t1 = t.pop() + t2 = t.pop() + args.append(1 - (tan(t1)/tan(t1 + t2) + tan(t2)/tan(t1 + t2))) + if t: + args.append(tan(t.pop())) + while len(c) > 1: + t1 = c.pop() + t2 = c.pop() + args.append(1 + cot(t1)*cot(t1 + t2) + cot(t2)*cot(t1 + t2)) + if c: + args.append(cot(c.pop())) + return Mul(*args) + + return bottom_up(rv, f) + + +def TRmorrie(rv): + """Returns cos(x)*cos(2*x)*...*cos(2**(k-1)*x) -> sin(2**k*x)/(2**k*sin(x)) + + Examples + ======== + + >>> from sympy.simplify.fu import TRmorrie, TR8, TR3 + >>> from sympy.abc import x + >>> from sympy import Mul, cos, pi + >>> TRmorrie(cos(x)*cos(2*x)) + sin(4*x)/(4*sin(x)) + >>> TRmorrie(7*Mul(*[cos(x) for x in range(10)])) + 7*sin(12)*sin(16)*cos(5)*cos(7)*cos(9)/(64*sin(1)*sin(3)) + + Sometimes autosimplification will cause a power to be + not recognized. e.g. in the following, cos(4*pi/7) automatically + simplifies to -cos(3*pi/7) so only 2 of the 3 terms are + recognized: + + >>> TRmorrie(cos(pi/7)*cos(2*pi/7)*cos(4*pi/7)) + -sin(3*pi/7)*cos(3*pi/7)/(4*sin(pi/7)) + + A touch by TR8 resolves the expression to a Rational + + >>> TR8(_) + -1/8 + + In this case, if eq is unsimplified, the answer is obtained + directly: + + >>> eq = cos(pi/9)*cos(2*pi/9)*cos(3*pi/9)*cos(4*pi/9) + >>> TRmorrie(eq) + 1/16 + + But if angles are made canonical with TR3 then the answer + is not simplified without further work: + + >>> TR3(eq) + sin(pi/18)*cos(pi/9)*cos(2*pi/9)/2 + >>> TRmorrie(_) + sin(pi/18)*sin(4*pi/9)/(8*sin(pi/9)) + >>> TR8(_) + cos(7*pi/18)/(16*sin(pi/9)) + >>> TR3(_) + 1/16 + + The original expression would have resolve to 1/16 directly with TR8, + however: + + >>> TR8(eq) + 1/16 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Morrie%27s_law + + """ + + def f(rv, first=True): + if not rv.is_Mul: + return rv + if first: + n, d = rv.as_numer_denom() + return f(n, 0)/f(d, 0) + + args = defaultdict(list) + coss = {} + other = [] + for c in rv.args: + b, e = c.as_base_exp() + if e.is_Integer and isinstance(b, cos): + co, a = b.args[0].as_coeff_Mul() + args[a].append(co) + coss[b] = e + else: + other.append(c) + + new = [] + for a in args: + c = args[a] + c.sort() + while c: + k = 0 + cc = ci = c[0] + while cc in c: + k += 1 + cc *= 2 + if k > 1: + newarg = sin(2**k*ci*a)/2**k/sin(ci*a) + # see how many times this can be taken + take = None + ccs = [] + for i in range(k): + cc /= 2 + key = cos(a*cc, evaluate=False) + ccs.append(cc) + take = min(coss[key], take or coss[key]) + # update exponent counts + for i in range(k): + cc = ccs.pop() + key = cos(a*cc, evaluate=False) + coss[key] -= take + if not coss[key]: + c.remove(cc) + new.append(newarg**take) + else: + b = cos(c.pop(0)*a) + other.append(b**coss[b]) + + if new: + rv = Mul(*(new + other + [ + cos(k*a, evaluate=False) for a in args for k in args[a]])) + + return rv + + return bottom_up(rv, f) + + +def TR14(rv, first=True): + """Convert factored powers of sin and cos identities into simpler + expressions. + + Examples + ======== + + >>> from sympy.simplify.fu import TR14 + >>> from sympy.abc import x, y + >>> from sympy import cos, sin + >>> TR14((cos(x) - 1)*(cos(x) + 1)) + -sin(x)**2 + >>> TR14((sin(x) - 1)*(sin(x) + 1)) + -cos(x)**2 + >>> p1 = (cos(x) + 1)*(cos(x) - 1) + >>> p2 = (cos(y) - 1)*2*(cos(y) + 1) + >>> p3 = (3*(cos(y) - 1))*(3*(cos(y) + 1)) + >>> TR14(p1*p2*p3*(x - 1)) + -18*(x - 1)*sin(x)**2*sin(y)**4 + + """ + + def f(rv): + if not rv.is_Mul: + return rv + + if first: + # sort them by location in numerator and denominator + # so the code below can just deal with positive exponents + n, d = rv.as_numer_denom() + if d is not S.One: + newn = TR14(n, first=False) + newd = TR14(d, first=False) + if newn != n or newd != d: + rv = newn/newd + return rv + + other = [] + process = [] + for a in rv.args: + if a.is_Pow: + b, e = a.as_base_exp() + if not (e.is_integer or b.is_positive): + other.append(a) + continue + a = b + else: + e = S.One + m = as_f_sign_1(a) + if not m or m[1].func not in (cos, sin): + if e is S.One: + other.append(a) + else: + other.append(a**e) + continue + g, f, si = m + process.append((g, e.is_Number, e, f, si, a)) + + # sort them to get like terms next to each other + process = list(ordered(process)) + + # keep track of whether there was any change + nother = len(other) + + # access keys + keys = (g, t, e, f, si, a) = list(range(6)) + + while process: + A = process.pop(0) + if process: + B = process[0] + + if A[e].is_Number and B[e].is_Number: + # both exponents are numbers + if A[f] == B[f]: + if A[si] != B[si]: + B = process.pop(0) + take = min(A[e], B[e]) + + # reinsert any remainder + # the B will likely sort after A so check it first + if B[e] != take: + rem = [B[i] for i in keys] + rem[e] -= take + process.insert(0, rem) + elif A[e] != take: + rem = [A[i] for i in keys] + rem[e] -= take + process.insert(0, rem) + + if isinstance(A[f], cos): + t = sin + else: + t = cos + other.append((-A[g]*B[g]*t(A[f].args[0])**2)**take) + continue + + elif A[e] == B[e]: + # both exponents are equal symbols + if A[f] == B[f]: + if A[si] != B[si]: + B = process.pop(0) + take = A[e] + if isinstance(A[f], cos): + t = sin + else: + t = cos + other.append((-A[g]*B[g]*t(A[f].args[0])**2)**take) + continue + + # either we are done or neither condition above applied + other.append(A[a]**A[e]) + + if len(other) != nother: + rv = Mul(*other) + + return rv + + return bottom_up(rv, f) + + +def TR15(rv, max=4, pow=False): + """Convert sin(x)**-2 to 1 + cot(x)**2. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR15 + >>> from sympy.abc import x + >>> from sympy import sin + >>> TR15(1 - 1/sin(x)**2) + -cot(x)**2 + + """ + + def f(rv): + if not (isinstance(rv, Pow) and isinstance(rv.base, sin)): + return rv + + e = rv.exp + if e % 2 == 1: + return TR15(rv.base**(e + 1))/rv.base + + ia = 1/rv + a = _TR56(ia, sin, cot, lambda x: 1 + x, max=max, pow=pow) + if a != ia: + rv = a + return rv + + return bottom_up(rv, f) + + +def TR16(rv, max=4, pow=False): + """Convert cos(x)**-2 to 1 + tan(x)**2. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR16 + >>> from sympy.abc import x + >>> from sympy import cos + >>> TR16(1 - 1/cos(x)**2) + -tan(x)**2 + + """ + + def f(rv): + if not (isinstance(rv, Pow) and isinstance(rv.base, cos)): + return rv + + e = rv.exp + if e % 2 == 1: + return TR15(rv.base**(e + 1))/rv.base + + ia = 1/rv + a = _TR56(ia, cos, tan, lambda x: 1 + x, max=max, pow=pow) + if a != ia: + rv = a + return rv + + return bottom_up(rv, f) + + +def TR111(rv): + """Convert f(x)**-i to g(x)**i where either ``i`` is an integer + or the base is positive and f, g are: tan, cot; sin, csc; or cos, sec. + + Examples + ======== + + >>> from sympy.simplify.fu import TR111 + >>> from sympy.abc import x + >>> from sympy import tan + >>> TR111(1 - 1/tan(x)**2) + 1 - cot(x)**2 + + """ + + def f(rv): + if not ( + isinstance(rv, Pow) and + (rv.base.is_positive or rv.exp.is_integer and rv.exp.is_negative)): + return rv + + if isinstance(rv.base, tan): + return cot(rv.base.args[0])**-rv.exp + elif isinstance(rv.base, sin): + return csc(rv.base.args[0])**-rv.exp + elif isinstance(rv.base, cos): + return sec(rv.base.args[0])**-rv.exp + return rv + + return bottom_up(rv, f) + + +def TR22(rv, max=4, pow=False): + """Convert tan(x)**2 to sec(x)**2 - 1 and cot(x)**2 to csc(x)**2 - 1. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR22 + >>> from sympy.abc import x + >>> from sympy import tan, cot + >>> TR22(1 + tan(x)**2) + sec(x)**2 + >>> TR22(1 + cot(x)**2) + csc(x)**2 + + """ + + def f(rv): + if not (isinstance(rv, Pow) and rv.base.func in (cot, tan)): + return rv + + rv = _TR56(rv, tan, sec, lambda x: x - 1, max=max, pow=pow) + rv = _TR56(rv, cot, csc, lambda x: x - 1, max=max, pow=pow) + return rv + + return bottom_up(rv, f) + + +def TRpower(rv): + """Convert sin(x)**n and cos(x)**n with positive n to sums. + + Examples + ======== + + >>> from sympy.simplify.fu import TRpower + >>> from sympy.abc import x + >>> from sympy import cos, sin + >>> TRpower(sin(x)**6) + -15*cos(2*x)/32 + 3*cos(4*x)/16 - cos(6*x)/32 + 5/16 + >>> TRpower(sin(x)**3*cos(2*x)**4) + (3*sin(x)/4 - sin(3*x)/4)*(cos(4*x)/2 + cos(8*x)/8 + 3/8) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/List_of_trigonometric_identities#Power-reduction_formulae + + """ + + def f(rv): + if not (isinstance(rv, Pow) and isinstance(rv.base, (sin, cos))): + return rv + b, n = rv.as_base_exp() + x = b.args[0] + if n.is_Integer and n.is_positive: + if n.is_odd and isinstance(b, cos): + rv = 2**(1-n)*Add(*[binomial(n, k)*cos((n - 2*k)*x) + for k in range((n + 1)/2)]) + elif n.is_odd and isinstance(b, sin): + rv = 2**(1-n)*S.NegativeOne**((n-1)/2)*Add(*[binomial(n, k)* + S.NegativeOne**k*sin((n - 2*k)*x) for k in range((n + 1)/2)]) + elif n.is_even and isinstance(b, cos): + rv = 2**(1-n)*Add(*[binomial(n, k)*cos((n - 2*k)*x) + for k in range(n/2)]) + elif n.is_even and isinstance(b, sin): + rv = 2**(1-n)*S.NegativeOne**(n/2)*Add(*[binomial(n, k)* + S.NegativeOne**k*cos((n - 2*k)*x) for k in range(n/2)]) + if n.is_even: + rv += 2**(-n)*binomial(n, n/2) + return rv + + return bottom_up(rv, f) + + +def L(rv): + """Return count of trigonometric functions in expression. + + Examples + ======== + + >>> from sympy.simplify.fu import L + >>> from sympy.abc import x + >>> from sympy import cos, sin + >>> L(cos(x)+sin(x)) + 2 + """ + return S(rv.count(TrigonometricFunction)) + + +# ============== end of basic Fu-like tools ===================== + +if SYMPY_DEBUG: + (TR0, TR1, TR2, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TR10, TR11, TR12, TR13, + TR2i, TRmorrie, TR14, TR15, TR16, TR12i, TR111, TR22 + )= list(map(debug, + (TR0, TR1, TR2, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TR10, TR11, TR12, TR13, + TR2i, TRmorrie, TR14, TR15, TR16, TR12i, TR111, TR22))) + + +# tuples are chains -- (f, g) -> lambda x: g(f(x)) +# lists are choices -- [f, g] -> lambda x: min(f(x), g(x), key=objective) + +CTR1 = [(TR5, TR0), (TR6, TR0), identity] + +CTR2 = (TR11, [(TR5, TR0), (TR6, TR0), TR0]) + +CTR3 = [(TRmorrie, TR8, TR0), (TRmorrie, TR8, TR10i, TR0), identity] + +CTR4 = [(TR4, TR10i), identity] + +RL1 = (TR4, TR3, TR4, TR12, TR4, TR13, TR4, TR0) + + +# XXX it's a little unclear how this one is to be implemented +# see Fu paper of reference, page 7. What is the Union symbol referring to? +# The diagram shows all these as one chain of transformations, but the +# text refers to them being applied independently. Also, a break +# if L starts to increase has not been implemented. +RL2 = [ + (TR4, TR3, TR10, TR4, TR3, TR11), + (TR5, TR7, TR11, TR4), + (CTR3, CTR1, TR9, CTR2, TR4, TR9, TR9, CTR4), + identity, + ] + + +def fu(rv, measure=lambda x: (L(x), x.count_ops())): + """Attempt to simplify expression by using transformation rules given + in the algorithm by Fu et al. + + :func:`fu` will try to minimize the objective function ``measure``. + By default this first minimizes the number of trig terms and then minimizes + the number of total operations. + + Examples + ======== + + >>> from sympy.simplify.fu import fu + >>> from sympy import cos, sin, tan, pi, S, sqrt + >>> from sympy.abc import x, y, a, b + + >>> fu(sin(50)**2 + cos(50)**2 + sin(pi/6)) + 3/2 + >>> fu(sqrt(6)*cos(x) + sqrt(2)*sin(x)) + 2*sqrt(2)*sin(x + pi/3) + + CTR1 example + + >>> eq = sin(x)**4 - cos(y)**2 + sin(y)**2 + 2*cos(x)**2 + >>> fu(eq) + cos(x)**4 - 2*cos(y)**2 + 2 + + CTR2 example + + >>> fu(S.Half - cos(2*x)/2) + sin(x)**2 + + CTR3 example + + >>> fu(sin(a)*(cos(b) - sin(b)) + cos(a)*(sin(b) + cos(b))) + sqrt(2)*sin(a + b + pi/4) + + CTR4 example + + >>> fu(sqrt(3)*cos(x)/2 + sin(x)/2) + sin(x + pi/3) + + Example 1 + + >>> fu(1-sin(2*x)**2/4-sin(y)**2-cos(x)**4) + -cos(x)**2 + cos(y)**2 + + Example 2 + + >>> fu(cos(4*pi/9)) + sin(pi/18) + >>> fu(cos(pi/9)*cos(2*pi/9)*cos(3*pi/9)*cos(4*pi/9)) + 1/16 + + Example 3 + + >>> fu(tan(7*pi/18)+tan(5*pi/18)-sqrt(3)*tan(5*pi/18)*tan(7*pi/18)) + -sqrt(3) + + Objective function example + + >>> fu(sin(x)/cos(x)) # default objective function + tan(x) + >>> fu(sin(x)/cos(x), measure=lambda x: -x.count_ops()) # maximize op count + sin(x)/cos(x) + + References + ========== + + .. [1] https://www.sciencedirect.com/science/article/pii/S0895717706001609 + """ + fRL1 = greedy(RL1, measure) + fRL2 = greedy(RL2, measure) + + was = rv + rv = sympify(rv) + if not isinstance(rv, Expr): + return rv.func(*[fu(a, measure=measure) for a in rv.args]) + rv = TR1(rv) + if rv.has(tan, cot): + rv1 = fRL1(rv) + if (measure(rv1) < measure(rv)): + rv = rv1 + if rv.has(tan, cot): + rv = TR2(rv) + if rv.has(sin, cos): + rv1 = fRL2(rv) + rv2 = TR8(TRmorrie(rv1)) + rv = min([was, rv, rv1, rv2], key=measure) + return min(TR2i(rv), rv, key=measure) + + +def process_common_addends(rv, do, key2=None, key1=True): + """Apply ``do`` to addends of ``rv`` that (if ``key1=True``) share at least + a common absolute value of their coefficient and the value of ``key2`` when + applied to the argument. If ``key1`` is False ``key2`` must be supplied and + will be the only key applied. + """ + + # collect by absolute value of coefficient and key2 + absc = defaultdict(list) + if key1: + for a in rv.args: + c, a = a.as_coeff_Mul() + if c < 0: + c = -c + a = -a # put the sign on `a` + absc[(c, key2(a) if key2 else 1)].append(a) + elif key2: + for a in rv.args: + absc[(S.One, key2(a))].append(a) + else: + raise ValueError('must have at least one key') + + args = [] + hit = False + for k in absc: + v = absc[k] + c, _ = k + if len(v) > 1: + e = Add(*v, evaluate=False) + new = do(e) + if new != e: + e = new + hit = True + args.append(c*e) + else: + args.append(c*v[0]) + if hit: + rv = Add(*args) + + return rv + + +fufuncs = ''' + TR0 TR1 TR2 TR3 TR4 TR5 TR6 TR7 TR8 TR9 TR10 TR10i TR11 + TR12 TR13 L TR2i TRmorrie TR12i + TR14 TR15 TR16 TR111 TR22'''.split() +FU = dict(list(zip(fufuncs, list(map(locals().get, fufuncs))))) + + +@cacheit +def _ROOT2(): + return sqrt(2) + + +@cacheit +def _ROOT3(): + return sqrt(3) + + +@cacheit +def _invROOT3(): + return 1/sqrt(3) + + +def trig_split(a, b, two=False): + """Return the gcd, s1, s2, a1, a2, bool where + + If two is False (default) then:: + a + b = gcd*(s1*f(a1) + s2*f(a2)) where f = cos if bool else sin + else: + if bool, a + b was +/- cos(a1)*cos(a2) +/- sin(a1)*sin(a2) and equals + n1*gcd*cos(a - b) if n1 == n2 else + n1*gcd*cos(a + b) + else a + b was +/- cos(a1)*sin(a2) +/- sin(a1)*cos(a2) and equals + n1*gcd*sin(a + b) if n1 = n2 else + n1*gcd*sin(b - a) + + Examples + ======== + + >>> from sympy.simplify.fu import trig_split + >>> from sympy.abc import x, y, z + >>> from sympy import cos, sin, sqrt + + >>> trig_split(cos(x), cos(y)) + (1, 1, 1, x, y, True) + >>> trig_split(2*cos(x), -2*cos(y)) + (2, 1, -1, x, y, True) + >>> trig_split(cos(x)*sin(y), cos(y)*sin(y)) + (sin(y), 1, 1, x, y, True) + + >>> trig_split(cos(x), -sqrt(3)*sin(x), two=True) + (2, 1, -1, x, pi/6, False) + >>> trig_split(cos(x), sin(x), two=True) + (sqrt(2), 1, 1, x, pi/4, False) + >>> trig_split(cos(x), -sin(x), two=True) + (sqrt(2), 1, -1, x, pi/4, False) + >>> trig_split(sqrt(2)*cos(x), -sqrt(6)*sin(x), two=True) + (2*sqrt(2), 1, -1, x, pi/6, False) + >>> trig_split(-sqrt(6)*cos(x), -sqrt(2)*sin(x), two=True) + (-2*sqrt(2), 1, 1, x, pi/3, False) + >>> trig_split(cos(x)/sqrt(6), sin(x)/sqrt(2), two=True) + (sqrt(6)/3, 1, 1, x, pi/6, False) + >>> trig_split(-sqrt(6)*cos(x)*sin(y), -sqrt(2)*sin(x)*sin(y), two=True) + (-2*sqrt(2)*sin(y), 1, 1, x, pi/3, False) + + >>> trig_split(cos(x), sin(x)) + >>> trig_split(cos(x), sin(z)) + >>> trig_split(2*cos(x), -sin(x)) + >>> trig_split(cos(x), -sqrt(3)*sin(x)) + >>> trig_split(cos(x)*cos(y), sin(x)*sin(z)) + >>> trig_split(cos(x)*cos(y), sin(x)*sin(y)) + >>> trig_split(-sqrt(6)*cos(x), sqrt(2)*sin(x)*sin(y), two=True) + """ + a, b = [Factors(i) for i in (a, b)] + ua, ub = a.normal(b) + gcd = a.gcd(b).as_expr() + n1 = n2 = 1 + if S.NegativeOne in ua.factors: + ua = ua.quo(S.NegativeOne) + n1 = -n1 + elif S.NegativeOne in ub.factors: + ub = ub.quo(S.NegativeOne) + n2 = -n2 + a, b = [i.as_expr() for i in (ua, ub)] + + def pow_cos_sin(a, two): + """Return ``a`` as a tuple (r, c, s) such that + ``a = (r or 1)*(c or 1)*(s or 1)``. + + Three arguments are returned (radical, c-factor, s-factor) as + long as the conditions set by ``two`` are met; otherwise None is + returned. If ``two`` is True there will be one or two non-None + values in the tuple: c and s or c and r or s and r or s or c with c + being a cosine function (if possible) else a sine, and s being a sine + function (if possible) else oosine. If ``two`` is False then there + will only be a c or s term in the tuple. + + ``two`` also require that either two cos and/or sin be present (with + the condition that if the functions are the same the arguments are + different or vice versa) or that a single cosine or a single sine + be present with an optional radical. + + If the above conditions dictated by ``two`` are not met then None + is returned. + """ + c = s = None + co = S.One + if a.is_Mul: + co, a = a.as_coeff_Mul() + if len(a.args) > 2 or not two: + return None + if a.is_Mul: + args = list(a.args) + else: + args = [a] + a = args.pop(0) + if isinstance(a, cos): + c = a + elif isinstance(a, sin): + s = a + elif a.is_Pow and a.exp is S.Half: # autoeval doesn't allow -1/2 + co *= a + else: + return None + if args: + b = args[0] + if isinstance(b, cos): + if c: + s = b + else: + c = b + elif isinstance(b, sin): + if s: + c = b + else: + s = b + elif b.is_Pow and b.exp is S.Half: + co *= b + else: + return None + return co if co is not S.One else None, c, s + elif isinstance(a, cos): + c = a + elif isinstance(a, sin): + s = a + if c is None and s is None: + return + co = co if co is not S.One else None + return co, c, s + + # get the parts + m = pow_cos_sin(a, two) + if m is None: + return + coa, ca, sa = m + m = pow_cos_sin(b, two) + if m is None: + return + cob, cb, sb = m + + # check them + if (not ca) and cb or ca and isinstance(ca, sin): + coa, ca, sa, cob, cb, sb = cob, cb, sb, coa, ca, sa + n1, n2 = n2, n1 + if not two: # need cos(x) and cos(y) or sin(x) and sin(y) + c = ca or sa + s = cb or sb + if not isinstance(c, s.func): + return None + return gcd, n1, n2, c.args[0], s.args[0], isinstance(c, cos) + else: + if not coa and not cob: + if (ca and cb and sa and sb): + if isinstance(ca, sa.func) is not isinstance(cb, sb.func): + return + args = {j.args for j in (ca, sa)} + if not all(i.args in args for i in (cb, sb)): + return + return gcd, n1, n2, ca.args[0], sa.args[0], isinstance(ca, sa.func) + if ca and sa or cb and sb or \ + two and (ca is None and sa is None or cb is None and sb is None): + return + c = ca or sa + s = cb or sb + if c.args != s.args: + return + if not coa: + coa = S.One + if not cob: + cob = S.One + if coa is cob: + gcd *= _ROOT2() + return gcd, n1, n2, c.args[0], pi/4, False + elif coa/cob == _ROOT3(): + gcd *= 2*cob + return gcd, n1, n2, c.args[0], pi/3, False + elif coa/cob == _invROOT3(): + gcd *= 2*coa + return gcd, n1, n2, c.args[0], pi/6, False + + +def as_f_sign_1(e): + """If ``e`` is a sum that can be written as ``g*(a + s)`` where + ``s`` is ``+/-1``, return ``g``, ``a``, and ``s`` where ``a`` does + not have a leading negative coefficient. + + Examples + ======== + + >>> from sympy.simplify.fu import as_f_sign_1 + >>> from sympy.abc import x + >>> as_f_sign_1(x + 1) + (1, x, 1) + >>> as_f_sign_1(x - 1) + (1, x, -1) + >>> as_f_sign_1(-x + 1) + (-1, x, -1) + >>> as_f_sign_1(-x - 1) + (-1, x, 1) + >>> as_f_sign_1(2*x + 2) + (2, x, 1) + """ + if not e.is_Add or len(e.args) != 2: + return + # exact match + a, b = e.args + if a in (S.NegativeOne, S.One): + g = S.One + if b.is_Mul and b.args[0].is_Number and b.args[0] < 0: + a, b = -a, -b + g = -g + return g, b, a + # gcd match + a, b = [Factors(i) for i in e.args] + ua, ub = a.normal(b) + gcd = a.gcd(b).as_expr() + if S.NegativeOne in ua.factors: + ua = ua.quo(S.NegativeOne) + n1 = -1 + n2 = 1 + elif S.NegativeOne in ub.factors: + ub = ub.quo(S.NegativeOne) + n1 = 1 + n2 = -1 + else: + n1 = n2 = 1 + a, b = [i.as_expr() for i in (ua, ub)] + if a is S.One: + a, b = b, a + n1, n2 = n2, n1 + if n1 == -1: + gcd = -gcd + n2 = -n2 + + if b is S.One: + return gcd, a, n2 + + +def _osborne(e, d): + """Replace all hyperbolic functions with trig functions using + the Osborne rule. + + Notes + ===== + + ``d`` is a dummy variable to prevent automatic evaluation + of trigonometric/hyperbolic functions. + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hyperbolic_function + """ + + def f(rv): + if not isinstance(rv, HyperbolicFunction): + return rv + a = rv.args[0] + a = a*d if not a.is_Add else Add._from_args([i*d for i in a.args]) + if isinstance(rv, sinh): + return I*sin(a) + elif isinstance(rv, cosh): + return cos(a) + elif isinstance(rv, tanh): + return I*tan(a) + elif isinstance(rv, coth): + return cot(a)/I + elif isinstance(rv, sech): + return sec(a) + elif isinstance(rv, csch): + return csc(a)/I + else: + raise NotImplementedError('unhandled %s' % rv.func) + + return bottom_up(e, f) + + +def _osbornei(e, d): + """Replace all trig functions with hyperbolic functions using + the Osborne rule. + + Notes + ===== + + ``d`` is a dummy variable to prevent automatic evaluation + of trigonometric/hyperbolic functions. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hyperbolic_function + """ + + def f(rv): + if not isinstance(rv, TrigonometricFunction): + return rv + const, x = rv.args[0].as_independent(d, as_Add=True) + a = x.xreplace({d: S.One}) + const*I + if isinstance(rv, sin): + return sinh(a)/I + elif isinstance(rv, cos): + return cosh(a) + elif isinstance(rv, tan): + return tanh(a)/I + elif isinstance(rv, cot): + return coth(a)*I + elif isinstance(rv, sec): + return sech(a) + elif isinstance(rv, csc): + return csch(a)*I + else: + raise NotImplementedError('unhandled %s' % rv.func) + + return bottom_up(e, f) + + +def hyper_as_trig(rv): + """Return an expression containing hyperbolic functions in terms + of trigonometric functions. Any trigonometric functions initially + present are replaced with Dummy symbols and the function to undo + the masking and the conversion back to hyperbolics is also returned. It + should always be true that:: + + t, f = hyper_as_trig(expr) + expr == f(t) + + Examples + ======== + + >>> from sympy.simplify.fu import hyper_as_trig, fu + >>> from sympy.abc import x + >>> from sympy import cosh, sinh + >>> eq = sinh(x)**2 + cosh(x)**2 + >>> t, f = hyper_as_trig(eq) + >>> f(fu(t)) + cosh(2*x) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hyperbolic_function + """ + from sympy.simplify.simplify import signsimp + from sympy.simplify.radsimp import collect + + # mask off trig functions + trigs = rv.atoms(TrigonometricFunction) + reps = [(t, Dummy()) for t in trigs] + masked = rv.xreplace(dict(reps)) + + # get inversion substitutions in place + reps = [(v, k) for k, v in reps] + + d = Dummy() + + return _osborne(masked, d), lambda x: collect(signsimp( + _osbornei(x, d).xreplace(dict(reps))), S.ImaginaryUnit) + + +def sincos_to_sum(expr): + """Convert products and powers of sin and cos to sums. + + Explanation + =========== + + Applied power reduction TRpower first, then expands products, and + converts products to sums with TR8. + + Examples + ======== + + >>> from sympy.simplify.fu import sincos_to_sum + >>> from sympy.abc import x + >>> from sympy import cos, sin + >>> sincos_to_sum(16*sin(x)**3*cos(2*x)**2) + 7*sin(x) - 5*sin(3*x) + 3*sin(5*x) - sin(7*x) + """ + + if not expr.has(cos, sin): + return expr + else: + return TR8(expand_mul(TRpower(expr))) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/gammasimp.py b/.venv/lib/python3.13/site-packages/sympy/simplify/gammasimp.py new file mode 100644 index 0000000000000000000000000000000000000000..aec20c56eb60efb8e1aadfb5bff3d1ba1ab51869 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/gammasimp.py @@ -0,0 +1,493 @@ +from sympy.core import Function, S, Mul, Pow, Add +from sympy.core.sorting import ordered, default_sort_key +from sympy.core.function import expand_func +from sympy.core.symbol import Dummy +from sympy.functions import gamma, sqrt, sin +from sympy.polys import factor, cancel +from sympy.utilities.iterables import sift, uniq + + +def gammasimp(expr): + r""" + Simplify expressions with gamma functions. + + Explanation + =========== + + This function takes as input an expression containing gamma + functions or functions that can be rewritten in terms of gamma + functions and tries to minimize the number of those functions and + reduce the size of their arguments. + + The algorithm works by rewriting all gamma functions as expressions + involving rising factorials (Pochhammer symbols) and applies + recurrence relations and other transformations applicable to rising + factorials, to reduce their arguments, possibly letting the resulting + rising factorial to cancel. Rising factorials with the second argument + being an integer are expanded into polynomial forms and finally all + other rising factorial are rewritten in terms of gamma functions. + + Then the following two steps are performed. + + 1. Reduce the number of gammas by applying the reflection theorem + gamma(x)*gamma(1-x) == pi/sin(pi*x). + 2. Reduce the number of gammas by applying the multiplication theorem + gamma(x)*gamma(x+1/n)*...*gamma(x+(n-1)/n) == C*gamma(n*x). + + It then reduces the number of prefactors by absorbing them into gammas + where possible and expands gammas with rational argument. + + All transformation rules can be found (or were derived from) here: + + .. [1] https://functions.wolfram.com/GammaBetaErf/Pochhammer/17/01/02/ + .. [2] https://functions.wolfram.com/GammaBetaErf/Pochhammer/27/01/0005/ + + Examples + ======== + + >>> from sympy.simplify import gammasimp + >>> from sympy import gamma, Symbol + >>> from sympy.abc import x + >>> n = Symbol('n', integer = True) + + >>> gammasimp(gamma(x)/gamma(x - 3)) + (x - 3)*(x - 2)*(x - 1) + >>> gammasimp(gamma(n + 3)) + gamma(n + 3) + + """ + + expr = expr.rewrite(gamma) + + # compute_ST will be looking for Functions and we don't want + # it looking for non-gamma functions: issue 22606 + # so we mask free, non-gamma functions + f = expr.atoms(Function) + # take out gammas + gammas = {i for i in f if isinstance(i, gamma)} + if not gammas: + return expr # avoid side effects like factoring + f -= gammas + # keep only those without bound symbols + f = f & expr.as_dummy().atoms(Function) + if f: + dum, fun, simp = zip(*[ + (Dummy(), fi, fi.func(*[ + _gammasimp(a, as_comb=False) for a in fi.args])) + for fi in ordered(f)]) + d = expr.xreplace(dict(zip(fun, dum))) + return _gammasimp(d, as_comb=False).xreplace(dict(zip(dum, simp))) + + return _gammasimp(expr, as_comb=False) + + +def _gammasimp(expr, as_comb): + """ + Helper function for gammasimp and combsimp. + + Explanation + =========== + + Simplifies expressions written in terms of gamma function. If + as_comb is True, it tries to preserve integer arguments. See + docstring of gammasimp for more information. This was part of + combsimp() in combsimp.py. + """ + expr = expr.replace(gamma, + lambda n: _rf(1, (n - 1).expand())) + + if as_comb: + expr = expr.replace(_rf, + lambda a, b: gamma(b + 1)) + else: + expr = expr.replace(_rf, + lambda a, b: gamma(a + b)/gamma(a)) + + def rule_gamma(expr, level=0): + """ Simplify products of gamma functions further. """ + + if expr.is_Atom: + return expr + + def gamma_rat(x): + # helper to simplify ratios of gammas + was = x.count(gamma) + xx = x.replace(gamma, lambda n: _rf(1, (n - 1).expand() + ).replace(_rf, lambda a, b: gamma(a + b)/gamma(a))) + if xx.count(gamma) < was: + x = xx + return x + + def gamma_factor(x): + # return True if there is a gamma factor in shallow args + if isinstance(x, gamma): + return True + if x.is_Add or x.is_Mul: + return any(gamma_factor(xi) for xi in x.args) + if x.is_Pow and (x.exp.is_integer or x.base.is_positive): + return gamma_factor(x.base) + return False + + # recursion step + if level == 0: + expr = expr.func(*[rule_gamma(x, level + 1) for x in expr.args]) + level += 1 + + if not expr.is_Mul: + return expr + + # non-commutative step + if level == 1: + args, nc = expr.args_cnc() + if not args: + return expr + if nc: + return rule_gamma(Mul._from_args(args), level + 1)*Mul._from_args(nc) + level += 1 + + # pure gamma handling, not factor absorption + if level == 2: + T, F = sift(expr.args, gamma_factor, binary=True) + gamma_ind = Mul(*F) + d = Mul(*T) + + nd, dd = d.as_numer_denom() + for ipass in range(2): + args = list(ordered(Mul.make_args(nd))) + for i, ni in enumerate(args): + if ni.is_Add: + ni, dd = Add(*[ + rule_gamma(gamma_rat(a/dd), level + 1) for a in ni.args] + ).as_numer_denom() + args[i] = ni + if not dd.has(gamma): + break + nd = Mul(*args) + if ipass == 0 and not gamma_factor(nd): + break + nd, dd = dd, nd # now process in reversed order + expr = gamma_ind*nd/dd + if not (expr.is_Mul and (gamma_factor(dd) or gamma_factor(nd))): + return expr + level += 1 + + # iteration until constant + if level == 3: + while True: + was = expr + expr = rule_gamma(expr, 4) + if expr == was: + return expr + + numer_gammas = [] + denom_gammas = [] + numer_others = [] + denom_others = [] + def explicate(p): + if p is S.One: + return None, [] + b, e = p.as_base_exp() + if e.is_Integer: + if isinstance(b, gamma): + return True, [b.args[0]]*e + else: + return False, [b]*e + else: + return False, [p] + + newargs = list(ordered(expr.args)) + while newargs: + n, d = newargs.pop().as_numer_denom() + isg, l = explicate(n) + if isg: + numer_gammas.extend(l) + elif isg is False: + numer_others.extend(l) + isg, l = explicate(d) + if isg: + denom_gammas.extend(l) + elif isg is False: + denom_others.extend(l) + + # =========== level 2 work: pure gamma manipulation ========= + + if not as_comb: + # Try to reduce the number of gamma factors by applying the + # reflection formula gamma(x)*gamma(1-x) = pi/sin(pi*x) + for gammas, numer, denom in [( + numer_gammas, numer_others, denom_others), + (denom_gammas, denom_others, numer_others)]: + new = [] + while gammas: + g1 = gammas.pop() + if g1.is_integer: + new.append(g1) + continue + for i, g2 in enumerate(gammas): + n = g1 + g2 - 1 + if not n.is_Integer: + continue + numer.append(S.Pi) + denom.append(sin(S.Pi*g1)) + gammas.pop(i) + if n > 0: + numer.extend(1 - g1 + k for k in range(n)) + elif n < 0: + denom.extend(-g1 - k for k in range(-n)) + break + else: + new.append(g1) + # /!\ updating IN PLACE + gammas[:] = new + + # Try to reduce the number of gammas by using the duplication + # theorem to cancel an upper and lower: gamma(2*s)/gamma(s) = + # 2**(2*s + 1)/(4*sqrt(pi))*gamma(s + 1/2). Although this could + # be done with higher argument ratios like gamma(3*x)/gamma(x), + # this would not reduce the number of gammas as in this case. + for ng, dg, no, do in [(numer_gammas, denom_gammas, numer_others, + denom_others), + (denom_gammas, numer_gammas, denom_others, + numer_others)]: + + while True: + for x in ng: + for y in dg: + n = x - 2*y + if n.is_Integer: + break + else: + continue + break + else: + break + ng.remove(x) + dg.remove(y) + if n > 0: + no.extend(2*y + k for k in range(n)) + elif n < 0: + do.extend(2*y - 1 - k for k in range(-n)) + ng.append(y + S.Half) + no.append(2**(2*y - 1)) + do.append(sqrt(S.Pi)) + + # Try to reduce the number of gamma factors by applying the + # multiplication theorem (used when n gammas with args differing + # by 1/n mod 1 are encountered). + # + # run of 2 with args differing by 1/2 + # + # >>> gammasimp(gamma(x)*gamma(x+S.Half)) + # 2*sqrt(2)*2**(-2*x - 1/2)*sqrt(pi)*gamma(2*x) + # + # run of 3 args differing by 1/3 (mod 1) + # + # >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(2)/3)) + # 6*3**(-3*x - 1/2)*pi*gamma(3*x) + # >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(5)/3)) + # 2*3**(-3*x - 1/2)*pi*(3*x + 2)*gamma(3*x) + # + def _run(coeffs): + # find runs in coeffs such that the difference in terms (mod 1) + # of t1, t2, ..., tn is 1/n + u = list(uniq(coeffs)) + for i in range(len(u)): + dj = ([((u[j] - u[i]) % 1, j) for j in range(i + 1, len(u))]) + for one, j in dj: + if one.p == 1 and one.q != 1: + n = one.q + got = [i] + get = list(range(1, n)) + for d, j in dj: + m = n*d + if m.is_Integer and m in get: + get.remove(m) + got.append(j) + if not get: + break + else: + continue + for i, j in enumerate(got): + c = u[j] + coeffs.remove(c) + got[i] = c + return one.q, got[0], got[1:] + + def _mult_thm(gammas, numer, denom): + # pull off and analyze the leading coefficient from each gamma arg + # looking for runs in those Rationals + + # expr -> coeff + resid -> rats[resid] = coeff + rats = {} + for g in gammas: + c, resid = g.as_coeff_Add() + rats.setdefault(resid, []).append(c) + + # look for runs in Rationals for each resid + keys = sorted(rats, key=default_sort_key) + for resid in keys: + coeffs = sorted(rats[resid]) + new = [] + while True: + run = _run(coeffs) + if run is None: + break + + # process the sequence that was found: + # 1) convert all the gamma functions to have the right + # argument (could be off by an integer) + # 2) append the factors corresponding to the theorem + # 3) append the new gamma function + + n, ui, other = run + + # (1) + for u in other: + con = resid + u - 1 + for k in range(int(u - ui)): + numer.append(con - k) + + con = n*(resid + ui) # for (2) and (3) + + # (2) + numer.append((2*S.Pi)**(S(n - 1)/2)* + n**(S.Half - con)) + # (3) + new.append(con) + + # restore resid to coeffs + rats[resid] = [resid + c for c in coeffs] + new + + # rebuild the gamma arguments + g = [] + for resid in keys: + g += rats[resid] + # /!\ updating IN PLACE + gammas[:] = g + + for l, numer, denom in [(numer_gammas, numer_others, denom_others), + (denom_gammas, denom_others, numer_others)]: + _mult_thm(l, numer, denom) + + # =========== level >= 2 work: factor absorption ========= + + if level >= 2: + # Try to absorb factors into the gammas: x*gamma(x) -> gamma(x + 1) + # and gamma(x)/(x - 1) -> gamma(x - 1) + # This code (in particular repeated calls to find_fuzzy) can be very + # slow. + def find_fuzzy(l, x): + if not l: + return + S1, T1 = compute_ST(x) + for y in l: + S2, T2 = inv[y] + if T1 != T2 or (not S1.intersection(S2) and + (S1 != set() or S2 != set())): + continue + # XXX we want some simplification (e.g. cancel or + # simplify) but no matter what it's slow. + a = len(cancel(x/y).free_symbols) + b = len(x.free_symbols) + c = len(y.free_symbols) + # TODO is there a better heuristic? + if a == 0 and (b > 0 or c > 0): + return y + + # We thus try to avoid expensive calls by building the following + # "invariants": For every factor or gamma function argument + # - the set of free symbols S + # - the set of functional components T + # We will only try to absorb if T1==T2 and (S1 intersect S2 != emptyset + # or S1 == S2 == emptyset) + inv = {} + + def compute_ST(expr): + if expr in inv: + return inv[expr] + return (expr.free_symbols, expr.atoms(Function).union( + {e.exp for e in expr.atoms(Pow)})) + + def update_ST(expr): + inv[expr] = compute_ST(expr) + for expr in numer_gammas + denom_gammas + numer_others + denom_others: + update_ST(expr) + + for gammas, numer, denom in [( + numer_gammas, numer_others, denom_others), + (denom_gammas, denom_others, numer_others)]: + new = [] + while gammas: + g = gammas.pop() + cont = True + while cont: + cont = False + y = find_fuzzy(numer, g) + if y is not None: + numer.remove(y) + if y != g: + numer.append(y/g) + update_ST(y/g) + g += 1 + cont = True + y = find_fuzzy(denom, g - 1) + if y is not None: + denom.remove(y) + if y != g - 1: + numer.append((g - 1)/y) + update_ST((g - 1)/y) + g -= 1 + cont = True + new.append(g) + # /!\ updating IN PLACE + gammas[:] = new + + # =========== rebuild expr ================================== + + return Mul(*[gamma(g) for g in numer_gammas]) \ + / Mul(*[gamma(g) for g in denom_gammas]) \ + * Mul(*numer_others) / Mul(*denom_others) + + was = factor(expr) + # (for some reason we cannot use Basic.replace in this case) + expr = rule_gamma(was) + if expr != was: + expr = factor(expr) + + expr = expr.replace(gamma, + lambda n: expand_func(gamma(n)) if n.is_Rational else gamma(n)) + + return expr + + +class _rf(Function): + @classmethod + def eval(cls, a, b): + if b.is_Integer: + if not b: + return S.One + + n = int(b) + + if n > 0: + return Mul(*[a + i for i in range(n)]) + elif n < 0: + return 1/Mul(*[a - i for i in range(1, -n + 1)]) + else: + if b.is_Add: + c, _b = b.as_coeff_Add() + + if c.is_Integer: + if c > 0: + return _rf(a, _b)*_rf(a + _b, c) + elif c < 0: + return _rf(a, _b)/_rf(a + _b + c, -c) + + if a.is_Add: + c, _a = a.as_coeff_Add() + + if c.is_Integer: + if c > 0: + return _rf(_a, b)*_rf(_a + b, c)/_rf(_a, c) + elif c < 0: + return _rf(_a, b)*_rf(_a + c, -c)/_rf(_a + b + c, -c) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/hyperexpand.py b/.venv/lib/python3.13/site-packages/sympy/simplify/hyperexpand.py new file mode 100644 index 0000000000000000000000000000000000000000..c070aa2e44b92794107b3e33df897813a54307b9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/hyperexpand.py @@ -0,0 +1,2494 @@ +""" +Expand Hypergeometric (and Meijer G) functions into named +special functions. + +The algorithm for doing this uses a collection of lookup tables of +hypergeometric functions, and various of their properties, to expand +many hypergeometric functions in terms of special functions. + +It is based on the following paper: + Kelly B. Roach. Meijer G Function Representations. + In: Proceedings of the 1997 International Symposium on Symbolic and + Algebraic Computation, pages 205-211, New York, 1997. ACM. + +It is described in great(er) detail in the Sphinx documentation. +""" +# SUMMARY OF EXTENSIONS FOR MEIJER G FUNCTIONS +# +# o z**rho G(ap, bq; z) = G(ap + rho, bq + rho; z) +# +# o denote z*d/dz by D +# +# o It is helpful to keep in mind that ap and bq play essentially symmetric +# roles: G(1/z) has slightly altered parameters, with ap and bq interchanged. +# +# o There are four shift operators: +# A_J = b_J - D, J = 1, ..., n +# B_J = 1 - a_j + D, J = 1, ..., m +# C_J = -b_J + D, J = m+1, ..., q +# D_J = a_J - 1 - D, J = n+1, ..., p +# +# A_J, C_J increment b_J +# B_J, D_J decrement a_J +# +# o The corresponding four inverse-shift operators are defined if there +# is no cancellation. Thus e.g. an index a_J (upper or lower) can be +# incremented if a_J != b_i for i = 1, ..., q. +# +# o Order reduction: if b_j - a_i is a non-negative integer, where +# j <= m and i > n, the corresponding quotient of gamma functions reduces +# to a polynomial. Hence the G function can be expressed using a G-function +# of lower order. +# Similarly if j > m and i <= n. +# +# Secondly, there are paired index theorems [Adamchik, The evaluation of +# integrals of Bessel functions via G-function identities]. Suppose there +# are three parameters a, b, c, where a is an a_i, i <= n, b is a b_j, +# j <= m and c is a denominator parameter (i.e. a_i, i > n or b_j, j > m). +# Suppose further all three differ by integers. +# Then the order can be reduced. +# TODO work this out in detail. +# +# o An index quadruple is called suitable if its order cannot be reduced. +# If there exists a sequence of shift operators transforming one index +# quadruple into another, we say one is reachable from the other. +# +# o Deciding if one index quadruple is reachable from another is tricky. For +# this reason, we use hand-built routines to match and instantiate formulas. +# +from collections import defaultdict +from itertools import product +from functools import reduce +from math import prod + +from sympy import SYMPY_DEBUG +from sympy.core import (S, Dummy, symbols, sympify, Tuple, expand, I, pi, Mul, + EulerGamma, oo, zoo, expand_func, Add, nan, Expr, Rational) +from sympy.core.mod import Mod +from sympy.core.sorting import default_sort_key +from sympy.functions import (exp, sqrt, root, log, lowergamma, cos, + besseli, gamma, uppergamma, expint, erf, sin, besselj, Ei, Ci, Si, Shi, + sinh, cosh, Chi, fresnels, fresnelc, polar_lift, exp_polar, floor, ceiling, + rf, factorial, lerchphi, Piecewise, re, elliptic_k, elliptic_e) +from sympy.functions.elementary.complexes import polarify, unpolarify +from sympy.functions.special.hyper import (hyper, HyperRep_atanh, + HyperRep_power1, HyperRep_power2, HyperRep_log1, HyperRep_asin1, + HyperRep_asin2, HyperRep_sqrts1, HyperRep_sqrts2, HyperRep_log2, + HyperRep_cosasin, HyperRep_sinasin, meijerg) +from sympy.matrices import Matrix, eye, zeros +from sympy.polys import apart, poly, Poly +from sympy.series import residue +from sympy.simplify.powsimp import powdenest +from sympy.utilities.iterables import sift + +# function to define "buckets" +def _mod1(x): + # TODO see if this can work as Mod(x, 1); this will require + # different handling of the "buckets" since these need to + # be sorted and that fails when there is a mixture of + # integers and expressions with parameters. With the current + # Mod behavior, Mod(k, 1) == Mod(1, 1) == 0 if k is an integer. + # Although the sorting can be done with Basic.compare, this may + # still require different handling of the sorted buckets. + if x.is_Number: + return Mod(x, 1) + c, x = x.as_coeff_Add() + return Mod(c, 1) + x + + +# leave add formulae at the top for easy reference +def add_formulae(formulae): + """ Create our knowledge base. """ + a, b, c, z = symbols('a b c, z', cls=Dummy) + + def add(ap, bq, res): + func = Hyper_Function(ap, bq) + formulae.append(Formula(func, z, res, (a, b, c))) + + def addb(ap, bq, B, C, M): + func = Hyper_Function(ap, bq) + formulae.append(Formula(func, z, None, (a, b, c), B, C, M)) + + # Luke, Y. L. (1969), The Special Functions and Their Approximations, + # Volume 1, section 6.2 + + # 0F0 + add((), (), exp(z)) + + # 1F0 + add((a, ), (), HyperRep_power1(-a, z)) + + # 2F1 + addb((a, a - S.Half), (2*a, ), + Matrix([HyperRep_power2(a, z), + HyperRep_power2(a + S.Half, z)/2]), + Matrix([[1, 0]]), + Matrix([[(a - S.Half)*z/(1 - z), (S.Half - a)*z/(1 - z)], + [a/(1 - z), a*(z - 2)/(1 - z)]])) + addb((1, 1), (2, ), + Matrix([HyperRep_log1(z), 1]), Matrix([[-1/z, 0]]), + Matrix([[0, z/(z - 1)], [0, 0]])) + addb((S.Half, 1), (S('3/2'), ), + Matrix([HyperRep_atanh(z), 1]), + Matrix([[1, 0]]), + Matrix([[Rational(-1, 2), 1/(1 - z)/2], [0, 0]])) + addb((S.Half, S.Half), (S('3/2'), ), + Matrix([HyperRep_asin1(z), HyperRep_power1(Rational(-1, 2), z)]), + Matrix([[1, 0]]), + Matrix([[Rational(-1, 2), S.Half], [0, z/(1 - z)/2]])) + addb((a, S.Half + a), (S.Half, ), + Matrix([HyperRep_sqrts1(-a, z), -HyperRep_sqrts2(-a - S.Half, z)]), + Matrix([[1, 0]]), + Matrix([[0, -a], + [z*(-2*a - 1)/2/(1 - z), S.Half - z*(-2*a - 1)/(1 - z)]])) + + # A. P. Prudnikov, Yu. A. Brychkov and O. I. Marichev (1990). + # Integrals and Series: More Special Functions, Vol. 3,. + # Gordon and Breach Science Publisher + addb([a, -a], [S.Half], + Matrix([HyperRep_cosasin(a, z), HyperRep_sinasin(a, z)]), + Matrix([[1, 0]]), + Matrix([[0, -a], [a*z/(1 - z), 1/(1 - z)/2]])) + addb([1, 1], [3*S.Half], + Matrix([HyperRep_asin2(z), 1]), Matrix([[1, 0]]), + Matrix([[(z - S.Half)/(1 - z), 1/(1 - z)/2], [0, 0]])) + + # Complete elliptic integrals K(z) and E(z), both a 2F1 function + addb([S.Half, S.Half], [S.One], + Matrix([elliptic_k(z), elliptic_e(z)]), + Matrix([[2/pi, 0]]), + Matrix([[Rational(-1, 2), -1/(2*z-2)], + [Rational(-1, 2), S.Half]])) + addb([Rational(-1, 2), S.Half], [S.One], + Matrix([elliptic_k(z), elliptic_e(z)]), + Matrix([[0, 2/pi]]), + Matrix([[Rational(-1, 2), -1/(2*z-2)], + [Rational(-1, 2), S.Half]])) + + # 3F2 + addb([Rational(-1, 2), 1, 1], [S.Half, 2], + Matrix([z*HyperRep_atanh(z), HyperRep_log1(z), 1]), + Matrix([[Rational(-2, 3), -S.One/(3*z), Rational(2, 3)]]), + Matrix([[S.Half, 0, z/(1 - z)/2], + [0, 0, z/(z - 1)], + [0, 0, 0]])) + # actually the formula for 3/2 is much nicer ... + addb([Rational(-1, 2), 1, 1], [2, 2], + Matrix([HyperRep_power1(S.Half, z), HyperRep_log2(z), 1]), + Matrix([[Rational(4, 9) - 16/(9*z), 4/(3*z), 16/(9*z)]]), + Matrix([[z/2/(z - 1), 0, 0], [1/(2*(z - 1)), 0, S.Half], [0, 0, 0]])) + + # 1F1 + addb([1], [b], Matrix([z**(1 - b) * exp(z) * lowergamma(b - 1, z), 1]), + Matrix([[b - 1, 0]]), Matrix([[1 - b + z, 1], [0, 0]])) + addb([a], [2*a], + Matrix([z**(S.Half - a)*exp(z/2)*besseli(a - S.Half, z/2) + * gamma(a + S.Half)/4**(S.Half - a), + z**(S.Half - a)*exp(z/2)*besseli(a + S.Half, z/2) + * gamma(a + S.Half)/4**(S.Half - a)]), + Matrix([[1, 0]]), + Matrix([[z/2, z/2], [z/2, (z/2 - 2*a)]])) + mz = polar_lift(-1)*z + addb([a], [a + 1], + Matrix([mz**(-a)*a*lowergamma(a, mz), a*exp(z)]), + Matrix([[1, 0]]), + Matrix([[-a, 1], [0, z]])) + # This one is redundant. + add([Rational(-1, 2)], [S.Half], exp(z) - sqrt(pi*z)*(-I)*erf(I*sqrt(z))) + + # Added to get nice results for Laplace transform of Fresnel functions + # https://functions.wolfram.com/07.22.03.6437.01 + # Basic rule + #add([1], [Rational(3, 4), Rational(5, 4)], + # sqrt(pi) * (cos(2*sqrt(polar_lift(-1)*z))*fresnelc(2*root(polar_lift(-1)*z,4)/sqrt(pi)) + + # sin(2*sqrt(polar_lift(-1)*z))*fresnels(2*root(polar_lift(-1)*z,4)/sqrt(pi))) + # / (2*root(polar_lift(-1)*z,4))) + # Manually tuned rule + addb([1], [Rational(3, 4), Rational(5, 4)], + Matrix([ sqrt(pi)*(I*sinh(2*sqrt(z))*fresnels(2*root(z, 4)*exp(I*pi/4)/sqrt(pi)) + + cosh(2*sqrt(z))*fresnelc(2*root(z, 4)*exp(I*pi/4)/sqrt(pi))) + * exp(-I*pi/4)/(2*root(z, 4)), + sqrt(pi)*root(z, 4)*(sinh(2*sqrt(z))*fresnelc(2*root(z, 4)*exp(I*pi/4)/sqrt(pi)) + + I*cosh(2*sqrt(z))*fresnels(2*root(z, 4)*exp(I*pi/4)/sqrt(pi))) + *exp(-I*pi/4)/2, + 1 ]), + Matrix([[1, 0, 0]]), + Matrix([[Rational(-1, 4), 1, Rational(1, 4)], + [ z, Rational(1, 4), 0], + [ 0, 0, 0]])) + + # 2F2 + addb([S.Half, a], [Rational(3, 2), a + 1], + Matrix([a/(2*a - 1)*(-I)*sqrt(pi/z)*erf(I*sqrt(z)), + a/(2*a - 1)*(polar_lift(-1)*z)**(-a)* + lowergamma(a, polar_lift(-1)*z), + a/(2*a - 1)*exp(z)]), + Matrix([[1, -1, 0]]), + Matrix([[Rational(-1, 2), 0, 1], [0, -a, 1], [0, 0, z]])) + # We make a "basis" of four functions instead of three, and give EulerGamma + # an extra slot (it could just be a coefficient to 1). The advantage is + # that this way Polys will not see multivariate polynomials (it treats + # EulerGamma as an indeterminate), which is *way* faster. + addb([1, 1], [2, 2], + Matrix([Ei(z) - log(z), exp(z), 1, EulerGamma]), + Matrix([[1/z, 0, 0, -1/z]]), + Matrix([[0, 1, -1, 0], [0, z, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])) + + # 0F1 + add((), (S.Half, ), cosh(2*sqrt(z))) + addb([], [b], + Matrix([gamma(b)*z**((1 - b)/2)*besseli(b - 1, 2*sqrt(z)), + gamma(b)*z**(1 - b/2)*besseli(b, 2*sqrt(z))]), + Matrix([[1, 0]]), Matrix([[0, 1], [z, (1 - b)]])) + + # 0F3 + x = 4*z**Rational(1, 4) + + def fp(a, z): + return besseli(a, x) + besselj(a, x) + + def fm(a, z): + return besseli(a, x) - besselj(a, x) + + # TODO branching + addb([], [S.Half, a, a + S.Half], + Matrix([fp(2*a - 1, z), fm(2*a, z)*z**Rational(1, 4), + fm(2*a - 1, z)*sqrt(z), fp(2*a, z)*z**Rational(3, 4)]) + * 2**(-2*a)*gamma(2*a)*z**((1 - 2*a)/4), + Matrix([[1, 0, 0, 0]]), + Matrix([[0, 1, 0, 0], + [0, S.Half - a, 1, 0], + [0, 0, S.Half, 1], + [z, 0, 0, 1 - a]])) + x = 2*(4*z)**Rational(1, 4)*exp_polar(I*pi/4) + addb([], [a, a + S.Half, 2*a], + (2*sqrt(polar_lift(-1)*z))**(1 - 2*a)*gamma(2*a)**2 * + Matrix([besselj(2*a - 1, x)*besseli(2*a - 1, x), + x*(besseli(2*a, x)*besselj(2*a - 1, x) + - besseli(2*a - 1, x)*besselj(2*a, x)), + x**2*besseli(2*a, x)*besselj(2*a, x), + x**3*(besseli(2*a, x)*besselj(2*a - 1, x) + + besseli(2*a - 1, x)*besselj(2*a, x))]), + Matrix([[1, 0, 0, 0]]), + Matrix([[0, Rational(1, 4), 0, 0], + [0, (1 - 2*a)/2, Rational(-1, 2), 0], + [0, 0, 1 - 2*a, Rational(1, 4)], + [-32*z, 0, 0, 1 - a]])) + + # 1F2 + addb([a], [a - S.Half, 2*a], + Matrix([z**(S.Half - a)*besseli(a - S.Half, sqrt(z))**2, + z**(1 - a)*besseli(a - S.Half, sqrt(z)) + *besseli(a - Rational(3, 2), sqrt(z)), + z**(Rational(3, 2) - a)*besseli(a - Rational(3, 2), sqrt(z))**2]), + Matrix([[-gamma(a + S.Half)**2/4**(S.Half - a), + 2*gamma(a - S.Half)*gamma(a + S.Half)/4**(1 - a), + 0]]), + Matrix([[1 - 2*a, 1, 0], [z/2, S.Half - a, S.Half], [0, z, 0]])) + addb([S.Half], [b, 2 - b], + pi*(1 - b)/sin(pi*b)* + Matrix([besseli(1 - b, sqrt(z))*besseli(b - 1, sqrt(z)), + sqrt(z)*(besseli(-b, sqrt(z))*besseli(b - 1, sqrt(z)) + + besseli(1 - b, sqrt(z))*besseli(b, sqrt(z))), + besseli(-b, sqrt(z))*besseli(b, sqrt(z))]), + Matrix([[1, 0, 0]]), + Matrix([[b - 1, S.Half, 0], + [z, 0, z], + [0, S.Half, -b]])) + addb([S.Half], [Rational(3, 2), Rational(3, 2)], + Matrix([Shi(2*sqrt(z))/2/sqrt(z), sinh(2*sqrt(z))/2/sqrt(z), + cosh(2*sqrt(z))]), + Matrix([[1, 0, 0]]), + Matrix([[Rational(-1, 2), S.Half, 0], [0, Rational(-1, 2), S.Half], [0, 2*z, 0]])) + + # FresnelS + # Basic rule + #add([Rational(3, 4)], [Rational(3, 2),Rational(7, 4)], 6*fresnels( exp(pi*I/4)*root(z,4)*2/sqrt(pi) ) / ( pi * (exp(pi*I/4)*root(z,4)*2/sqrt(pi))**3 ) ) + # Manually tuned rule + addb([Rational(3, 4)], [Rational(3, 2), Rational(7, 4)], + Matrix( + [ fresnels( + exp( + pi*I/4)*root( + z, 4)*2/sqrt( + pi) ) / ( + pi * (exp(pi*I/4)*root(z, 4)*2/sqrt(pi))**3 ), + sinh(2*sqrt(z))/sqrt(z), + cosh(2*sqrt(z)) ]), + Matrix([[6, 0, 0]]), + Matrix([[Rational(-3, 4), Rational(1, 16), 0], + [ 0, Rational(-1, 2), 1], + [ 0, z, 0]])) + + # FresnelC + # Basic rule + #add([Rational(1, 4)], [S.Half,Rational(5, 4)], fresnelc( exp(pi*I/4)*root(z,4)*2/sqrt(pi) ) / ( exp(pi*I/4)*root(z,4)*2/sqrt(pi) ) ) + # Manually tuned rule + addb([Rational(1, 4)], [S.Half, Rational(5, 4)], + Matrix( + [ sqrt( + pi)*exp( + -I*pi/4)*fresnelc( + 2*root(z, 4)*exp(I*pi/4)/sqrt(pi))/(2*root(z, 4)), + cosh(2*sqrt(z)), + sinh(2*sqrt(z))*sqrt(z) ]), + Matrix([[1, 0, 0]]), + Matrix([[Rational(-1, 4), Rational(1, 4), 0 ], + [ 0, 0, 1 ], + [ 0, z, S.Half]])) + + # 2F3 + # XXX with this five-parameter formula is pretty slow with the current + # Formula.find_instantiations (creates 2!*3!*3**(2+3) ~ 3000 + # instantiations ... But it's not too bad. + addb([a, a + S.Half], [2*a, b, 2*a - b + 1], + gamma(b)*gamma(2*a - b + 1) * (sqrt(z)/2)**(1 - 2*a) * + Matrix([besseli(b - 1, sqrt(z))*besseli(2*a - b, sqrt(z)), + sqrt(z)*besseli(b, sqrt(z))*besseli(2*a - b, sqrt(z)), + sqrt(z)*besseli(b - 1, sqrt(z))*besseli(2*a - b + 1, sqrt(z)), + besseli(b, sqrt(z))*besseli(2*a - b + 1, sqrt(z))]), + Matrix([[1, 0, 0, 0]]), + Matrix([[0, S.Half, S.Half, 0], + [z/2, 1 - b, 0, z/2], + [z/2, 0, b - 2*a, z/2], + [0, S.Half, S.Half, -2*a]])) + # (C/f above comment about eulergamma in the basis). + addb([1, 1], [2, 2, Rational(3, 2)], + Matrix([Chi(2*sqrt(z)) - log(2*sqrt(z)), + cosh(2*sqrt(z)), sqrt(z)*sinh(2*sqrt(z)), 1, EulerGamma]), + Matrix([[1/z, 0, 0, 0, -1/z]]), + Matrix([[0, S.Half, 0, Rational(-1, 2), 0], + [0, 0, 1, 0, 0], + [0, z, S.Half, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]])) + + # 3F3 + # This is rule: https://functions.wolfram.com/07.31.03.0134.01 + # Initial reason to add it was a nice solution for + # integrate(erf(a*z)/z**2, z) and same for erfc and erfi. + # Basic rule + # add([1, 1, a], [2, 2, a+1], (a/(z*(a-1)**2)) * + # (1 - (-z)**(1-a) * (gamma(a) - uppergamma(a,-z)) + # - (a-1) * (EulerGamma + uppergamma(0,-z) + log(-z)) + # - exp(z))) + # Manually tuned rule + addb([1, 1, a], [2, 2, a+1], + Matrix([a*(log(-z) + expint(1, -z) + EulerGamma)/(z*(a**2 - 2*a + 1)), + a*(-z)**(-a)*(gamma(a) - uppergamma(a, -z))/(a - 1)**2, + a*exp(z)/(a**2 - 2*a + 1), + a/(z*(a**2 - 2*a + 1))]), + Matrix([[1-a, 1, -1/z, 1]]), + Matrix([[-1,0,-1/z,1], + [0,-a,1,0], + [0,0,z,0], + [0,0,0,-1]])) + + +def add_meijerg_formulae(formulae): + a, b, c, z = list(map(Dummy, 'abcz')) + rho = Dummy('rho') + + def add(an, ap, bm, bq, B, C, M, matcher): + formulae.append(MeijerFormula(an, ap, bm, bq, z, [a, b, c, rho], + B, C, M, matcher)) + + def detect_uppergamma(func): + x = func.an[0] + y, z = func.bm + swapped = False + if not _mod1((x - y).simplify()): + swapped = True + (y, z) = (z, y) + if _mod1((x - z).simplify()) or x - z > 0: + return None + l = [y, x] + if swapped: + l = [x, y] + return {rho: y, a: x - y}, G_Function([x], [], l, []) + + add([a + rho], [], [rho, a + rho], [], + Matrix([gamma(1 - a)*z**rho*exp(z)*uppergamma(a, z), + gamma(1 - a)*z**(a + rho)]), + Matrix([[1, 0]]), + Matrix([[rho + z, -1], [0, a + rho]]), + detect_uppergamma) + + def detect_3113(func): + """https://functions.wolfram.com/07.34.03.0984.01""" + x = func.an[0] + u, v, w = func.bm + if _mod1((u - v).simplify()) == 0: + if _mod1((v - w).simplify()) == 0: + return + sig = (S.Half, S.Half, S.Zero) + x1, x2, y = u, v, w + else: + if _mod1((x - u).simplify()) == 0: + sig = (S.Half, S.Zero, S.Half) + x1, y, x2 = u, v, w + else: + sig = (S.Zero, S.Half, S.Half) + y, x1, x2 = u, v, w + + if (_mod1((x - x1).simplify()) != 0 or + _mod1((x - x2).simplify()) != 0 or + _mod1((x - y).simplify()) != S.Half or + x - x1 > 0 or x - x2 > 0): + return + + return {a: x}, G_Function([x], [], [x - S.Half + t for t in sig], []) + + s = sin(2*sqrt(z)) + c_ = cos(2*sqrt(z)) + S_ = Si(2*sqrt(z)) - pi/2 + C = Ci(2*sqrt(z)) + add([a], [], [a, a, a - S.Half], [], + Matrix([sqrt(pi)*z**(a - S.Half)*(c_*S_ - s*C), + sqrt(pi)*z**a*(s*S_ + c_*C), + sqrt(pi)*z**a]), + Matrix([[-2, 0, 0]]), + Matrix([[a - S.Half, -1, 0], [z, a, S.Half], [0, 0, a]]), + detect_3113) + + +def make_simp(z): + """ Create a function that simplifies rational functions in ``z``. """ + + def simp(expr): + """ Efficiently simplify the rational function ``expr``. """ + numer, denom = expr.as_numer_denom() + numer = numer.expand() + # denom = denom.expand() # is this needed? + c, numer, denom = poly(numer, z).cancel(poly(denom, z)) + return c * numer.as_expr() / denom.as_expr() + + return simp + + +def debug(*args): + if SYMPY_DEBUG: + for a in args: + print(a, end="") + print() + + +class Hyper_Function(Expr): + """ A generalized hypergeometric function. """ + + def __new__(cls, ap, bq): + obj = super().__new__(cls) + obj.ap = Tuple(*list(map(expand, ap))) + obj.bq = Tuple(*list(map(expand, bq))) + return obj + + @property + def args(self): + return (self.ap, self.bq) + + @property + def sizes(self): + return (len(self.ap), len(self.bq)) + + @property + def gamma(self): + """ + Number of upper parameters that are negative integers + + This is a transformation invariant. + """ + return sum(bool(x.is_integer and x.is_negative) for x in self.ap) + + def _hashable_content(self): + return super()._hashable_content() + (self.ap, + self.bq) + + def __call__(self, arg): + return hyper(self.ap, self.bq, arg) + + def build_invariants(self): + """ + Compute the invariant vector. + + Explanation + =========== + + The invariant vector is: + (gamma, ((s1, n1), ..., (sk, nk)), ((t1, m1), ..., (tr, mr))) + where gamma is the number of integer a < 0, + s1 < ... < sk + nl is the number of parameters a_i congruent to sl mod 1 + t1 < ... < tr + ml is the number of parameters b_i congruent to tl mod 1 + + If the index pair contains parameters, then this is not truly an + invariant, since the parameters cannot be sorted uniquely mod1. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import Hyper_Function + >>> from sympy import S + >>> ap = (S.Half, S.One/3, S(-1)/2, -2) + >>> bq = (1, 2) + + Here gamma = 1, + k = 3, s1 = 0, s2 = 1/3, s3 = 1/2 + n1 = 1, n2 = 1, n2 = 2 + r = 1, t1 = 0 + m1 = 2: + + >>> Hyper_Function(ap, bq).build_invariants() + (1, ((0, 1), (1/3, 1), (1/2, 2)), ((0, 2),)) + """ + abuckets, bbuckets = sift(self.ap, _mod1), sift(self.bq, _mod1) + + def tr(bucket): + bucket = list(bucket.items()) + if not any(isinstance(x[0], Mod) for x in bucket): + bucket.sort(key=lambda x: default_sort_key(x[0])) + bucket = tuple([(mod, len(values)) for mod, values in bucket if + values]) + return bucket + + return (self.gamma, tr(abuckets), tr(bbuckets)) + + def difficulty(self, func): + """ Estimate how many steps it takes to reach ``func`` from self. + Return -1 if impossible. """ + if self.gamma != func.gamma: + return -1 + oabuckets, obbuckets, abuckets, bbuckets = [sift(params, _mod1) for + params in (self.ap, self.bq, func.ap, func.bq)] + + diff = 0 + for bucket, obucket in [(abuckets, oabuckets), (bbuckets, obbuckets)]: + for mod in set(list(bucket.keys()) + list(obucket.keys())): + if (mod not in bucket) or (mod not in obucket) \ + or len(bucket[mod]) != len(obucket[mod]): + return -1 + l1 = list(bucket[mod]) + l2 = list(obucket[mod]) + l1.sort() + l2.sort() + for i, j in zip(l1, l2): + diff += abs(i - j) + + return diff + + def _is_suitable_origin(self): + """ + Decide if ``self`` is a suitable origin. + + Explanation + =========== + + A function is a suitable origin iff: + * none of the ai equals bj + n, with n a non-negative integer + * none of the ai is zero + * none of the bj is a non-positive integer + + Note that this gives meaningful results only when none of the indices + are symbolic. + + """ + for a in self.ap: + for b in self.bq: + if (a - b).is_integer and (a - b).is_negative is False: + return False + for a in self.ap: + if a == 0: + return False + for b in self.bq: + if b.is_integer and b.is_nonpositive: + return False + return True + + +class G_Function(Expr): + """ A Meijer G-function. """ + + def __new__(cls, an, ap, bm, bq): + obj = super().__new__(cls) + obj.an = Tuple(*list(map(expand, an))) + obj.ap = Tuple(*list(map(expand, ap))) + obj.bm = Tuple(*list(map(expand, bm))) + obj.bq = Tuple(*list(map(expand, bq))) + return obj + + @property + def args(self): + return (self.an, self.ap, self.bm, self.bq) + + def _hashable_content(self): + return super()._hashable_content() + self.args + + def __call__(self, z): + return meijerg(self.an, self.ap, self.bm, self.bq, z) + + def compute_buckets(self): + """ + Compute buckets for the fours sets of parameters. + + Explanation + =========== + + We guarantee that any two equal Mod objects returned are actually the + same, and that the buckets are sorted by real part (an and bq + descendending, bm and ap ascending). + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import G_Function + >>> from sympy.abc import y + >>> from sympy import S + + >>> a, b = [1, 3, 2, S(3)/2], [1 + y, y, 2, y + 3] + >>> G_Function(a, b, [2], [y]).compute_buckets() + ({0: [3, 2, 1], 1/2: [3/2]}, + {0: [2], y: [y, y + 1, y + 3]}, {0: [2]}, {y: [y]}) + + """ + dicts = pan, pap, pbm, pbq = [defaultdict(list) for i in range(4)] + for dic, lis in zip(dicts, (self.an, self.ap, self.bm, self.bq)): + for x in lis: + dic[_mod1(x)].append(x) + + for dic, flip in zip(dicts, (True, False, False, True)): + for m, items in dic.items(): + x0 = items[0] + items.sort(key=lambda x: x - x0, reverse=flip) + dic[m] = items + + return tuple([dict(w) for w in dicts]) + + @property + def signature(self): + return (len(self.an), len(self.ap), len(self.bm), len(self.bq)) + + +# Dummy variable. +_x = Dummy('x') + +class Formula: + """ + This class represents hypergeometric formulae. + + Explanation + =========== + + Its data members are: + - z, the argument + - closed_form, the closed form expression + - symbols, the free symbols (parameters) in the formula + - func, the function + - B, C, M (see _compute_basis) + + Examples + ======== + + >>> from sympy.abc import a, b, z + >>> from sympy.simplify.hyperexpand import Formula, Hyper_Function + >>> func = Hyper_Function((a/2, a/3 + b, (1+a)/2), (a, b, (a+b)/7)) + >>> f = Formula(func, z, None, [a, b]) + + """ + + def _compute_basis(self, closed_form): + """ + Compute a set of functions B=(f1, ..., fn), a nxn matrix M + and a 1xn matrix C such that: + closed_form = C B + z d/dz B = M B. + """ + afactors = [_x + a for a in self.func.ap] + bfactors = [_x + b - 1 for b in self.func.bq] + expr = _x*Mul(*bfactors) - self.z*Mul(*afactors) + poly = Poly(expr, _x) + + n = poly.degree() - 1 + b = [closed_form] + for _ in range(n): + b.append(self.z*b[-1].diff(self.z)) + + self.B = Matrix(b) + self.C = Matrix([[1] + [0]*n]) + + m = eye(n) + m = m.col_insert(0, zeros(n, 1)) + l = poly.all_coeffs()[1:] + l.reverse() + self.M = m.row_insert(n, -Matrix([l])/poly.all_coeffs()[0]) + + def __init__(self, func, z, res, symbols, B=None, C=None, M=None): + z = sympify(z) + res = sympify(res) + symbols = [x for x in sympify(symbols) if func.has(x)] + + self.z = z + self.symbols = symbols + self.B = B + self.C = C + self.M = M + self.func = func + + # TODO with symbolic parameters, it could be advantageous + # (for prettier answers) to compute a basis only *after* + # instantiation + if res is not None: + self._compute_basis(res) + + @property + def closed_form(self): + return reduce(lambda s,m: s+m[0]*m[1], zip(self.C, self.B), S.Zero) + + def find_instantiations(self, func): + """ + Find substitutions of the free symbols that match ``func``. + + Return the substitution dictionaries as a list. Note that the returned + instantiations need not actually match, or be valid! + + """ + from sympy.solvers import solve + ap = func.ap + bq = func.bq + if len(ap) != len(self.func.ap) or len(bq) != len(self.func.bq): + raise TypeError('Cannot instantiate other number of parameters') + symbol_values = [] + for a in self.symbols: + if a in self.func.ap.args: + symbol_values.append(ap) + elif a in self.func.bq.args: + symbol_values.append(bq) + else: + raise ValueError("At least one of the parameters of the " + "formula must be equal to %s" % (a,)) + base_repl = [dict(list(zip(self.symbols, values))) + for values in product(*symbol_values)] + abuckets, bbuckets = [sift(params, _mod1) for params in [ap, bq]] + a_inv, b_inv = [{a: len(vals) for a, vals in bucket.items()} + for bucket in [abuckets, bbuckets]] + critical_values = [[0] for _ in self.symbols] + result = [] + _n = Dummy() + for repl in base_repl: + symb_a, symb_b = [sift(params, lambda x: _mod1(x.xreplace(repl))) + for params in [self.func.ap, self.func.bq]] + for bucket, obucket in [(abuckets, symb_a), (bbuckets, symb_b)]: + for mod in set(list(bucket.keys()) + list(obucket.keys())): + if (mod not in bucket) or (mod not in obucket) \ + or len(bucket[mod]) != len(obucket[mod]): + break + for a, vals in zip(self.symbols, critical_values): + if repl[a].free_symbols: + continue + exprs = [expr for expr in obucket[mod] if expr.has(a)] + repl0 = repl.copy() + repl0[a] += _n + for expr in exprs: + for target in bucket[mod]: + n0, = solve(expr.xreplace(repl0) - target, _n) + if n0.free_symbols: + raise ValueError("Value should not be true") + vals.append(n0) + else: + values = [] + for a, vals in zip(self.symbols, critical_values): + a0 = repl[a] + min_ = floor(min(vals)) + max_ = ceiling(max(vals)) + values.append([a0 + n for n in range(min_, max_ + 1)]) + result.extend(dict(list(zip(self.symbols, l))) for l in product(*values)) + return result + + + + +class FormulaCollection: + """ A collection of formulae to use as origins. """ + + def __init__(self): + """ Doing this globally at module init time is a pain ... """ + self.symbolic_formulae = {} + self.concrete_formulae = {} + self.formulae = [] + + add_formulae(self.formulae) + + # Now process the formulae into a helpful form. + # These dicts are indexed by (p, q). + + for f in self.formulae: + sizes = f.func.sizes + if len(f.symbols) > 0: + self.symbolic_formulae.setdefault(sizes, []).append(f) + else: + inv = f.func.build_invariants() + self.concrete_formulae.setdefault(sizes, {})[inv] = f + + def lookup_origin(self, func): + """ + Given the suitable target ``func``, try to find an origin in our + knowledge base. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import (FormulaCollection, + ... Hyper_Function) + >>> f = FormulaCollection() + >>> f.lookup_origin(Hyper_Function((), ())).closed_form + exp(_z) + >>> f.lookup_origin(Hyper_Function([1], ())).closed_form + HyperRep_power1(-1, _z) + + >>> from sympy import S + >>> i = Hyper_Function([S('1/4'), S('3/4 + 4')], [S.Half]) + >>> f.lookup_origin(i).closed_form + HyperRep_sqrts1(-1/4, _z) + """ + inv = func.build_invariants() + sizes = func.sizes + if sizes in self.concrete_formulae and \ + inv in self.concrete_formulae[sizes]: + return self.concrete_formulae[sizes][inv] + + # We don't have a concrete formula. Try to instantiate. + if sizes not in self.symbolic_formulae: + return None # Too bad... + + possible = [] + for f in self.symbolic_formulae[sizes]: + repls = f.find_instantiations(func) + for repl in repls: + func2 = f.func.xreplace(repl) + if not func2._is_suitable_origin(): + continue + diff = func2.difficulty(func) + if diff == -1: + continue + possible.append((diff, repl, f, func2)) + + # find the nearest origin + possible.sort(key=lambda x: x[0]) + for _, repl, f, func2 in possible: + f2 = Formula(func2, f.z, None, [], f.B.subs(repl), + f.C.subs(repl), f.M.subs(repl)) + if not any(e.has(S.NaN, oo, -oo, zoo) for e in [f2.B, f2.M, f2.C]): + return f2 + + return None + + +class MeijerFormula: + """ + This class represents a Meijer G-function formula. + + Its data members are: + - z, the argument + - symbols, the free symbols (parameters) in the formula + - func, the function + - B, C, M (c/f ordinary Formula) + """ + + def __init__(self, an, ap, bm, bq, z, symbols, B, C, M, matcher): + an, ap, bm, bq = [Tuple(*list(map(expand, w))) for w in [an, ap, bm, bq]] + self.func = G_Function(an, ap, bm, bq) + self.z = z + self.symbols = symbols + self._matcher = matcher + self.B = B + self.C = C + self.M = M + + @property + def closed_form(self): + return reduce(lambda s,m: s+m[0]*m[1], zip(self.C, self.B), S.Zero) + + def try_instantiate(self, func): + """ + Try to instantiate the current formula to (almost) match func. + This uses the _matcher passed on init. + """ + if func.signature != self.func.signature: + return None + res = self._matcher(func) + if res is not None: + subs, newfunc = res + return MeijerFormula(newfunc.an, newfunc.ap, newfunc.bm, newfunc.bq, + self.z, [], + self.B.subs(subs), self.C.subs(subs), + self.M.subs(subs), None) + + +class MeijerFormulaCollection: + """ + This class holds a collection of meijer g formulae. + """ + + def __init__(self): + formulae = [] + add_meijerg_formulae(formulae) + self.formulae = defaultdict(list) + for formula in formulae: + self.formulae[formula.func.signature].append(formula) + self.formulae = dict(self.formulae) + + def lookup_origin(self, func): + """ Try to find a formula that matches func. """ + if func.signature not in self.formulae: + return None + for formula in self.formulae[func.signature]: + res = formula.try_instantiate(func) + if res is not None: + return res + + +class Operator: + """ + Base class for operators to be applied to our functions. + + Explanation + =========== + + These operators are differential operators. They are by convention + expressed in the variable D = z*d/dz (although this base class does + not actually care). + Note that when the operator is applied to an object, we typically do + *not* blindly differentiate but instead use a different representation + of the z*d/dz operator (see make_derivative_operator). + + To subclass from this, define a __init__ method that initializes a + self._poly variable. This variable stores a polynomial. By convention + the generator is z*d/dz, and acts to the right of all coefficients. + + Thus this poly + x**2 + 2*z*x + 1 + represents the differential operator + (z*d/dz)**2 + 2*z**2*d/dz. + + This class is used only in the implementation of the hypergeometric + function expansion algorithm. + """ + + def apply(self, obj, op): + """ + Apply ``self`` to the object ``obj``, where the generator is ``op``. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import Operator + >>> from sympy.polys.polytools import Poly + >>> from sympy.abc import x, y, z + >>> op = Operator() + >>> op._poly = Poly(x**2 + z*x + y, x) + >>> op.apply(z**7, lambda f: f.diff(z)) + y*z**7 + 7*z**7 + 42*z**5 + """ + coeffs = self._poly.all_coeffs() + coeffs.reverse() + diffs = [obj] + for c in coeffs[1:]: + diffs.append(op(diffs[-1])) + r = coeffs[0]*diffs[0] + for c, d in zip(coeffs[1:], diffs[1:]): + r += c*d + return r + + +class MultOperator(Operator): + """ Simply multiply by a "constant" """ + + def __init__(self, p): + self._poly = Poly(p, _x) + + +class ShiftA(Operator): + """ Increment an upper index. """ + + def __init__(self, ai): + ai = sympify(ai) + if ai == 0: + raise ValueError('Cannot increment zero upper index.') + self._poly = Poly(_x/ai + 1, _x) + + def __str__(self): + return '' % (1/self._poly.all_coeffs()[0]) + + +class ShiftB(Operator): + """ Decrement a lower index. """ + + def __init__(self, bi): + bi = sympify(bi) + if bi == 1: + raise ValueError('Cannot decrement unit lower index.') + self._poly = Poly(_x/(bi - 1) + 1, _x) + + def __str__(self): + return '' % (1/self._poly.all_coeffs()[0] + 1) + + +class UnShiftA(Operator): + """ Decrement an upper index. """ + + def __init__(self, ap, bq, i, z): + """ Note: i counts from zero! """ + ap, bq, i = list(map(sympify, [ap, bq, i])) + + self._ap = ap + self._bq = bq + self._i = i + + ap = list(ap) + bq = list(bq) + ai = ap.pop(i) - 1 + + if ai == 0: + raise ValueError('Cannot decrement unit upper index.') + + m = Poly(z*ai, _x) + for a in ap: + m *= Poly(_x + a, _x) + + A = Dummy('A') + n = D = Poly(ai*A - ai, A) + for b in bq: + n *= D + (b - 1).as_poly(A) + + b0 = -n.nth(0) + if b0 == 0: + raise ValueError('Cannot decrement upper index: ' + 'cancels with lower') + + n = Poly(Poly(n.all_coeffs()[:-1], A).as_expr().subs(A, _x/ai + 1), _x) + + self._poly = Poly((n - m)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._ap, self._bq) + + +class UnShiftB(Operator): + """ Increment a lower index. """ + + def __init__(self, ap, bq, i, z): + """ Note: i counts from zero! """ + ap, bq, i = list(map(sympify, [ap, bq, i])) + + self._ap = ap + self._bq = bq + self._i = i + + ap = list(ap) + bq = list(bq) + bi = bq.pop(i) + 1 + + if bi == 0: + raise ValueError('Cannot increment -1 lower index.') + + m = Poly(_x*(bi - 1), _x) + for b in bq: + m *= Poly(_x + b - 1, _x) + + B = Dummy('B') + D = Poly((bi - 1)*B - bi + 1, B) + n = Poly(z, B) + for a in ap: + n *= (D + a.as_poly(B)) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot increment index: cancels with upper') + + n = Poly(Poly(n.all_coeffs()[:-1], B).as_expr().subs( + B, _x/(bi - 1) + 1), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._ap, self._bq) + + +class MeijerShiftA(Operator): + """ Increment an upper b index. """ + + def __init__(self, bi): + bi = sympify(bi) + self._poly = Poly(bi - _x, _x) + + def __str__(self): + return '' % (self._poly.all_coeffs()[1]) + + +class MeijerShiftB(Operator): + """ Decrement an upper a index. """ + + def __init__(self, bi): + bi = sympify(bi) + self._poly = Poly(1 - bi + _x, _x) + + def __str__(self): + return '' % (1 - self._poly.all_coeffs()[1]) + + +class MeijerShiftC(Operator): + """ Increment a lower b index. """ + + def __init__(self, bi): + bi = sympify(bi) + self._poly = Poly(-bi + _x, _x) + + def __str__(self): + return '' % (-self._poly.all_coeffs()[1]) + + +class MeijerShiftD(Operator): + """ Decrement a lower a index. """ + + def __init__(self, bi): + bi = sympify(bi) + self._poly = Poly(bi - 1 - _x, _x) + + def __str__(self): + return '' % (self._poly.all_coeffs()[1] + 1) + + +class MeijerUnShiftA(Operator): + """ Decrement an upper b index. """ + + def __init__(self, an, ap, bm, bq, i, z): + """ Note: i counts from zero! """ + an, ap, bm, bq, i = list(map(sympify, [an, ap, bm, bq, i])) + + self._an = an + self._ap = ap + self._bm = bm + self._bq = bq + self._i = i + + an = list(an) + ap = list(ap) + bm = list(bm) + bq = list(bq) + bi = bm.pop(i) - 1 + + m = Poly(1, _x) * prod(Poly(b - _x, _x) for b in bm) * prod(Poly(_x - b, _x) for b in bq) + + A = Dummy('A') + D = Poly(bi - A, A) + n = Poly(z, A) * prod((D + 1 - a) for a in an) * prod((-D + a - 1) for a in ap) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot decrement upper b index (cancels)') + + n = Poly(Poly(n.all_coeffs()[:-1], A).as_expr().subs(A, bi - _x), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._an, self._ap, self._bm, self._bq) + + +class MeijerUnShiftB(Operator): + """ Increment an upper a index. """ + + def __init__(self, an, ap, bm, bq, i, z): + """ Note: i counts from zero! """ + an, ap, bm, bq, i = list(map(sympify, [an, ap, bm, bq, i])) + + self._an = an + self._ap = ap + self._bm = bm + self._bq = bq + self._i = i + + an = list(an) + ap = list(ap) + bm = list(bm) + bq = list(bq) + ai = an.pop(i) + 1 + + m = Poly(z, _x) + for a in an: + m *= Poly(1 - a + _x, _x) + for a in ap: + m *= Poly(a - 1 - _x, _x) + + B = Dummy('B') + D = Poly(B + ai - 1, B) + n = Poly(1, B) + for b in bm: + n *= (-D + b) + for b in bq: + n *= (D - b) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot increment upper a index (cancels)') + + n = Poly(Poly(n.all_coeffs()[:-1], B).as_expr().subs( + B, 1 - ai + _x), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._an, self._ap, self._bm, self._bq) + + +class MeijerUnShiftC(Operator): + """ Decrement a lower b index. """ + # XXX this is "essentially" the same as MeijerUnShiftA. This "essentially" + # can be made rigorous using the functional equation G(1/z) = G'(z), + # where G' denotes a G function of slightly altered parameters. + # However, sorting out the details seems harder than just coding it + # again. + + def __init__(self, an, ap, bm, bq, i, z): + """ Note: i counts from zero! """ + an, ap, bm, bq, i = list(map(sympify, [an, ap, bm, bq, i])) + + self._an = an + self._ap = ap + self._bm = bm + self._bq = bq + self._i = i + + an = list(an) + ap = list(ap) + bm = list(bm) + bq = list(bq) + bi = bq.pop(i) - 1 + + m = Poly(1, _x) + for b in bm: + m *= Poly(b - _x, _x) + for b in bq: + m *= Poly(_x - b, _x) + + C = Dummy('C') + D = Poly(bi + C, C) + n = Poly(z, C) + for a in an: + n *= (D + 1 - a) + for a in ap: + n *= (-D + a - 1) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot decrement lower b index (cancels)') + + n = Poly(Poly(n.all_coeffs()[:-1], C).as_expr().subs(C, _x - bi), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._an, self._ap, self._bm, self._bq) + + +class MeijerUnShiftD(Operator): + """ Increment a lower a index. """ + # XXX This is essentially the same as MeijerUnShiftA. + # See comment at MeijerUnShiftC. + + def __init__(self, an, ap, bm, bq, i, z): + """ Note: i counts from zero! """ + an, ap, bm, bq, i = list(map(sympify, [an, ap, bm, bq, i])) + + self._an = an + self._ap = ap + self._bm = bm + self._bq = bq + self._i = i + + an = list(an) + ap = list(ap) + bm = list(bm) + bq = list(bq) + ai = ap.pop(i) + 1 + + m = Poly(z, _x) + for a in an: + m *= Poly(1 - a + _x, _x) + for a in ap: + m *= Poly(a - 1 - _x, _x) + + B = Dummy('B') # - this is the shift operator `D_I` + D = Poly(ai - 1 - B, B) + n = Poly(1, B) + for b in bm: + n *= (-D + b) + for b in bq: + n *= (D - b) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot increment lower a index (cancels)') + + n = Poly(Poly(n.all_coeffs()[:-1], B).as_expr().subs( + B, ai - 1 - _x), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._an, self._ap, self._bm, self._bq) + + +class ReduceOrder(Operator): + """ Reduce Order by cancelling an upper and a lower index. """ + + def __new__(cls, ai, bj): + """ For convenience if reduction is not possible, return None. """ + ai = sympify(ai) + bj = sympify(bj) + n = ai - bj + if not n.is_Integer or n < 0: + return None + if bj.is_integer and bj.is_nonpositive: + return None + + expr = Operator.__new__(cls) + + p = S.One + for k in range(n): + p *= (_x + bj + k)/(bj + k) + + expr._poly = Poly(p, _x) + expr._a = ai + expr._b = bj + + return expr + + @classmethod + def _meijer(cls, b, a, sign): + """ Cancel b + sign*s and a + sign*s + This is for meijer G functions. """ + b = sympify(b) + a = sympify(a) + n = b - a + if n.is_negative or not n.is_Integer: + return None + + expr = Operator.__new__(cls) + + p = S.One + for k in range(n): + p *= (sign*_x + a + k) + + expr._poly = Poly(p, _x) + if sign == -1: + expr._a = b + expr._b = a + else: + expr._b = Add(1, a - 1, evaluate=False) + expr._a = Add(1, b - 1, evaluate=False) + + return expr + + @classmethod + def meijer_minus(cls, b, a): + return cls._meijer(b, a, -1) + + @classmethod + def meijer_plus(cls, a, b): + return cls._meijer(1 - a, 1 - b, 1) + + def __str__(self): + return '' % \ + (self._a, self._b) + + +def _reduce_order(ap, bq, gen, key): + """ Order reduction algorithm used in Hypergeometric and Meijer G """ + ap = list(ap) + bq = list(bq) + + ap.sort(key=key) + bq.sort(key=key) + + nap = [] + # we will edit bq in place + operators = [] + for a in ap: + op = None + for i in range(len(bq)): + op = gen(a, bq[i]) + if op is not None: + bq.pop(i) + break + if op is None: + nap.append(a) + else: + operators.append(op) + + return nap, bq, operators + + +def reduce_order(func): + """ + Given the hypergeometric function ``func``, find a sequence of operators to + reduces order as much as possible. + + Explanation + =========== + + Return (newfunc, [operators]), where applying the operators to the + hypergeometric function newfunc yields func. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import reduce_order, Hyper_Function + >>> reduce_order(Hyper_Function((1, 2), (3, 4))) + (Hyper_Function((1, 2), (3, 4)), []) + >>> reduce_order(Hyper_Function((1,), (1,))) + (Hyper_Function((), ()), []) + >>> reduce_order(Hyper_Function((2, 4), (3, 3))) + (Hyper_Function((2,), (3,)), []) + """ + nap, nbq, operators = _reduce_order(func.ap, func.bq, ReduceOrder, default_sort_key) + + return Hyper_Function(Tuple(*nap), Tuple(*nbq)), operators + + +def reduce_order_meijer(func): + """ + Given the Meijer G function parameters, ``func``, find a sequence of + operators that reduces order as much as possible. + + Return newfunc, [operators]. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import (reduce_order_meijer, + ... G_Function) + >>> reduce_order_meijer(G_Function([3, 4], [5, 6], [3, 4], [1, 2]))[0] + G_Function((4, 3), (5, 6), (3, 4), (2, 1)) + >>> reduce_order_meijer(G_Function([3, 4], [5, 6], [3, 4], [1, 8]))[0] + G_Function((3,), (5, 6), (3, 4), (1,)) + >>> reduce_order_meijer(G_Function([3, 4], [5, 6], [7, 5], [1, 5]))[0] + G_Function((3,), (), (), (1,)) + >>> reduce_order_meijer(G_Function([3, 4], [5, 6], [7, 5], [5, 3]))[0] + G_Function((), (), (), ()) + """ + + nan, nbq, ops1 = _reduce_order(func.an, func.bq, ReduceOrder.meijer_plus, + lambda x: default_sort_key(-x)) + nbm, nap, ops2 = _reduce_order(func.bm, func.ap, ReduceOrder.meijer_minus, + default_sort_key) + + return G_Function(nan, nap, nbm, nbq), ops1 + ops2 + + +def make_derivative_operator(M, z): + """ Create a derivative operator, to be passed to Operator.apply. """ + def doit(C): + r = z*C.diff(z) + C*M + r = r.applyfunc(make_simp(z)) + return r + return doit + + +def apply_operators(obj, ops, op): + """ + Apply the list of operators ``ops`` to object ``obj``, substituting + ``op`` for the generator. + """ + res = obj + for o in reversed(ops): + res = o.apply(res, op) + return res + + +def devise_plan(target, origin, z): + """ + Devise a plan (consisting of shift and un-shift operators) to be applied + to the hypergeometric function ``target`` to yield ``origin``. + Returns a list of operators. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import devise_plan, Hyper_Function + >>> from sympy.abc import z + + Nothing to do: + + >>> devise_plan(Hyper_Function((1, 2), ()), Hyper_Function((1, 2), ()), z) + [] + >>> devise_plan(Hyper_Function((), (1, 2)), Hyper_Function((), (1, 2)), z) + [] + + Very simple plans: + + >>> devise_plan(Hyper_Function((2,), ()), Hyper_Function((1,), ()), z) + [] + >>> devise_plan(Hyper_Function((), (2,)), Hyper_Function((), (1,)), z) + [] + + Several buckets: + + >>> from sympy import S + >>> devise_plan(Hyper_Function((1, S.Half), ()), + ... Hyper_Function((2, S('3/2')), ()), z) #doctest: +NORMALIZE_WHITESPACE + [, + ] + + A slightly more complicated plan: + + >>> devise_plan(Hyper_Function((1, 3), ()), Hyper_Function((2, 2), ()), z) + [, ] + + Another more complicated plan: (note that the ap have to be shifted first!) + + >>> devise_plan(Hyper_Function((1, -1), (2,)), Hyper_Function((3, -2), (4,)), z) + [, , + , + , ] + """ + abuckets, bbuckets, nabuckets, nbbuckets = [sift(params, _mod1) for + params in (target.ap, target.bq, origin.ap, origin.bq)] + + if len(list(abuckets.keys())) != len(list(nabuckets.keys())) or \ + len(list(bbuckets.keys())) != len(list(nbbuckets.keys())): + raise ValueError('%s not reachable from %s' % (target, origin)) + + ops = [] + + def do_shifts(fro, to, inc, dec): + ops = [] + for i in range(len(fro)): + if to[i] - fro[i] > 0: + sh = inc + ch = 1 + else: + sh = dec + ch = -1 + + while to[i] != fro[i]: + ops += [sh(fro, i)] + fro[i] += ch + + return ops + + def do_shifts_a(nal, nbk, al, aother, bother): + """ Shift us from (nal, nbk) to (al, nbk). """ + return do_shifts(nal, al, lambda p, i: ShiftA(p[i]), + lambda p, i: UnShiftA(p + aother, nbk + bother, i, z)) + + def do_shifts_b(nal, nbk, bk, aother, bother): + """ Shift us from (nal, nbk) to (nal, bk). """ + return do_shifts(nbk, bk, + lambda p, i: UnShiftB(nal + aother, p + bother, i, z), + lambda p, i: ShiftB(p[i])) + + for r in sorted(list(abuckets.keys()) + list(bbuckets.keys()), key=default_sort_key): + al = () + nal = () + bk = () + nbk = () + if r in abuckets: + al = abuckets[r] + nal = nabuckets[r] + if r in bbuckets: + bk = bbuckets[r] + nbk = nbbuckets[r] + if len(al) != len(nal) or len(bk) != len(nbk): + raise ValueError('%s not reachable from %s' % (target, origin)) + + al, nal, bk, nbk = [sorted(w, key=default_sort_key) + for w in [al, nal, bk, nbk]] + + def others(dic, key): + l = [] + for k in dic: + if k != key: + l.extend(dic[k]) + return l + aother = others(nabuckets, r) + bother = others(nbbuckets, r) + + if len(al) == 0: + # there can be no complications, just shift the bs as we please + ops += do_shifts_b([], nbk, bk, aother, bother) + elif len(bk) == 0: + # there can be no complications, just shift the as as we please + ops += do_shifts_a(nal, [], al, aother, bother) + else: + namax = nal[-1] + amax = al[-1] + + if nbk[0] - namax <= 0 or bk[0] - amax <= 0: + raise ValueError('Non-suitable parameters.') + + if namax - amax > 0: + # we are going to shift down - first do the as, then the bs + ops += do_shifts_a(nal, nbk, al, aother, bother) + ops += do_shifts_b(al, nbk, bk, aother, bother) + else: + # we are going to shift up - first do the bs, then the as + ops += do_shifts_b(nal, nbk, bk, aother, bother) + ops += do_shifts_a(nal, bk, al, aother, bother) + + nabuckets[r] = al + nbbuckets[r] = bk + + ops.reverse() + return ops + + +def try_shifted_sum(func, z): + """ Try to recognise a hypergeometric sum that starts from k > 0. """ + abuckets, bbuckets = sift(func.ap, _mod1), sift(func.bq, _mod1) + if len(abuckets[S.Zero]) != 1: + return None + r = abuckets[S.Zero][0] + if r <= 0: + return None + if S.Zero not in bbuckets: + return None + l = list(bbuckets[S.Zero]) + l.sort() + k = l[0] + if k <= 0: + return None + + nap = list(func.ap) + nap.remove(r) + nbq = list(func.bq) + nbq.remove(k) + k -= 1 + nap = [x - k for x in nap] + nbq = [x - k for x in nbq] + + ops = [] + for n in range(r - 1): + ops.append(ShiftA(n + 1)) + ops.reverse() + + fac = factorial(k)/z**k + fac *= Mul(*[rf(b, k) for b in nbq]) + fac /= Mul(*[rf(a, k) for a in nap]) + + ops += [MultOperator(fac)] + + p = 0 + for n in range(k): + m = z**n/factorial(n) + m *= Mul(*[rf(a, n) for a in nap]) + m /= Mul(*[rf(b, n) for b in nbq]) + p += m + + return Hyper_Function(nap, nbq), ops, -p + + +def try_polynomial(func, z): + """ Recognise polynomial cases. Returns None if not such a case. + Requires order to be fully reduced. """ + abuckets, bbuckets = sift(func.ap, _mod1), sift(func.bq, _mod1) + a0 = abuckets[S.Zero] + b0 = bbuckets[S.Zero] + a0.sort() + b0.sort() + al0 = [x for x in a0 if x <= 0] + bl0 = [x for x in b0 if x <= 0] + + if bl0 and all(a < bl0[-1] for a in al0): + return oo + if not al0: + return None + + a = al0[-1] + fac = 1 + res = S.One + for n in Tuple(*list(range(-a))): + fac *= z + fac /= n + 1 + fac *= Mul(*[a + n for a in func.ap]) + fac /= Mul(*[b + n for b in func.bq]) + res += fac + return res + + +def try_lerchphi(func): + """ + Try to find an expression for Hyper_Function ``func`` in terms of Lerch + Transcendents. + + Return None if no such expression can be found. + """ + # This is actually quite simple, and is described in Roach's paper, + # section 18. + # We don't need to implement the reduction to polylog here, this + # is handled by expand_func. + + # First we need to figure out if the summation coefficient is a rational + # function of the summation index, and construct that rational function. + abuckets, bbuckets = sift(func.ap, _mod1), sift(func.bq, _mod1) + + paired = {} + for key, value in abuckets.items(): + if key != 0 and key not in bbuckets: + return None + bvalue = bbuckets[key] + paired[key] = (list(value), list(bvalue)) + bbuckets.pop(key, None) + if bbuckets != {}: + return None + if S.Zero not in abuckets: + return None + aints, bints = paired[S.Zero] + # Account for the additional n! in denominator + paired[S.Zero] = (aints, bints + [1]) + + t = Dummy('t') + numer = S.One + denom = S.One + for key, (avalue, bvalue) in paired.items(): + if len(avalue) != len(bvalue): + return None + # Note that since order has been reduced fully, all the b are + # bigger than all the a they differ from by an integer. In particular + # if there are any negative b left, this function is not well-defined. + for a, b in zip(avalue, bvalue): + if (a - b).is_positive: + k = a - b + numer *= rf(b + t, k) + denom *= rf(b, k) + else: + k = b - a + numer *= rf(a, k) + denom *= rf(a + t, k) + + # Now do a partial fraction decomposition. + # We assemble two structures: a list monomials of pairs (a, b) representing + # a*t**b (b a non-negative integer), and a dict terms, where + # terms[a] = [(b, c)] means that there is a term b/(t-a)**c. + part = apart(numer/denom, t) + args = Add.make_args(part) + monomials = [] + terms = {} + for arg in args: + numer, denom = arg.as_numer_denom() + if not denom.has(t): + p = Poly(numer, t) + if not p.is_monomial: + raise TypeError("p should be monomial") + ((b, ), a) = p.LT() + monomials += [(a/denom, b)] + continue + if numer.has(t): + raise NotImplementedError('Need partial fraction decomposition' + ' with linear denominators') + indep, [dep] = denom.as_coeff_mul(t) + n = 1 + if dep.is_Pow: + n = dep.exp + dep = dep.base + if dep == t: + a = 0 + elif dep.is_Add: + a, tmp = dep.as_independent(t) + b = 1 + if tmp != t: + b, _ = tmp.as_independent(t) + if dep != b*t + a: + raise NotImplementedError('unrecognised form %s' % dep) + a /= b + indep *= b**n + else: + raise NotImplementedError('unrecognised form of partial fraction') + terms.setdefault(a, []).append((numer/indep, n)) + + # Now that we have this information, assemble our formula. All the + # monomials yield rational functions and go into one basis element. + # The terms[a] are related by differentiation. If the largest exponent is + # n, we need lerchphi(z, k, a) for k = 1, 2, ..., n. + # deriv maps a basis to its derivative, expressed as a C(z)-linear + # combination of other basis elements. + deriv = {} + coeffs = {} + z = Dummy('z') + monomials.sort(key=lambda x: x[1]) + mon = {0: 1/(1 - z)} + if monomials: + for k in range(monomials[-1][1]): + mon[k + 1] = z*mon[k].diff(z) + for a, n in monomials: + coeffs.setdefault(S.One, []).append(a*mon[n]) + for a, l in terms.items(): + for c, k in l: + coeffs.setdefault(lerchphi(z, k, a), []).append(c) + l.sort(key=lambda x: x[1]) + for k in range(2, l[-1][1] + 1): + deriv[lerchphi(z, k, a)] = [(-a, lerchphi(z, k, a)), + (1, lerchphi(z, k - 1, a))] + deriv[lerchphi(z, 1, a)] = [(-a, lerchphi(z, 1, a)), + (1/(1 - z), S.One)] + trans = {} + for n, b in enumerate([S.One] + list(deriv.keys())): + trans[b] = n + basis = [expand_func(b) for (b, _) in sorted(trans.items(), + key=lambda x:x[1])] + B = Matrix(basis) + C = Matrix([[0]*len(B)]) + for b, c in coeffs.items(): + C[trans[b]] = Add(*c) + M = zeros(len(B)) + for b, l in deriv.items(): + for c, b2 in l: + M[trans[b], trans[b2]] = c + return Formula(func, z, None, [], B, C, M) + + +def build_hypergeometric_formula(func): + """ + Create a formula object representing the hypergeometric function ``func``. + + """ + # We know that no `ap` are negative integers, otherwise "detect poly" + # would have kicked in. However, `ap` could be empty. In this case we can + # use a different basis. + # I'm not aware of a basis that works in all cases. + z = Dummy('z') + if func.ap: + afactors = [_x + a for a in func.ap] + bfactors = [_x + b - 1 for b in func.bq] + expr = _x*Mul(*bfactors) - z*Mul(*afactors) + poly = Poly(expr, _x) + n = poly.degree() + basis = [] + M = zeros(n) + for k in range(n): + a = func.ap[0] + k + basis += [hyper([a] + list(func.ap[1:]), func.bq, z)] + if k < n - 1: + M[k, k] = -a + M[k, k + 1] = a + B = Matrix(basis) + C = Matrix([[1] + [0]*(n - 1)]) + derivs = [eye(n)] + for k in range(n): + derivs.append(M*derivs[k]) + l = poly.all_coeffs() + l.reverse() + res = [0]*n + for k, c in enumerate(l): + for r, d in enumerate(C*derivs[k]): + res[r] += c*d + for k, c in enumerate(res): + M[n - 1, k] = -c/derivs[n - 1][0, n - 1]/poly.all_coeffs()[0] + return Formula(func, z, None, [], B, C, M) + else: + # Since there are no `ap`, none of the `bq` can be non-positive + # integers. + basis = [] + bq = list(func.bq[:]) + for i in range(len(bq)): + basis += [hyper([], bq, z)] + bq[i] += 1 + basis += [hyper([], bq, z)] + B = Matrix(basis) + n = len(B) + C = Matrix([[1] + [0]*(n - 1)]) + M = zeros(n) + M[0, n - 1] = z/Mul(*func.bq) + for k in range(1, n): + M[k, k - 1] = func.bq[k - 1] + M[k, k] = -func.bq[k - 1] + return Formula(func, z, None, [], B, C, M) + + +def hyperexpand_special(ap, bq, z): + """ + Try to find a closed-form expression for hyper(ap, bq, z), where ``z`` + is supposed to be a "special" value, e.g. 1. + + This function tries various of the classical summation formulae + (Gauss, Saalschuetz, etc). + """ + # This code is very ad-hoc. There are many clever algorithms + # (notably Zeilberger's) related to this problem. + # For now we just want a few simple cases to work. + p, q = len(ap), len(bq) + z_ = z + z = unpolarify(z) + if z == 0: + return S.One + from sympy.simplify.simplify import simplify + if p == 2 and q == 1: + # 2F1 + a, b, c = ap + bq + if z == 1: + # Gauss + return gamma(c - a - b)*gamma(c)/gamma(c - a)/gamma(c - b) + if z == -1 and simplify(b - a + c) == 1: + b, a = a, b + if z == -1 and simplify(a - b + c) == 1: + # Kummer + if b.is_integer and b.is_negative: + return 2*cos(pi*b/2)*gamma(-b)*gamma(b - a + 1) \ + /gamma(-b/2)/gamma(b/2 - a + 1) + else: + return gamma(b/2 + 1)*gamma(b - a + 1) \ + /gamma(b + 1)/gamma(b/2 - a + 1) + # TODO tons of more formulae + # investigate what algorithms exist + return hyper(ap, bq, z_) + +_collection = None + + +def _hyperexpand(func, z, ops0=[], z0=Dummy('z0'), premult=1, prem=0, + rewrite='default'): + """ + Try to find an expression for the hypergeometric function ``func``. + + Explanation + =========== + + The result is expressed in terms of a dummy variable ``z0``. Then it + is multiplied by ``premult``. Then ``ops0`` is applied. + ``premult`` must be a*z**prem for some a independent of ``z``. + """ + + if z.is_zero: + return S.One + + from sympy.simplify.simplify import simplify + + z = polarify(z, subs=False) + if rewrite == 'default': + rewrite = 'nonrepsmall' + + def carryout_plan(f, ops): + C = apply_operators(f.C.subs(f.z, z0), ops, + make_derivative_operator(f.M.subs(f.z, z0), z0)) + C = apply_operators(C, ops0, + make_derivative_operator(f.M.subs(f.z, z0) + + prem*eye(f.M.shape[0]), z0)) + + if premult == 1: + C = C.applyfunc(make_simp(z0)) + r = reduce(lambda s,m: s+m[0]*m[1], zip(C, f.B.subs(f.z, z0)), S.Zero)*premult + res = r.subs(z0, z) + if rewrite: + res = res.rewrite(rewrite) + return res + + # TODO + # The following would be possible: + # *) PFD Duplication (see Kelly Roach's paper) + # *) In a similar spirit, try_lerchphi() can be generalised considerably. + + global _collection + if _collection is None: + _collection = FormulaCollection() + + debug('Trying to expand hypergeometric function ', func) + + # First reduce order as much as possible. + func, ops = reduce_order(func) + if ops: + debug(' Reduced order to ', func) + else: + debug(' Could not reduce order.') + + # Now try polynomial cases + res = try_polynomial(func, z0) + if res is not None: + debug(' Recognised polynomial.') + p = apply_operators(res, ops, lambda f: z0*f.diff(z0)) + p = apply_operators(p*premult, ops0, lambda f: z0*f.diff(z0)) + return unpolarify(simplify(p).subs(z0, z)) + + # Try to recognise a shifted sum. + p = S.Zero + res = try_shifted_sum(func, z0) + if res is not None: + func, nops, p = res + debug(' Recognised shifted sum, reduced order to ', func) + ops += nops + + # apply the plan for poly + p = apply_operators(p, ops, lambda f: z0*f.diff(z0)) + p = apply_operators(p*premult, ops0, lambda f: z0*f.diff(z0)) + p = simplify(p).subs(z0, z) + + # Try special expansions early. + if unpolarify(z) in [1, -1] and (len(func.ap), len(func.bq)) == (2, 1): + f = build_hypergeometric_formula(func) + r = carryout_plan(f, ops).replace(hyper, hyperexpand_special) + if not r.has(hyper): + return r + p + + # Try to find a formula in our collection + formula = _collection.lookup_origin(func) + + # Now try a lerch phi formula + if formula is None: + formula = try_lerchphi(func) + + if formula is None: + debug(' Could not find an origin. ', + 'Will return answer in terms of ' + 'simpler hypergeometric functions.') + formula = build_hypergeometric_formula(func) + + debug(' Found an origin: ', formula.closed_form, ' ', formula.func) + + # We need to find the operators that convert formula into func. + ops += devise_plan(func, formula.func, z0) + + # Now carry out the plan. + r = carryout_plan(formula, ops) + p + + return powdenest(r, polar=True).replace(hyper, hyperexpand_special) + + +def devise_plan_meijer(fro, to, z): + """ + Find operators to convert G-function ``fro`` into G-function ``to``. + + Explanation + =========== + + It is assumed that ``fro`` and ``to`` have the same signatures, and that in fact + any corresponding pair of parameters differs by integers, and a direct path + is possible. I.e. if there are parameters a1 b1 c1 and a2 b2 c2 it is + assumed that a1 can be shifted to a2, etc. The only thing this routine + determines is the order of shifts to apply, nothing clever will be tried. + It is also assumed that ``fro`` is suitable. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import (devise_plan_meijer, + ... G_Function) + >>> from sympy.abc import z + + Empty plan: + + >>> devise_plan_meijer(G_Function([1], [2], [3], [4]), + ... G_Function([1], [2], [3], [4]), z) + [] + + Very simple plans: + + >>> devise_plan_meijer(G_Function([0], [], [], []), + ... G_Function([1], [], [], []), z) + [] + >>> devise_plan_meijer(G_Function([0], [], [], []), + ... G_Function([-1], [], [], []), z) + [] + >>> devise_plan_meijer(G_Function([], [1], [], []), + ... G_Function([], [2], [], []), z) + [] + + Slightly more complicated plans: + + >>> devise_plan_meijer(G_Function([0], [], [], []), + ... G_Function([2], [], [], []), z) + [, + ] + >>> devise_plan_meijer(G_Function([0], [], [0], []), + ... G_Function([-1], [], [1], []), z) + [, ] + + Order matters: + + >>> devise_plan_meijer(G_Function([0], [], [0], []), + ... G_Function([1], [], [1], []), z) + [, ] + """ + # TODO for now, we use the following simple heuristic: inverse-shift + # when possible, shift otherwise. Give up if we cannot make progress. + + def try_shift(f, t, shifter, diff, counter): + """ Try to apply ``shifter`` in order to bring some element in ``f`` + nearer to its counterpart in ``to``. ``diff`` is +/- 1 and + determines the effect of ``shifter``. Counter is a list of elements + blocking the shift. + + Return an operator if change was possible, else None. + """ + for idx, (a, b) in enumerate(zip(f, t)): + if ( + (a - b).is_integer and (b - a)/diff > 0 and + all(a != x for x in counter)): + sh = shifter(idx) + f[idx] += diff + return sh + fan = list(fro.an) + fap = list(fro.ap) + fbm = list(fro.bm) + fbq = list(fro.bq) + ops = [] + change = True + while change: + change = False + op = try_shift(fan, to.an, + lambda i: MeijerUnShiftB(fan, fap, fbm, fbq, i, z), + 1, fbm + fbq) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fap, to.ap, + lambda i: MeijerUnShiftD(fan, fap, fbm, fbq, i, z), + 1, fbm + fbq) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fbm, to.bm, + lambda i: MeijerUnShiftA(fan, fap, fbm, fbq, i, z), + -1, fan + fap) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fbq, to.bq, + lambda i: MeijerUnShiftC(fan, fap, fbm, fbq, i, z), + -1, fan + fap) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fan, to.an, lambda i: MeijerShiftB(fan[i]), -1, []) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fap, to.ap, lambda i: MeijerShiftD(fap[i]), -1, []) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fbm, to.bm, lambda i: MeijerShiftA(fbm[i]), 1, []) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fbq, to.bq, lambda i: MeijerShiftC(fbq[i]), 1, []) + if op is not None: + ops += [op] + change = True + continue + if fan != list(to.an) or fap != list(to.ap) or fbm != list(to.bm) or \ + fbq != list(to.bq): + raise NotImplementedError('Could not devise plan.') + ops.reverse() + return ops + +_meijercollection = None + + +def _meijergexpand(func, z0, allow_hyper=False, rewrite='default', + place=None): + """ + Try to find an expression for the Meijer G function specified + by the G_Function ``func``. If ``allow_hyper`` is True, then returning + an expression in terms of hypergeometric functions is allowed. + + Currently this just does Slater's theorem. + If expansions exist both at zero and at infinity, ``place`` + can be set to ``0`` or ``zoo`` for the preferred choice. + """ + global _meijercollection + if _meijercollection is None: + _meijercollection = MeijerFormulaCollection() + if rewrite == 'default': + rewrite = None + + func0 = func + debug('Try to expand Meijer G function corresponding to ', func) + + # We will play games with analytic continuation - rather use a fresh symbol + z = Dummy('z') + + func, ops = reduce_order_meijer(func) + if ops: + debug(' Reduced order to ', func) + else: + debug(' Could not reduce order.') + + # Try to find a direct formula + f = _meijercollection.lookup_origin(func) + if f is not None: + debug(' Found a Meijer G formula: ', f.func) + ops += devise_plan_meijer(f.func, func, z) + + # Now carry out the plan. + C = apply_operators(f.C.subs(f.z, z), ops, + make_derivative_operator(f.M.subs(f.z, z), z)) + + C = C.applyfunc(make_simp(z)) + r = C*f.B.subs(f.z, z) + r = r[0].subs(z, z0) + return powdenest(r, polar=True) + + debug(" Could not find a direct formula. Trying Slater's theorem.") + + # TODO the following would be possible: + # *) Paired Index Theorems + # *) PFD Duplication + # (See Kelly Roach's paper for details on either.) + # + # TODO Also, we tend to create combinations of gamma functions that can be + # simplified. + + def can_do(pbm, pap): + """ Test if slater applies. """ + for i in pbm: + if len(pbm[i]) > 1: + l = 0 + if i in pap: + l = len(pap[i]) + if l + 1 < len(pbm[i]): + return False + return True + + def do_slater(an, bm, ap, bq, z, zfinal): + # zfinal is the value that will eventually be substituted for z. + # We pass it to _hyperexpand to improve performance. + func = G_Function(an, bm, ap, bq) + _, pbm, pap, _ = func.compute_buckets() + if not can_do(pbm, pap): + return S.Zero, False + + cond = len(an) + len(ap) < len(bm) + len(bq) + if len(an) + len(ap) == len(bm) + len(bq): + cond = abs(z) < 1 + if cond is False: + return S.Zero, False + + res = S.Zero + for m in pbm: + if len(pbm[m]) == 1: + bh = pbm[m][0] + fac = 1 + bo = list(bm) + bo.remove(bh) + for bj in bo: + fac *= gamma(bj - bh) + for aj in an: + fac *= gamma(1 + bh - aj) + for bj in bq: + fac /= gamma(1 + bh - bj) + for aj in ap: + fac /= gamma(aj - bh) + nap = [1 + bh - a for a in list(an) + list(ap)] + nbq = [1 + bh - b for b in list(bo) + list(bq)] + + k = polar_lift(S.NegativeOne**(len(ap) - len(bm))) + harg = k*zfinal + # NOTE even though k "is" +-1, this has to be t/k instead of + # t*k ... we are using polar numbers for consistency! + premult = (t/k)**bh + hyp = _hyperexpand(Hyper_Function(nap, nbq), harg, ops, + t, premult, bh, rewrite=None) + res += fac * hyp + else: + b_ = pbm[m][0] + ki = [bi - b_ for bi in pbm[m][1:]] + u = len(ki) + li = [ai - b_ for ai in pap[m][:u + 1]] + bo = list(bm) + for b in pbm[m]: + bo.remove(b) + ao = list(ap) + for a in pap[m][:u]: + ao.remove(a) + lu = li[-1] + di = [l - k for (l, k) in zip(li, ki)] + + # We first work out the integrand: + s = Dummy('s') + integrand = z**s + for b in bm: + if not Mod(b, 1) and b.is_Number: + b = int(round(b)) + integrand *= gamma(b - s) + for a in an: + integrand *= gamma(1 - a + s) + for b in bq: + integrand /= gamma(1 - b + s) + for a in ap: + integrand /= gamma(a - s) + + # Now sum the finitely many residues: + # XXX This speeds up some cases - is it a good idea? + integrand = expand_func(integrand) + for r in range(int(round(lu))): + resid = residue(integrand, s, b_ + r) + resid = apply_operators(resid, ops, lambda f: z*f.diff(z)) + res -= resid + + # Now the hypergeometric term. + au = b_ + lu + k = polar_lift(S.NegativeOne**(len(ao) + len(bo) + 1)) + harg = k*zfinal + premult = (t/k)**au + nap = [1 + au - a for a in list(an) + list(ap)] + [1] + nbq = [1 + au - b for b in list(bm) + list(bq)] + + hyp = _hyperexpand(Hyper_Function(nap, nbq), harg, ops, + t, premult, au, rewrite=None) + + C = S.NegativeOne**(lu)/factorial(lu) + for i in range(u): + C *= S.NegativeOne**di[i]/rf(lu - li[i] + 1, di[i]) + for a in an: + C *= gamma(1 - a + au) + for b in bo: + C *= gamma(b - au) + for a in ao: + C /= gamma(a - au) + for b in bq: + C /= gamma(1 - b + au) + + res += C*hyp + + return res, cond + + t = Dummy('t') + slater1, cond1 = do_slater(func.an, func.bm, func.ap, func.bq, z, z0) + + def tr(l): + return [1 - x for x in l] + + for op in ops: + op._poly = Poly(op._poly.subs({z: 1/t, _x: -_x}), _x) + slater2, cond2 = do_slater(tr(func.bm), tr(func.an), tr(func.bq), tr(func.ap), + t, 1/z0) + + slater1 = powdenest(slater1.subs(z, z0), polar=True) + slater2 = powdenest(slater2.subs(t, 1/z0), polar=True) + if not isinstance(cond2, bool): + cond2 = cond2.subs(t, 1/z) + + m = func(z) + if m.delta > 0 or \ + (m.delta == 0 and len(m.ap) == len(m.bq) and + (re(m.nu) < -1) is not False and polar_lift(z0) == polar_lift(1)): + # The condition delta > 0 means that the convergence region is + # connected. Any expression we find can be continued analytically + # to the entire convergence region. + # The conditions delta==0, p==q, re(nu) < -1 imply that G is continuous + # on the positive reals, so the values at z=1 agree. + if cond1 is not False: + cond1 = True + if cond2 is not False: + cond2 = True + + if cond1 is True: + slater1 = slater1.rewrite(rewrite or 'nonrep') + else: + slater1 = slater1.rewrite(rewrite or 'nonrepsmall') + if cond2 is True: + slater2 = slater2.rewrite(rewrite or 'nonrep') + else: + slater2 = slater2.rewrite(rewrite or 'nonrepsmall') + + if cond1 is not False and cond2 is not False: + # If one condition is False, there is no choice. + if place == 0: + cond2 = False + if place == zoo: + cond1 = False + + if not isinstance(cond1, bool): + cond1 = cond1.subs(z, z0) + if not isinstance(cond2, bool): + cond2 = cond2.subs(z, z0) + + def weight(expr, cond): + if cond is True: + c0 = 0 + elif cond is False: + c0 = 1 + else: + c0 = 2 + if expr.has(oo, zoo, -oo, nan): + # XXX this actually should not happen, but consider + # S('meijerg(((0, -1/2, 0, -1/2, 1/2), ()), ((0,), + # (-1/2, -1/2, -1/2, -1)), exp_polar(I*pi))/4') + c0 = 3 + return (c0, expr.count(hyper), expr.count_ops()) + + w1 = weight(slater1, cond1) + w2 = weight(slater2, cond2) + if min(w1, w2) <= (0, 1, oo): + if w1 < w2: + return slater1 + else: + return slater2 + if max(w1[0], w2[0]) <= 1 and max(w1[1], w2[1]) <= 1: + return Piecewise((slater1, cond1), (slater2, cond2), (func0(z0), True)) + + # We couldn't find an expression without hypergeometric functions. + # TODO it would be helpful to give conditions under which the integral + # is known to diverge. + r = Piecewise((slater1, cond1), (slater2, cond2), (func0(z0), True)) + if r.has(hyper) and not allow_hyper: + debug(' Could express using hypergeometric functions, ' + 'but not allowed.') + if not r.has(hyper) or allow_hyper: + return r + + return func0(z0) + + +def hyperexpand(f, allow_hyper=False, rewrite='default', place=None): + """ + Expand hypergeometric functions. If allow_hyper is True, allow partial + simplification (that is a result different from input, + but still containing hypergeometric functions). + + If a G-function has expansions both at zero and at infinity, + ``place`` can be set to ``0`` or ``zoo`` to indicate the + preferred choice. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import hyperexpand + >>> from sympy.functions import hyper + >>> from sympy.abc import z + >>> hyperexpand(hyper([], [], z)) + exp(z) + + Non-hyperegeometric parts of the expression and hypergeometric expressions + that are not recognised are left unchanged: + + >>> hyperexpand(1 + hyper([1, 1, 1], [], z)) + hyper((1, 1, 1), (), z) + 1 + """ + f = sympify(f) + + def do_replace(ap, bq, z): + r = _hyperexpand(Hyper_Function(ap, bq), z, rewrite=rewrite) + if r is None: + return hyper(ap, bq, z) + else: + return r + + def do_meijer(ap, bq, z): + r = _meijergexpand(G_Function(ap[0], ap[1], bq[0], bq[1]), z, + allow_hyper, rewrite=rewrite, place=place) + if not r.has(nan, zoo, oo, -oo): + return r + return f.replace(hyper, do_replace).replace(meijerg, do_meijer) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/hyperexpand_doc.py b/.venv/lib/python3.13/site-packages/sympy/simplify/hyperexpand_doc.py new file mode 100644 index 0000000000000000000000000000000000000000..a18377f3aede5214036fbf628825536611001584 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/hyperexpand_doc.py @@ -0,0 +1,18 @@ +""" This module cooks up a docstring when imported. Its only purpose is to + be displayed in the sphinx documentation. """ + +from sympy.core.relational import Eq +from sympy.functions.special.hyper import hyper +from sympy.printing.latex import latex +from sympy.simplify.hyperexpand import FormulaCollection + +c = FormulaCollection() + +doc = "" + +for f in c.formulae: + obj = Eq(hyper(f.func.ap, f.func.bq, f.z), + f.closed_form.rewrite('nonrepsmall')) + doc += ".. math::\n %s\n" % latex(obj) + +__doc__ = doc diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/powsimp.py b/.venv/lib/python3.13/site-packages/sympy/simplify/powsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..f72dfeb072e0d0d4737ace310eda5c2a3a082c16 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/powsimp.py @@ -0,0 +1,718 @@ +from collections import defaultdict +from functools import reduce +from math import prod + +from sympy.core.function import expand_log, count_ops, _coeff_isneg +from sympy.core import sympify, Basic, Dummy, S, Add, Mul, Pow, expand_mul, factor_terms +from sympy.core.sorting import ordered, default_sort_key +from sympy.core.numbers import Integer, Rational, equal_valued +from sympy.core.mul import _keep_coeff +from sympy.core.rules import Transform +from sympy.functions import exp_polar, exp, log, root, polarify, unpolarify +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.polys import lcm, gcd +from sympy.ntheory.factor_ import multiplicity + + + +def powsimp(expr, deep=False, combine='all', force=False, measure=count_ops): + """ + Reduce expression by combining powers with similar bases and exponents. + + Explanation + =========== + + If ``deep`` is ``True`` then powsimp() will also simplify arguments of + functions. By default ``deep`` is set to ``False``. + + If ``force`` is ``True`` then bases will be combined without checking for + assumptions, e.g. sqrt(x)*sqrt(y) -> sqrt(x*y) which is not true + if x and y are both negative. + + You can make powsimp() only combine bases or only combine exponents by + changing combine='base' or combine='exp'. By default, combine='all', + which does both. combine='base' will only combine:: + + a a a 2x x + x * y => (x*y) as well as things like 2 => 4 + + and combine='exp' will only combine + :: + + a b (a + b) + x * x => x + + combine='exp' will strictly only combine exponents in the way that used + to be automatic. Also use deep=True if you need the old behavior. + + When combine='all', 'exp' is evaluated first. Consider the first + example below for when there could be an ambiguity relating to this. + This is done so things like the second example can be completely + combined. If you want 'base' combined first, do something like + powsimp(powsimp(expr, combine='base'), combine='exp'). + + Examples + ======== + + >>> from sympy import powsimp, exp, log, symbols + >>> from sympy.abc import x, y, z, n + >>> powsimp(x**y*x**z*y**z, combine='all') + x**(y + z)*y**z + >>> powsimp(x**y*x**z*y**z, combine='exp') + x**(y + z)*y**z + >>> powsimp(x**y*x**z*y**z, combine='base', force=True) + x**y*(x*y)**z + + >>> powsimp(x**z*x**y*n**z*n**y, combine='all', force=True) + (n*x)**(y + z) + >>> powsimp(x**z*x**y*n**z*n**y, combine='exp') + n**(y + z)*x**(y + z) + >>> powsimp(x**z*x**y*n**z*n**y, combine='base', force=True) + (n*x)**y*(n*x)**z + + >>> x, y = symbols('x y', positive=True) + >>> powsimp(log(exp(x)*exp(y))) + log(exp(x)*exp(y)) + >>> powsimp(log(exp(x)*exp(y)), deep=True) + x + y + + Radicals with Mul bases will be combined if combine='exp' + + >>> from sympy import sqrt + >>> x, y = symbols('x y') + + Two radicals are automatically joined through Mul: + + >>> a=sqrt(x*sqrt(y)) + >>> a*a**3 == a**4 + True + + But if an integer power of that radical has been + autoexpanded then Mul does not join the resulting factors: + + >>> a**4 # auto expands to a Mul, no longer a Pow + x**2*y + >>> _*a # so Mul doesn't combine them + x**2*y*sqrt(x*sqrt(y)) + >>> powsimp(_) # but powsimp will + (x*sqrt(y))**(5/2) + >>> powsimp(x*y*a) # but won't when doing so would violate assumptions + x*y*sqrt(x*sqrt(y)) + + """ + def recurse(arg, **kwargs): + _deep = kwargs.get('deep', deep) + _combine = kwargs.get('combine', combine) + _force = kwargs.get('force', force) + _measure = kwargs.get('measure', measure) + return powsimp(arg, _deep, _combine, _force, _measure) + + expr = sympify(expr) + + if (not isinstance(expr, Basic) or isinstance(expr, MatrixSymbol) or ( + expr.is_Atom or expr in (exp_polar(0), exp_polar(1)))): + return expr + + if deep or expr.is_Add or expr.is_Mul and _y not in expr.args: + expr = expr.func(*[recurse(w) for w in expr.args]) + + if expr.is_Pow: + return recurse(expr*_y, deep=False)/_y + + if not expr.is_Mul: + return expr + + # handle the Mul + if combine in ('exp', 'all'): + # Collect base/exp data, while maintaining order in the + # non-commutative parts of the product + c_powers = defaultdict(list) + nc_part = [] + newexpr = [] + coeff = S.One + for term in expr.args: + if term.is_Rational: + coeff *= term + continue + if term.is_Pow: + term = _denest_pow(term) + if term.is_commutative: + b, e = term.as_base_exp() + if deep: + b, e = [recurse(i) for i in [b, e]] + if b.is_Pow or isinstance(b, exp): + # don't let smthg like sqrt(x**a) split into x**a, 1/2 + # or else it will be joined as x**(a/2) later + b, e = b**e, S.One + c_powers[b].append(e) + else: + # This is the logic that combines exponents for equal, + # but non-commutative bases: A**x*A**y == A**(x+y). + if nc_part: + b1, e1 = nc_part[-1].as_base_exp() + b2, e2 = term.as_base_exp() + if (b1 == b2 and + e1.is_commutative and e2.is_commutative): + nc_part[-1] = Pow(b1, Add(e1, e2)) + continue + nc_part.append(term) + + # add up exponents of common bases + for b, e in ordered(iter(c_powers.items())): + # allow 2**x/4 -> 2**(x - 2); don't do this when b and e are + # Numbers since autoevaluation will undo it, e.g. + # 2**(1/3)/4 -> 2**(1/3 - 2) -> 2**(1/3)/4 + if (b and b.is_Rational and not all(ei.is_Number for ei in e) and \ + coeff is not S.One and + b not in (S.One, S.NegativeOne)): + m = multiplicity(abs(b), abs(coeff)) + if m: + e.append(m) + coeff /= b**m + c_powers[b] = Add(*e) + if coeff is not S.One: + if coeff in c_powers: + c_powers[coeff] += S.One + else: + c_powers[coeff] = S.One + + # convert to plain dictionary + c_powers = dict(c_powers) + + # check for base and inverted base pairs + be = list(c_powers.items()) + skip = set() # skip if we already saw them + for b, e in be: + if b in skip: + continue + bpos = b.is_positive or b.is_polar + if bpos: + binv = 1/b + #Special case for float 1 + if b.is_Float and equal_valued(b, 1): + c_powers[b] = S.One + continue + if b != binv and binv in c_powers: + if b.as_numer_denom()[0] is S.One: + c_powers.pop(b) + c_powers[binv] -= e + else: + skip.add(binv) + e = c_powers.pop(binv) + c_powers[b] -= e + + # check for base and negated base pairs + be = list(c_powers.items()) + _n = S.NegativeOne + for b, e in be: + if (b.is_Symbol or b.is_Add) and -b in c_powers and b in c_powers: + if (b.is_positive is not None or e.is_integer): + if e.is_integer or b.is_negative: + c_powers[-b] += c_powers.pop(b) + else: # (-b).is_positive so use its e + e = c_powers.pop(-b) + c_powers[b] += e + if _n in c_powers: + c_powers[_n] += e + else: + c_powers[_n] = e + + # filter c_powers and convert to a list + c_powers = [(b, e) for b, e in c_powers.items() if e] + + # ============================================================== + # check for Mul bases of Rational powers that can be combined with + # separated bases, e.g. x*sqrt(x*y)*sqrt(x*sqrt(x*y)) -> + # (x*sqrt(x*y))**(3/2) + # ---------------- helper functions + + def ratq(x): + '''Return Rational part of x's exponent as it appears in the bkey. + ''' + return bkey(x)[0][1] + + def bkey(b, e=None): + '''Return (b**s, c.q), c.p where e -> c*s. If e is not given then + it will be taken by using as_base_exp() on the input b. + e.g. + x**3/2 -> (x, 2), 3 + x**y -> (x**y, 1), 1 + x**(2*y/3) -> (x**y, 3), 2 + exp(x/2) -> (exp(a), 2), 1 + + ''' + if e is not None: # coming from c_powers or from below + if e.is_Integer: + return (b, S.One), e + elif e.is_Rational: + return (b, Integer(e.q)), Integer(e.p) + else: + c, m = e.as_coeff_Mul(rational=True) + if c is not S.One: + if m.is_integer: + return (b, Integer(c.q)), m*Integer(c.p) + return (b**m, Integer(c.q)), Integer(c.p) + else: + return (b**e, S.One), S.One + else: + return bkey(*b.as_base_exp()) + + def update(b): + '''Decide what to do with base, b. If its exponent is now an + integer multiple of the Rational denominator, then remove it + and put the factors of its base in the common_b dictionary or + update the existing bases if necessary. If it has been zeroed + out, simply remove the base. + ''' + newe, r = divmod(common_b[b], b[1]) + if not r: + common_b.pop(b) + if newe: + for m in Mul.make_args(b[0]**newe): + b, e = bkey(m) + if b not in common_b: + common_b[b] = 0 + common_b[b] += e + if b[1] != 1: + bases.append(b) + # ---------------- end of helper functions + + # assemble a dictionary of the factors having a Rational power + common_b = {} + done = [] + bases = [] + for b, e in c_powers: + b, e = bkey(b, e) + if b in common_b: + common_b[b] = common_b[b] + e + else: + common_b[b] = e + if b[1] != 1 and b[0].is_Mul: + bases.append(b) + bases.sort(key=default_sort_key) # this makes tie-breaking canonical + bases.sort(key=measure, reverse=True) # handle longest first + for base in bases: + if base not in common_b: # it may have been removed already + continue + b, exponent = base + last = False # True when no factor of base is a radical + qlcm = 1 # the lcm of the radical denominators + while True: + bstart = b + qstart = qlcm + + bb = [] # list of factors + ee = [] # (factor's expo. and it's current value in common_b) + for bi in Mul.make_args(b): + bib, bie = bkey(bi) + if bib not in common_b or common_b[bib] < bie: + ee = bb = [] # failed + break + ee.append([bie, common_b[bib]]) + bb.append(bib) + if ee: + # find the number of integral extractions possible + # e.g. [(1, 2), (2, 2)] -> min(2/1, 2/2) -> 1 + min1 = ee[0][1]//ee[0][0] + for i in range(1, len(ee)): + rat = ee[i][1]//ee[i][0] + if rat < 1: + break + min1 = min(min1, rat) + else: + # update base factor counts + # e.g. if ee = [(2, 5), (3, 6)] then min1 = 2 + # and the new base counts will be 5-2*2 and 6-2*3 + for i in range(len(bb)): + common_b[bb[i]] -= min1*ee[i][0] + update(bb[i]) + # update the count of the base + # e.g. x**2*y*sqrt(x*sqrt(y)) the count of x*sqrt(y) + # will increase by 4 to give bkey (x*sqrt(y), 2, 5) + common_b[base] += min1*qstart*exponent + if (last # no more radicals in base + or len(common_b) == 1 # nothing left to join with + or all(k[1] == 1 for k in common_b) # no rad's in common_b + ): + break + # see what we can exponentiate base by to remove any radicals + # so we know what to search for + # e.g. if base were x**(1/2)*y**(1/3) then we should + # exponentiate by 6 and look for powers of x and y in the ratio + # of 2 to 3 + qlcm = lcm([ratq(bi) for bi in Mul.make_args(bstart)]) + if qlcm == 1: + break # we are done + b = bstart**qlcm + qlcm *= qstart + if all(ratq(bi) == 1 for bi in Mul.make_args(b)): + last = True # we are going to be done after this next pass + # this base no longer can find anything to join with and + # since it was longer than any other we are done with it + b, q = base + done.append((b, common_b.pop(base)*Rational(1, q))) + + # update c_powers and get ready to continue with powsimp + c_powers = done + # there may be terms still in common_b that were bases that were + # identified as needing processing, so remove those, too + for (b, q), e in common_b.items(): + if (b.is_Pow or isinstance(b, exp)) and \ + q is not S.One and not b.exp.is_Rational: + b, be = b.as_base_exp() + b = b**(be/q) + else: + b = root(b, q) + c_powers.append((b, e)) + check = len(c_powers) + c_powers = dict(c_powers) + assert len(c_powers) == check # there should have been no duplicates + # ============================================================== + + # rebuild the expression + newexpr = expr.func(*(newexpr + [Pow(b, e) for b, e in c_powers.items()])) + if combine == 'exp': + return expr.func(newexpr, expr.func(*nc_part)) + else: + return recurse(expr.func(*nc_part), combine='base') * \ + recurse(newexpr, combine='base') + + elif combine == 'base': + + # Build c_powers and nc_part. These must both be lists not + # dicts because exp's are not combined. + c_powers = [] + nc_part = [] + for term in expr.args: + if term.is_commutative: + c_powers.append(list(term.as_base_exp())) + else: + nc_part.append(term) + + # Pull out numerical coefficients from exponent if assumptions allow + # e.g., 2**(2*x) => 4**x + for i in range(len(c_powers)): + b, e = c_powers[i] + if not (all(x.is_nonnegative for x in b.as_numer_denom()) or e.is_integer or force or b.is_polar): + continue + exp_c, exp_t = e.as_coeff_Mul(rational=True) + if exp_c is not S.One and exp_t is not S.One: + c_powers[i] = [Pow(b, exp_c), exp_t] + + # Combine bases whenever they have the same exponent and + # assumptions allow + # first gather the potential bases under the common exponent + c_exp = defaultdict(list) + for b, e in c_powers: + if deep: + e = recurse(e) + if e.is_Add and (b.is_positive or e.is_integer): + e = factor_terms(e) + if _coeff_isneg(e): + e = -e + b = 1/b + c_exp[e].append(b) + del c_powers + + # Merge back in the results of the above to form a new product + c_powers = defaultdict(list) + for e in c_exp: + bases = c_exp[e] + + # calculate the new base for e + + if len(bases) == 1: + new_base = bases[0] + elif e.is_integer or force: + new_base = expr.func(*bases) + else: + # see which ones can be joined + unk = [] + nonneg = [] + neg = [] + for bi in bases: + if bi.is_negative: + neg.append(bi) + elif bi.is_nonnegative: + nonneg.append(bi) + elif bi.is_polar: + nonneg.append( + bi) # polar can be treated like non-negative + else: + unk.append(bi) + if len(unk) == 1 and not neg or len(neg) == 1 and not unk: + # a single neg or a single unk can join the rest + nonneg.extend(unk + neg) + unk = neg = [] + elif neg: + # their negative signs cancel in groups of 2*q if we know + # that e = p/q else we have to treat them as unknown + israt = False + if e.is_Rational: + israt = True + else: + p, d = e.as_numer_denom() + if p.is_integer and d.is_integer: + israt = True + if israt: + neg = [-w for w in neg] + unk.extend([S.NegativeOne]*len(neg)) + else: + unk.extend(neg) + neg = [] + del israt + + # these shouldn't be joined + for b in unk: + c_powers[b].append(e) + # here is a new joined base + new_base = expr.func(*(nonneg + neg)) + # if there are positive parts they will just get separated + # again unless some change is made + + def _terms(e): + # return the number of terms of this expression + # when multiplied out -- assuming no joining of terms + if e.is_Add: + return sum(_terms(ai) for ai in e.args) + if e.is_Mul: + return prod([_terms(mi) for mi in e.args]) + return 1 + xnew_base = expand_mul(new_base, deep=False) + if len(Add.make_args(xnew_base)) < _terms(new_base): + new_base = factor_terms(xnew_base) + + c_powers[new_base].append(e) + + # break out the powers from c_powers now + c_part = [Pow(b, ei) for b, e in c_powers.items() for ei in e] + + # we're done + return expr.func(*(c_part + nc_part)) + + else: + raise ValueError("combine must be one of ('all', 'exp', 'base').") + + +def powdenest(eq, force=False, polar=False): + r""" + Collect exponents on powers as assumptions allow. + + Explanation + =========== + + Given ``(bb**be)**e``, this can be simplified as follows: + * if ``bb`` is positive, or + * ``e`` is an integer, or + * ``|be| < 1`` then this simplifies to ``bb**(be*e)`` + + Given a product of powers raised to a power, ``(bb1**be1 * + bb2**be2...)**e``, simplification can be done as follows: + + - if e is positive, the gcd of all bei can be joined with e; + - all non-negative bb can be separated from those that are negative + and their gcd can be joined with e; autosimplification already + handles this separation. + - integer factors from powers that have integers in the denominator + of the exponent can be removed from any term and the gcd of such + integers can be joined with e + + Setting ``force`` to ``True`` will make symbols that are not explicitly + negative behave as though they are positive, resulting in more + denesting. + + Setting ``polar`` to ``True`` will do simplifications on the Riemann surface of + the logarithm, also resulting in more denestings. + + When there are sums of logs in exp() then a product of powers may be + obtained e.g. ``exp(3*(log(a) + 2*log(b)))`` - > ``a**3*b**6``. + + Examples + ======== + + >>> from sympy.abc import a, b, x, y, z + >>> from sympy import Symbol, exp, log, sqrt, symbols, powdenest + + >>> powdenest((x**(2*a/3))**(3*x)) + (x**(2*a/3))**(3*x) + >>> powdenest(exp(3*x*log(2))) + 2**(3*x) + + Assumptions may prevent expansion: + + >>> powdenest(sqrt(x**2)) + sqrt(x**2) + + >>> p = symbols('p', positive=True) + >>> powdenest(sqrt(p**2)) + p + + No other expansion is done. + + >>> i, j = symbols('i,j', integer=True) + >>> powdenest((x**x)**(i + j)) # -X-> (x**x)**i*(x**x)**j + x**(x*(i + j)) + + But exp() will be denested by moving all non-log terms outside of + the function; this may result in the collapsing of the exp to a power + with a different base: + + >>> powdenest(exp(3*y*log(x))) + x**(3*y) + >>> powdenest(exp(y*(log(a) + log(b)))) + (a*b)**y + >>> powdenest(exp(3*(log(a) + log(b)))) + a**3*b**3 + + If assumptions allow, symbols can also be moved to the outermost exponent: + + >>> i = Symbol('i', integer=True) + >>> powdenest(((x**(2*i))**(3*y))**x) + ((x**(2*i))**(3*y))**x + >>> powdenest(((x**(2*i))**(3*y))**x, force=True) + x**(6*i*x*y) + + >>> powdenest(((x**(2*a/3))**(3*y/i))**x) + ((x**(2*a/3))**(3*y/i))**x + >>> powdenest((x**(2*i)*y**(4*i))**z, force=True) + (x*y**2)**(2*i*z) + + >>> n = Symbol('n', negative=True) + + >>> powdenest((x**i)**y, force=True) + x**(i*y) + >>> powdenest((n**i)**x, force=True) + (n**i)**x + + """ + from sympy.simplify.simplify import posify + + if force: + def _denest(b, e): + if not isinstance(b, (Pow, exp)): + return b.is_positive, Pow(b, e, evaluate=False) + return _denest(b.base, b.exp*e) + reps = [] + for p in eq.atoms(Pow, exp): + if isinstance(p.base, (Pow, exp)): + ok, dp = _denest(*p.args) + if ok is not False: + reps.append((p, dp)) + if reps: + eq = eq.subs(reps) + eq, reps = posify(eq) + return powdenest(eq, force=False, polar=polar).xreplace(reps) + + if polar: + eq, rep = polarify(eq) + return unpolarify(powdenest(unpolarify(eq, exponents_only=True)), rep) + + new = powsimp(eq) + return new.xreplace(Transform( + _denest_pow, filter=lambda m: m.is_Pow or isinstance(m, exp))) + +_y = Dummy('y') + + +def _denest_pow(eq): + """ + Denest powers. + + This is a helper function for powdenest that performs the actual + transformation. + """ + from sympy.simplify.simplify import logcombine + + b, e = eq.as_base_exp() + if b.is_Pow or isinstance(b, exp) and e != 1: + new = b._eval_power(e) + if new is not None: + eq = new + b, e = new.as_base_exp() + + # denest exp with log terms in exponent + if b is S.Exp1 and e.is_Mul: + logs = [] + other = [] + for ei in e.args: + if any(isinstance(ai, log) for ai in Add.make_args(ei)): + logs.append(ei) + else: + other.append(ei) + logs = logcombine(Mul(*logs)) + return Pow(exp(logs), Mul(*other)) + + _, be = b.as_base_exp() + if be is S.One and not (b.is_Mul or + b.is_Rational and b.q != 1 or + b.is_positive): + return eq + + # denest eq which is either pos**e or Pow**e or Mul**e or + # Mul(b1**e1, b2**e2) + + # handle polar numbers specially + polars, nonpolars = [], [] + for bb in Mul.make_args(b): + if bb.is_polar: + polars.append(bb.as_base_exp()) + else: + nonpolars.append(bb) + if len(polars) == 1 and not polars[0][0].is_Mul: + return Pow(polars[0][0], polars[0][1]*e)*powdenest(Mul(*nonpolars)**e) + elif polars: + return Mul(*[powdenest(bb**(ee*e)) for (bb, ee) in polars]) \ + *powdenest(Mul(*nonpolars)**e) + + if b.is_Integer: + # use log to see if there is a power here + logb = expand_log(log(b)) + if logb.is_Mul: + c, logb = logb.args + e *= c + base = logb.args[0] + return Pow(base, e) + + # if b is not a Mul or any factor is an atom then there is nothing to do + if not b.is_Mul or any(s.is_Atom for s in Mul.make_args(b)): + return eq + + # let log handle the case of the base of the argument being a Mul, e.g. + # sqrt(x**(2*i)*y**(6*i)) -> x**i*y**(3**i) if x and y are positive; we + # will take the log, expand it, and then factor out the common powers that + # now appear as coefficient. We do this manually since terms_gcd pulls out + # fractions, terms_gcd(x+x*y/2) -> x*(y + 2)/2 and we don't want the 1/2; + # gcd won't pull out numerators from a fraction: gcd(3*x, 9*x/2) -> x but + # we want 3*x. Neither work with noncommutatives. + + def nc_gcd(aa, bb): + a, b = [i.as_coeff_Mul() for i in [aa, bb]] + c = gcd(a[0], b[0]).as_numer_denom()[0] + g = Mul(*(a[1].args_cnc(cset=True)[0] & b[1].args_cnc(cset=True)[0])) + return _keep_coeff(c, g) + + glogb = expand_log(log(b)) + if glogb.is_Add: + args = glogb.args + g = reduce(nc_gcd, args) + if g != 1: + cg, rg = g.as_coeff_Mul() + glogb = _keep_coeff(cg, rg*Add(*[a/g for a in args])) + + # now put the log back together again + if isinstance(glogb, log) or not glogb.is_Mul: + if glogb.args[0].is_Pow or isinstance(glogb.args[0], exp): + glogb = _denest_pow(glogb.args[0]) + if (abs(glogb.exp) < 1) == True: + return Pow(glogb.base, glogb.exp*e) + return eq + + # the log(b) was a Mul so join any adds with logcombine + add = [] + other = [] + for a in glogb.args: + if a.is_Add: + add.append(a) + else: + other.append(a) + return Pow(exp(logcombine(Mul(*add))), e*Mul(*other)) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/radsimp.py b/.venv/lib/python3.13/site-packages/sympy/simplify/radsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..c878168ebfbc29fc632577d6325befc120c26f56 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/radsimp.py @@ -0,0 +1,1234 @@ +from collections import defaultdict + +from sympy.core import sympify, S, Mul, Derivative, Pow +from sympy.core.add import _unevaluated_Add, Add +from sympy.core.assumptions import assumptions +from sympy.core.exprtools import Factors, gcd_terms +from sympy.core.function import _mexpand, expand_mul, expand_power_base +from sympy.core.mul import _keep_coeff, _unevaluated_Mul, _mulsort +from sympy.core.numbers import Rational, zoo, nan +from sympy.core.parameters import global_parameters +from sympy.core.sorting import ordered, default_sort_key +from sympy.core.symbol import Dummy, Wild, symbols +from sympy.functions import exp, sqrt, log +from sympy.functions.elementary.complexes import Abs +from sympy.polys import gcd +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.utilities.iterables import iterable, sift + + + + +def collect(expr, syms, func=None, evaluate=None, exact=False, distribute_order_term=True): + """ + Collect additive terms of an expression. + + Explanation + =========== + + This function collects additive terms of an expression with respect + to a list of expression up to powers with rational exponents. By the + term symbol here are meant arbitrary expressions, which can contain + powers, products, sums etc. In other words symbol is a pattern which + will be searched for in the expression's terms. + + The input expression is not expanded by :func:`collect`, so user is + expected to provide an expression in an appropriate form. This makes + :func:`collect` more predictable as there is no magic happening behind the + scenes. However, it is important to note, that powers of products are + converted to products of powers using the :func:`~.expand_power_base` + function. + + There are two possible types of output. First, if ``evaluate`` flag is + set, this function will return an expression with collected terms or + else it will return a dictionary with expressions up to rational powers + as keys and collected coefficients as values. + + Examples + ======== + + >>> from sympy import S, collect, expand, factor, Wild + >>> from sympy.abc import a, b, c, x, y + + This function can collect symbolic coefficients in polynomials or + rational expressions. It will manage to find all integer or rational + powers of collection variable:: + + >>> collect(a*x**2 + b*x**2 + a*x - b*x + c, x) + c + x**2*(a + b) + x*(a - b) + + The same result can be achieved in dictionary form:: + + >>> d = collect(a*x**2 + b*x**2 + a*x - b*x + c, x, evaluate=False) + >>> d[x**2] + a + b + >>> d[x] + a - b + >>> d[S.One] + c + + You can also work with multivariate polynomials. However, remember that + this function is greedy so it will care only about a single symbol at time, + in specification order:: + + >>> collect(x**2 + y*x**2 + x*y + y + a*y, [x, y]) + x**2*(y + 1) + x*y + y*(a + 1) + + Also more complicated expressions can be used as patterns:: + + >>> from sympy import sin, log + >>> collect(a*sin(2*x) + b*sin(2*x), sin(2*x)) + (a + b)*sin(2*x) + + >>> collect(a*x*log(x) + b*(x*log(x)), x*log(x)) + x*(a + b)*log(x) + + You can use wildcards in the pattern:: + + >>> w = Wild('w1') + >>> collect(a*x**y - b*x**y, w**y) + x**y*(a - b) + + It is also possible to work with symbolic powers, although it has more + complicated behavior, because in this case power's base and symbolic part + of the exponent are treated as a single symbol:: + + >>> collect(a*x**c + b*x**c, x) + a*x**c + b*x**c + >>> collect(a*x**c + b*x**c, x**c) + x**c*(a + b) + + However if you incorporate rationals to the exponents, then you will get + well known behavior:: + + >>> collect(a*x**(2*c) + b*x**(2*c), x**c) + x**(2*c)*(a + b) + + Note also that all previously stated facts about :func:`collect` function + apply to the exponential function, so you can get:: + + >>> from sympy import exp + >>> collect(a*exp(2*x) + b*exp(2*x), exp(x)) + (a + b)*exp(2*x) + + If you are interested only in collecting specific powers of some symbols + then set ``exact`` flag to True:: + + >>> collect(a*x**7 + b*x**7, x, exact=True) + a*x**7 + b*x**7 + >>> collect(a*x**7 + b*x**7, x**7, exact=True) + x**7*(a + b) + + If you want to collect on any object containing symbols, set + ``exact`` to None: + + >>> collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None) + x*exp(x) + 3*x + (y + 2)*sin(x) + >>> collect(a*x*y + x*y + b*x + x, [x, y], exact=None) + x*y*(a + 1) + x*(b + 1) + + You can also apply this function to differential equations, where + derivatives of arbitrary order can be collected. Note that if you + collect with respect to a function or a derivative of a function, all + derivatives of that function will also be collected. Use + ``exact=True`` to prevent this from happening:: + + >>> from sympy import Derivative as D, collect, Function + >>> f = Function('f') (x) + + >>> collect(a*D(f,x) + b*D(f,x), D(f,x)) + (a + b)*Derivative(f(x), x) + + >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), f) + (a + b)*Derivative(f(x), (x, 2)) + + >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), D(f,x), exact=True) + a*Derivative(f(x), (x, 2)) + b*Derivative(f(x), (x, 2)) + + >>> collect(a*D(f,x) + b*D(f,x) + a*f + b*f, f) + (a + b)*f(x) + (a + b)*Derivative(f(x), x) + + Or you can even match both derivative order and exponent at the same time:: + + >>> collect(a*D(D(f,x),x)**2 + b*D(D(f,x),x)**2, D(f,x)) + (a + b)*Derivative(f(x), (x, 2))**2 + + Finally, you can apply a function to each of the collected coefficients. + For example you can factorize symbolic coefficients of polynomial:: + + >>> f = expand((x + a + 1)**3) + + >>> collect(f, x, factor) + x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + (a + 1)**3 + + .. note:: Arguments are expected to be in expanded form, so you might have + to call :func:`~.expand` prior to calling this function. + + See Also + ======== + + collect_const, collect_sqrt, rcollect + """ + expr = sympify(expr) + syms = [sympify(i) for i in (syms if iterable(syms) else [syms])] + + # replace syms[i] if it is not x, -x or has Wild symbols + cond = lambda x: x.is_Symbol or (-x).is_Symbol or bool( + x.atoms(Wild)) + _, nonsyms = sift(syms, cond, binary=True) + if nonsyms: + reps = dict(zip(nonsyms, [Dummy(**assumptions(i)) for i in nonsyms])) + syms = [reps.get(s, s) for s in syms] + rv = collect(expr.subs(reps), syms, + func=func, evaluate=evaluate, exact=exact, + distribute_order_term=distribute_order_term) + urep = {v: k for k, v in reps.items()} + if not isinstance(rv, dict): + return rv.xreplace(urep) + else: + return {urep.get(k, k).xreplace(urep): v.xreplace(urep) + for k, v in rv.items()} + + # see if other expressions should be considered + if exact is None: + _syms = set() + for i in Add.make_args(expr): + if not i.has_free(*syms) or i in syms: + continue + if not i.is_Mul and i not in syms: + _syms.add(i) + else: + # identify compound generators + g = i._new_rawargs(*i.as_coeff_mul(*syms)[1]) + if g not in syms: + _syms.add(g) + simple = all(i.is_Pow and i.base in syms for i in _syms) + syms = syms + list(ordered(_syms)) + if not simple: + return collect(expr, syms, + func=func, evaluate=evaluate, exact=False, + distribute_order_term=distribute_order_term) + + if evaluate is None: + evaluate = global_parameters.evaluate + + def make_expression(terms): + product = [] + + for term, rat, sym, deriv in terms: + if deriv is not None: + var, order = deriv + for _ in range(order): + term = Derivative(term, var) + + if sym is None: + if rat is S.One: + product.append(term) + else: + product.append(Pow(term, rat)) + else: + product.append(Pow(term, rat*sym)) + + return Mul(*product) + + def parse_derivative(deriv): + # scan derivatives tower in the input expression and return + # underlying function and maximal differentiation order + expr, sym, order = deriv.expr, deriv.variables[0], 1 + + for s in deriv.variables[1:]: + if s == sym: + order += 1 + else: + raise NotImplementedError( + 'Improve MV Derivative support in collect') + + while isinstance(expr, Derivative): + s0 = expr.variables[0] + + if any(s != s0 for s in expr.variables): + raise NotImplementedError( + 'Improve MV Derivative support in collect') + + if s0 == sym: + expr, order = expr.expr, order + len(expr.variables) + else: + break + + return expr, (sym, Rational(order)) + + def parse_term(expr): + """Parses expression expr and outputs tuple (sexpr, rat_expo, + sym_expo, deriv) + where: + - sexpr is the base expression + - rat_expo is the rational exponent that sexpr is raised to + - sym_expo is the symbolic exponent that sexpr is raised to + - deriv contains the derivatives of the expression + + For example, the output of x would be (x, 1, None, None) + the output of 2**x would be (2, 1, x, None). + """ + rat_expo, sym_expo = S.One, None + sexpr, deriv = expr, None + + if expr.is_Pow: + if isinstance(expr.base, Derivative): + sexpr, deriv = parse_derivative(expr.base) + else: + sexpr = expr.base + + if expr.base == S.Exp1: + arg = expr.exp + if arg.is_Rational: + sexpr, rat_expo = S.Exp1, arg + elif arg.is_Mul: + coeff, tail = arg.as_coeff_Mul(rational=True) + sexpr, rat_expo = exp(tail), coeff + + elif expr.exp.is_Number: + rat_expo = expr.exp + else: + coeff, tail = expr.exp.as_coeff_Mul() + + if coeff.is_Number: + rat_expo, sym_expo = coeff, tail + else: + sym_expo = expr.exp + elif isinstance(expr, exp): + arg = expr.exp + if arg.is_Rational: + sexpr, rat_expo = S.Exp1, arg + elif arg.is_Mul: + coeff, tail = arg.as_coeff_Mul(rational=True) + sexpr, rat_expo = exp(tail), coeff + elif isinstance(expr, Derivative): + sexpr, deriv = parse_derivative(expr) + + return sexpr, rat_expo, sym_expo, deriv + + def parse_expression(terms, pattern): + """Parse terms searching for a pattern. + Terms is a list of tuples as returned by parse_terms; + Pattern is an expression treated as a product of factors. + """ + pattern = Mul.make_args(pattern) + + if len(terms) < len(pattern): + # pattern is longer than matched product + # so no chance for positive parsing result + return None + else: + pattern = [parse_term(elem) for elem in pattern] + + terms = terms[:] # need a copy + elems, common_expo, has_deriv = [], None, False + + for elem, e_rat, e_sym, e_ord in pattern: + + if elem.is_Number and e_rat == 1 and e_sym is None: + # a constant is a match for everything + continue + + for j in range(len(terms)): + if terms[j] is None: + continue + + term, t_rat, t_sym, t_ord = terms[j] + + # keeping track of whether one of the terms had + # a derivative or not as this will require rebuilding + # the expression later + if t_ord is not None: + has_deriv = True + + if (term.match(elem) is not None and + (t_sym == e_sym or t_sym is not None and + e_sym is not None and + t_sym.match(e_sym) is not None)): + if exact is False: + # we don't have to be exact so find common exponent + # for both expression's term and pattern's element + expo = t_rat / e_rat + + if common_expo is None: + # first time + common_expo = expo + else: + # common exponent was negotiated before so + # there is no chance for a pattern match unless + # common and current exponents are equal + if common_expo != expo: + common_expo = 1 + else: + # we ought to be exact so all fields of + # interest must match in every details + if e_rat != t_rat or e_ord != t_ord: + continue + + # found common term so remove it from the expression + # and try to match next element in the pattern + elems.append(terms[j]) + terms[j] = None + + break + + else: + # pattern element not found + return None + + return [_f for _f in terms if _f], elems, common_expo, has_deriv + + if evaluate: + if expr.is_Add: + o = expr.getO() or 0 + expr = expr.func(*[ + collect(a, syms, func, True, exact, distribute_order_term) + for a in expr.args if a != o]) + o + elif expr.is_Mul: + return expr.func(*[ + collect(term, syms, func, True, exact, distribute_order_term) + for term in expr.args]) + elif expr.is_Pow: + b = collect( + expr.base, syms, func, True, exact, distribute_order_term) + return Pow(b, expr.exp) + + syms = [expand_power_base(i, deep=False) for i in syms] + + order_term = None + + if distribute_order_term: + order_term = expr.getO() + + if order_term is not None: + if order_term.has(*syms): + order_term = None + else: + expr = expr.removeO() + + summa = [expand_power_base(i, deep=False) for i in Add.make_args(expr)] + + collected, disliked = defaultdict(list), S.Zero + for product in summa: + c, nc = product.args_cnc(split_1=False) + args = list(ordered(c)) + nc + terms = [parse_term(i) for i in args] + small_first = True + + for symbol in syms: + if isinstance(symbol, Derivative) and small_first: + terms = list(reversed(terms)) + small_first = not small_first + result = parse_expression(terms, symbol) + + if result is not None: + if not symbol.is_commutative: + raise AttributeError("Can not collect noncommutative symbol") + + terms, elems, common_expo, has_deriv = result + + # when there was derivative in current pattern we + # will need to rebuild its expression from scratch + if not has_deriv: + margs = [] + for elem in elems: + if elem[2] is None: + e = elem[1] + else: + e = elem[1]*elem[2] + margs.append(Pow(elem[0], e)) + index = Mul(*margs) + else: + index = make_expression(elems) + terms = expand_power_base(make_expression(terms), deep=False) + index = expand_power_base(index, deep=False) + collected[index].append(terms) + break + else: + # none of the patterns matched + disliked += product + # add terms now for each key + collected = {k: Add(*v) for k, v in collected.items()} + + if disliked is not S.Zero: + collected[S.One] = disliked + + if order_term is not None: + for key, val in collected.items(): + collected[key] = val + order_term + + if func is not None: + collected = { + key: func(val) for key, val in collected.items()} + + if evaluate: + return Add(*[key*val for key, val in collected.items()]) + else: + return collected + + +def rcollect(expr, *vars): + """ + Recursively collect sums in an expression. + + Examples + ======== + + >>> from sympy.simplify import rcollect + >>> from sympy.abc import x, y + + >>> expr = (x**2*y + x*y + x + y)/(x + y) + + >>> rcollect(expr, y) + (x + y*(x**2 + x + 1))/(x + y) + + See Also + ======== + + collect, collect_const, collect_sqrt + """ + if expr.is_Atom or not expr.has(*vars): + return expr + else: + expr = expr.__class__(*[rcollect(arg, *vars) for arg in expr.args]) + + if expr.is_Add: + return collect(expr, vars) + else: + return expr + + +def collect_sqrt(expr, evaluate=None): + """Return expr with terms having common square roots collected together. + If ``evaluate`` is False a count indicating the number of sqrt-containing + terms will be returned and, if non-zero, the terms of the Add will be + returned, else the expression itself will be returned as a single term. + If ``evaluate`` is True, the expression with any collected terms will be + returned. + + Note: since I = sqrt(-1), it is collected, too. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.radsimp import collect_sqrt + >>> from sympy.abc import a, b + + >>> r2, r3, r5 = [sqrt(i) for i in [2, 3, 5]] + >>> collect_sqrt(a*r2 + b*r2) + sqrt(2)*(a + b) + >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r3) + sqrt(2)*(a + b) + sqrt(3)*(a + b) + >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5) + sqrt(3)*a + sqrt(5)*b + sqrt(2)*(a + b) + + If evaluate is False then the arguments will be sorted and + returned as a list and a count of the number of sqrt-containing + terms will be returned: + + >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5, evaluate=False) + ((sqrt(3)*a, sqrt(5)*b, sqrt(2)*(a + b)), 3) + >>> collect_sqrt(a*sqrt(2) + b, evaluate=False) + ((b, sqrt(2)*a), 1) + >>> collect_sqrt(a + b, evaluate=False) + ((a + b,), 0) + + See Also + ======== + + collect, collect_const, rcollect + """ + if evaluate is None: + evaluate = global_parameters.evaluate + # this step will help to standardize any complex arguments + # of sqrts + coeff, expr = expr.as_content_primitive() + vars = set() + for a in Add.make_args(expr): + for m in a.args_cnc()[0]: + if m.is_number and ( + m.is_Pow and m.exp.is_Rational and m.exp.q == 2 or + m is S.ImaginaryUnit): + vars.add(m) + + # we only want radicals, so exclude Number handling; in this case + # d will be evaluated + d = collect_const(expr, *vars, Numbers=False) + hit = expr != d + + if not evaluate: + nrad = 0 + # make the evaluated args canonical + args = list(ordered(Add.make_args(d))) + for i, m in enumerate(args): + c, nc = m.args_cnc() + for ci in c: + # XXX should this be restricted to ci.is_number as above? + if ci.is_Pow and ci.exp.is_Rational and ci.exp.q == 2 or \ + ci is S.ImaginaryUnit: + nrad += 1 + break + args[i] *= coeff + if not (hit or nrad): + args = [Add(*args)] + return tuple(args), nrad + + return coeff*d + + +def collect_abs(expr): + """Return ``expr`` with arguments of multiple Abs in a term collected + under a single instance. + + Examples + ======== + + >>> from sympy.simplify.radsimp import collect_abs + >>> from sympy.abc import x + >>> collect_abs(abs(x + 1)/abs(x**2 - 1)) + Abs((x + 1)/(x**2 - 1)) + >>> collect_abs(abs(1/x)) + Abs(1/x) + """ + def _abs(mul): + c, nc = mul.args_cnc() + a = [] + o = [] + for i in c: + if isinstance(i, Abs): + a.append(i.args[0]) + elif isinstance(i, Pow) and isinstance(i.base, Abs) and i.exp.is_real: + a.append(i.base.args[0]**i.exp) + else: + o.append(i) + if len(a) < 2 and not any(i.exp.is_negative for i in a if isinstance(i, Pow)): + return mul + absarg = Mul(*a) + A = Abs(absarg) + args = [A] + args.extend(o) + if not A.has(Abs): + args.extend(nc) + return Mul(*args) + if not isinstance(A, Abs): + # reevaluate and make it unevaluated + A = Abs(absarg, evaluate=False) + args[0] = A + _mulsort(args) + args.extend(nc) # nc always go last + return Mul._from_args(args, is_commutative=not nc) + + return expr.replace( + lambda x: isinstance(x, Mul), + lambda x: _abs(x)).replace( + lambda x: isinstance(x, Pow), + lambda x: _abs(x)) + + +def collect_const(expr, *vars, Numbers=True): + """A non-greedy collection of terms with similar number coefficients in + an Add expr. If ``vars`` is given then only those constants will be + targeted. Although any Number can also be targeted, if this is not + desired set ``Numbers=False`` and no Float or Rational will be collected. + + Parameters + ========== + + expr : SymPy expression + This parameter defines the expression the expression from which + terms with similar coefficients are to be collected. A non-Add + expression is returned as it is. + + vars : variable length collection of Numbers, optional + Specifies the constants to target for collection. Can be multiple in + number. + + Numbers : bool + Specifies to target all instance of + :class:`sympy.core.numbers.Number` class. If ``Numbers=False``, then + no Float or Rational will be collected. + + Returns + ======= + + expr : Expr + Returns an expression with similar coefficient terms collected. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.abc import s, x, y, z + >>> from sympy.simplify.radsimp import collect_const + >>> collect_const(sqrt(3) + sqrt(3)*(1 + sqrt(2))) + sqrt(3)*(sqrt(2) + 2) + >>> collect_const(sqrt(3)*s + sqrt(7)*s + sqrt(3) + sqrt(7)) + (sqrt(3) + sqrt(7))*(s + 1) + >>> s = sqrt(2) + 2 + >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7)) + (sqrt(2) + 3)*(sqrt(3) + sqrt(7)) + >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7), sqrt(3)) + sqrt(7) + sqrt(3)*(sqrt(2) + 3) + sqrt(7)*(sqrt(2) + 2) + + The collection is sign-sensitive, giving higher precedence to the + unsigned values: + + >>> collect_const(x - y - z) + x - (y + z) + >>> collect_const(-y - z) + -(y + z) + >>> collect_const(2*x - 2*y - 2*z, 2) + 2*(x - y - z) + >>> collect_const(2*x - 2*y - 2*z, -2) + 2*x - 2*(y + z) + + See Also + ======== + + collect, collect_sqrt, rcollect + """ + if not expr.is_Add: + return expr + + recurse = False + + if not vars: + recurse = True + vars = set() + for a in expr.args: + for m in Mul.make_args(a): + if m.is_number: + vars.add(m) + else: + vars = sympify(vars) + if not Numbers: + vars = [v for v in vars if not v.is_Number] + + vars = list(ordered(vars)) + for v in vars: + terms = defaultdict(list) + Fv = Factors(v) + for m in Add.make_args(expr): + f = Factors(m) + q, r = f.div(Fv) + if r.is_one: + # only accept this as a true factor if + # it didn't change an exponent from an Integer + # to a non-Integer, e.g. 2/sqrt(2) -> sqrt(2) + # -- we aren't looking for this sort of change + fwas = f.factors.copy() + fnow = q.factors + if not any(k in fwas and fwas[k].is_Integer and not + fnow[k].is_Integer for k in fnow): + terms[v].append(q.as_expr()) + continue + terms[S.One].append(m) + + args = [] + hit = False + uneval = False + for k in ordered(terms): + v = terms[k] + if k is S.One: + args.extend(v) + continue + + if len(v) > 1: + v = Add(*v) + hit = True + if recurse and v != expr: + vars.append(v) + else: + v = v[0] + + # be careful not to let uneval become True unless + # it must be because it's going to be more expensive + # to rebuild the expression as an unevaluated one + if Numbers and k.is_Number and v.is_Add: + args.append(_keep_coeff(k, v, sign=True)) + uneval = True + else: + args.append(k*v) + + if hit: + if uneval: + expr = _unevaluated_Add(*args) + else: + expr = Add(*args) + if not expr.is_Add: + break + + return expr + + +def radsimp(expr, symbolic=True, max_terms=4): + r""" + Rationalize the denominator by removing square roots. + + Explanation + =========== + + The expression returned from radsimp must be used with caution + since if the denominator contains symbols, it will be possible to make + substitutions that violate the assumptions of the simplification process: + that for a denominator matching a + b*sqrt(c), a != +/-b*sqrt(c). (If + there are no symbols, this assumptions is made valid by collecting terms + of sqrt(c) so the match variable ``a`` does not contain ``sqrt(c)``.) If + you do not want the simplification to occur for symbolic denominators, set + ``symbolic`` to False. + + If there are more than ``max_terms`` radical terms then the expression is + returned unchanged. + + Examples + ======== + + >>> from sympy import radsimp, sqrt, Symbol, pprint + >>> from sympy import factor_terms, fraction, signsimp + >>> from sympy.simplify.radsimp import collect_sqrt + >>> from sympy.abc import a, b, c + + >>> radsimp(1/(2 + sqrt(2))) + (2 - sqrt(2))/2 + >>> x,y = map(Symbol, 'xy') + >>> e = ((2 + 2*sqrt(2))*x + (2 + sqrt(8))*y)/(2 + sqrt(2)) + >>> radsimp(e) + sqrt(2)*(x + y) + + No simplification beyond removal of the gcd is done. One might + want to polish the result a little, however, by collecting + square root terms: + + >>> r2 = sqrt(2) + >>> r5 = sqrt(5) + >>> ans = radsimp(1/(y*r2 + x*r2 + a*r5 + b*r5)); pprint(ans) + ___ ___ ___ ___ + \/ 5 *a + \/ 5 *b - \/ 2 *x - \/ 2 *y + ------------------------------------------ + 2 2 2 2 + 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y + + >>> n, d = fraction(ans) + >>> pprint(factor_terms(signsimp(collect_sqrt(n))/d, radical=True)) + ___ ___ + \/ 5 *(a + b) - \/ 2 *(x + y) + ------------------------------------------ + 2 2 2 2 + 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y + + If radicals in the denominator cannot be removed or there is no denominator, + the original expression will be returned. + + >>> radsimp(sqrt(2)*x + sqrt(2)) + sqrt(2)*x + sqrt(2) + + Results with symbols will not always be valid for all substitutions: + + >>> eq = 1/(a + b*sqrt(c)) + >>> eq.subs(a, b*sqrt(c)) + 1/(2*b*sqrt(c)) + >>> radsimp(eq).subs(a, b*sqrt(c)) + nan + + If ``symbolic=False``, symbolic denominators will not be transformed (but + numeric denominators will still be processed): + + >>> radsimp(eq, symbolic=False) + 1/(a + b*sqrt(c)) + + """ + from sympy.core.expr import Expr + from sympy.simplify.simplify import signsimp + + syms = symbols("a:d A:D") + def _num(rterms): + # return the multiplier that will simplify the expression described + # by rterms [(sqrt arg, coeff), ... ] + a, b, c, d, A, B, C, D = syms + if len(rterms) == 2: + reps = dict(list(zip([A, a, B, b], [j for i in rterms for j in i]))) + return ( + sqrt(A)*a - sqrt(B)*b).xreplace(reps) + if len(rterms) == 3: + reps = dict(list(zip([A, a, B, b, C, c], [j for i in rterms for j in i]))) + return ( + (sqrt(A)*a + sqrt(B)*b - sqrt(C)*c)*(2*sqrt(A)*sqrt(B)*a*b - A*a**2 - + B*b**2 + C*c**2)).xreplace(reps) + elif len(rterms) == 4: + reps = dict(list(zip([A, a, B, b, C, c, D, d], [j for i in rterms for j in i]))) + return ((sqrt(A)*a + sqrt(B)*b - sqrt(C)*c - sqrt(D)*d)*(2*sqrt(A)*sqrt(B)*a*b + - A*a**2 - B*b**2 - 2*sqrt(C)*sqrt(D)*c*d + C*c**2 + + D*d**2)*(-8*sqrt(A)*sqrt(B)*sqrt(C)*sqrt(D)*a*b*c*d + A**2*a**4 - + 2*A*B*a**2*b**2 - 2*A*C*a**2*c**2 - 2*A*D*a**2*d**2 + B**2*b**4 - + 2*B*C*b**2*c**2 - 2*B*D*b**2*d**2 + C**2*c**4 - 2*C*D*c**2*d**2 + + D**2*d**4)).xreplace(reps) + elif len(rterms) == 1: + return sqrt(rterms[0][0]) + else: + raise NotImplementedError + + def ispow2(d, log2=False): + if not d.is_Pow: + return False + e = d.exp + if e.is_Rational and e.q == 2 or symbolic and denom(e) == 2: + return True + if log2: + q = 1 + if e.is_Rational: + q = e.q + elif symbolic: + d = denom(e) + if d.is_Integer: + q = d + if q != 1 and log(q, 2).is_Integer: + return True + return False + + def handle(expr): + # Handle first reduces to the case + # expr = 1/d, where d is an add, or d is base**p/2. + # We do this by recursively calling handle on each piece. + from sympy.simplify.simplify import nsimplify + + if expr.is_Atom: + return expr + elif not isinstance(expr, Expr): + return expr.func(*[handle(a) for a in expr.args]) + + n, d = fraction(expr) + + if d.is_Atom and n.is_Atom: + return expr + elif not n.is_Atom: + n = n.func(*[handle(a) for a in n.args]) + return _unevaluated_Mul(n, handle(1/d)) + elif n is not S.One: + return _unevaluated_Mul(n, handle(1/d)) + elif d.is_Mul: + return _unevaluated_Mul(*[handle(1/d) for d in d.args]) + + # By this step, expr is 1/d, and d is not a mul. + if not symbolic and d.free_symbols: + return expr + + if ispow2(d): + d2 = sqrtdenest(sqrt(d.base))**numer(d.exp) + if d2 != d: + return handle(1/d2) + elif d.is_Pow and (d.exp.is_integer or d.base.is_positive): + # (1/d**i) = (1/d)**i + return handle(1/d.base)**d.exp + + if not (d.is_Add or ispow2(d)): + return 1/d.func(*[handle(a) for a in d.args]) + + # handle 1/d treating d as an Add (though it may not be) + + keep = True # keep changes that are made + + # flatten it and collect radicals after checking for special + # conditions + d = _mexpand(d) + + # did it change? + if d.is_Atom: + return 1/d + + # is it a number that might be handled easily? + if d.is_number: + _d = nsimplify(d) + if _d.is_Number and _d.equals(d): + return 1/_d + + while True: + # collect similar terms + collected = defaultdict(list) + for m in Add.make_args(d): # d might have become non-Add + p2 = [] + other = [] + for i in Mul.make_args(m): + if ispow2(i, log2=True): + p2.append(i.base if i.exp is S.Half else i.base**(2*i.exp)) + elif i is S.ImaginaryUnit: + p2.append(S.NegativeOne) + else: + other.append(i) + collected[tuple(ordered(p2))].append(Mul(*other)) + rterms = list(ordered(list(collected.items()))) + rterms = [(Mul(*i), Add(*j)) for i, j in rterms] + nrad = len(rterms) - (1 if rterms[0][0] is S.One else 0) + if nrad < 1: + break + elif nrad > max_terms: + # there may have been invalid operations leading to this point + # so don't keep changes, e.g. this expression is troublesome + # in collecting terms so as not to raise the issue of 2834: + # r = sqrt(sqrt(5) + 5) + # eq = 1/(sqrt(5)*r + 2*sqrt(5)*sqrt(-sqrt(5) + 5) + 5*r) + keep = False + break + if len(rterms) > 4: + # in general, only 4 terms can be removed with repeated squaring + # but other considerations can guide selection of radical terms + # so that radicals are removed + if all(x.is_Integer and (y**2).is_Rational for x, y in rterms): + nd, d = rad_rationalize(S.One, Add._from_args( + [sqrt(x)*y for x, y in rterms])) + n *= nd + else: + # is there anything else that might be attempted? + keep = False + break + from sympy.simplify.powsimp import powsimp, powdenest + + num = powsimp(_num(rterms)) + n *= num + d *= num + d = powdenest(_mexpand(d), force=symbolic) + if d.has(S.Zero, nan, zoo): + return expr + if d.is_Atom: + break + + if not keep: + return expr + return _unevaluated_Mul(n, 1/d) + + if not isinstance(expr, Expr): + return expr.func(*[radsimp(a, symbolic=symbolic, max_terms=max_terms) for a in expr.args]) + + coeff, expr = expr.as_coeff_Add() + expr = expr.normal() + old = fraction(expr) + n, d = fraction(handle(expr)) + if old != (n, d): + if not d.is_Atom: + was = (n, d) + n = signsimp(n, evaluate=False) + d = signsimp(d, evaluate=False) + u = Factors(_unevaluated_Mul(n, 1/d)) + u = _unevaluated_Mul(*[k**v for k, v in u.factors.items()]) + n, d = fraction(u) + if old == (n, d): + n, d = was + n = expand_mul(n) + if d.is_Number or d.is_Add: + n2, d2 = fraction(gcd_terms(_unevaluated_Mul(n, 1/d))) + if d2.is_Number or (d2.count_ops() <= d.count_ops()): + n, d = [signsimp(i) for i in (n2, d2)] + if n.is_Mul and n.args[0].is_Number: + n = n.func(*n.args) + + return coeff + _unevaluated_Mul(n, 1/d) + + +def rad_rationalize(num, den): + """ + Rationalize ``num/den`` by removing square roots in the denominator; + num and den are sum of terms whose squares are positive rationals. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.radsimp import rad_rationalize + >>> rad_rationalize(sqrt(3), 1 + sqrt(2)/3) + (-sqrt(3) + sqrt(6)/3, -7/9) + """ + if not den.is_Add: + return num, den + g, a, b = split_surds(den) + a = a*sqrt(g) + num = _mexpand((a - b)*num) + den = _mexpand(a**2 - b**2) + return rad_rationalize(num, den) + + +def fraction(expr, exact=False): + """Returns a pair with expression's numerator and denominator. + If the given expression is not a fraction then this function + will return the tuple (expr, 1). + + This function will not make any attempt to simplify nested + fractions or to do any term rewriting at all. + + If only one of the numerator/denominator pair is needed then + use numer(expr) or denom(expr) functions respectively. + + >>> from sympy import fraction, Rational, Symbol + >>> from sympy.abc import x, y + + >>> fraction(x/y) + (x, y) + >>> fraction(x) + (x, 1) + + >>> fraction(1/y**2) + (1, y**2) + + >>> fraction(x*y/2) + (x*y, 2) + >>> fraction(Rational(1, 2)) + (1, 2) + + This function will also work fine with assumptions: + + >>> k = Symbol('k', negative=True) + >>> fraction(x * y**k) + (x, y**(-k)) + + If we know nothing about sign of some exponent and ``exact`` + flag is unset, then the exponent's structure will + be analyzed and pretty fraction will be returned: + + >>> from sympy import exp, Mul + >>> fraction(2*x**(-y)) + (2, x**y) + + >>> fraction(exp(-x)) + (1, exp(x)) + + >>> fraction(exp(-x), exact=True) + (exp(-x), 1) + + The ``exact`` flag will also keep any unevaluated Muls from + being evaluated: + + >>> u = Mul(2, x + 1, evaluate=False) + >>> fraction(u) + (2*x + 2, 1) + >>> fraction(u, exact=True) + (2*(x + 1), 1) + """ + expr = sympify(expr) + + numer, denom = [], [] + + for term in Mul.make_args(expr): + if term.is_commutative and (term.is_Pow or isinstance(term, exp)): + b, ex = term.as_base_exp() + if ex.is_negative: + if ex is S.NegativeOne: + denom.append(b) + elif exact: + if ex.is_constant(): + denom.append(Pow(b, -ex)) + else: + numer.append(term) + else: + denom.append(Pow(b, -ex)) + elif ex.is_positive: + numer.append(term) + elif not exact and ex.is_Mul: + n, d = term.as_numer_denom() # this will cause evaluation + if n != 1: + numer.append(n) + denom.append(d) + else: + numer.append(term) + elif term.is_Rational and not term.is_Integer: + if term.p != 1: + numer.append(term.p) + denom.append(term.q) + else: + numer.append(term) + return Mul(*numer, evaluate=not exact), Mul(*denom, evaluate=not exact) + + +def numer(expr, exact=False): # default matches fraction's default + return fraction(expr, exact=exact)[0] + + +def denom(expr, exact=False): # default matches fraction's default + return fraction(expr, exact=exact)[1] + + +def fraction_expand(expr, **hints): + return expr.expand(frac=True, **hints) + + +def numer_expand(expr, **hints): + # default matches fraction's default + a, b = fraction(expr, exact=hints.get('exact', False)) + return a.expand(numer=True, **hints) / b + + +def denom_expand(expr, **hints): + # default matches fraction's default + a, b = fraction(expr, exact=hints.get('exact', False)) + return a / b.expand(denom=True, **hints) + + +expand_numer = numer_expand +expand_denom = denom_expand +expand_fraction = fraction_expand + + +def split_surds(expr): + """ + Split an expression with terms whose squares are positive rationals + into a sum of terms whose surds squared have gcd equal to g + and a sum of terms with surds squared prime with g. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.radsimp import split_surds + >>> split_surds(3*sqrt(3) + sqrt(5)/7 + sqrt(6) + sqrt(10) + sqrt(15)) + (3, sqrt(2) + sqrt(5) + 3, sqrt(5)/7 + sqrt(10)) + """ + args = sorted(expr.args, key=default_sort_key) + coeff_muls = [x.as_coeff_Mul() for x in args] + surds = [x[1]**2 for x in coeff_muls if x[1].is_Pow] + surds.sort(key=default_sort_key) + g, b1, b2 = _split_gcd(*surds) + g2 = g + if not b2 and len(b1) >= 2: + b1n = [x/g for x in b1] + b1n = [x for x in b1n if x != 1] + # only a common factor has been factored; split again + g1, b1n, b2 = _split_gcd(*b1n) + g2 = g*g1 + a1v, a2v = [], [] + for c, s in coeff_muls: + if s.is_Pow and s.exp == S.Half: + s1 = s.base + if s1 in b1: + a1v.append(c*sqrt(s1/g2)) + else: + a2v.append(c*s) + else: + a2v.append(c*s) + a = Add(*a1v) + b = Add(*a2v) + return g2, a, b + + +def _split_gcd(*a): + """ + Split the list of integers ``a`` into a list of integers, ``a1`` having + ``g = gcd(a1)``, and a list ``a2`` whose elements are not divisible by + ``g``. Returns ``g, a1, a2``. + + Examples + ======== + + >>> from sympy.simplify.radsimp import _split_gcd + >>> _split_gcd(55, 35, 22, 14, 77, 10) + (5, [55, 35, 10], [22, 14, 77]) + """ + g = a[0] + b1 = [g] + b2 = [] + for x in a[1:]: + g1 = gcd(g, x) + if g1 == 1: + b2.append(x) + else: + g = g1 + b1.append(x) + return g, b1, b2 diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/ratsimp.py b/.venv/lib/python3.13/site-packages/sympy/simplify/ratsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..95751fab47f585d3ae2e1289f014fba0f2708224 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/ratsimp.py @@ -0,0 +1,222 @@ +from itertools import combinations_with_replacement +from sympy.core import symbols, Add, Dummy +from sympy.core.numbers import Rational +from sympy.polys import cancel, ComputationFailed, parallel_poly_from_expr, reduced, Poly +from sympy.polys.monomials import Monomial, monomial_div +from sympy.polys.polyerrors import DomainError, PolificationFailed +from sympy.utilities.misc import debug, debugf + +def ratsimp(expr): + """ + Put an expression over a common denominator, cancel and reduce. + + Examples + ======== + + >>> from sympy import ratsimp + >>> from sympy.abc import x, y + >>> ratsimp(1/x + 1/y) + (x + y)/(x*y) + """ + + f, g = cancel(expr).as_numer_denom() + try: + Q, r = reduced(f, [g], field=True, expand=False) + except ComputationFailed: + return f/g + + return Add(*Q) + cancel(r/g) + + +def ratsimpmodprime(expr, G, *gens, quick=True, polynomial=False, **args): + """ + Simplifies a rational expression ``expr`` modulo the prime ideal + generated by ``G``. ``G`` should be a Groebner basis of the + ideal. + + Examples + ======== + + >>> from sympy.simplify.ratsimp import ratsimpmodprime + >>> from sympy.abc import x, y + >>> eq = (x + y**5 + y)/(x - y) + >>> ratsimpmodprime(eq, [x*y**5 - x - y], x, y, order='lex') + (-x**2 - x*y - x - y)/(-x**2 + x*y) + + If ``polynomial`` is ``False``, the algorithm computes a rational + simplification which minimizes the sum of the total degrees of + the numerator and the denominator. + + If ``polynomial`` is ``True``, this function just brings numerator and + denominator into a canonical form. This is much faster, but has + potentially worse results. + + References + ========== + + .. [1] M. Monagan, R. Pearce, Rational Simplification Modulo a Polynomial + Ideal, https://dl.acm.org/doi/pdf/10.1145/1145768.1145809 + (specifically, the second algorithm) + """ + from sympy.solvers.solvers import solve + + debug('ratsimpmodprime', expr) + + # usual preparation of polynomials: + + num, denom = cancel(expr).as_numer_denom() + + try: + polys, opt = parallel_poly_from_expr([num, denom] + G, *gens, **args) + except PolificationFailed: + return expr + + domain = opt.domain + + if domain.has_assoc_Field: + opt.domain = domain.get_field() + else: + raise DomainError( + "Cannot compute rational simplification over %s" % domain) + + # compute only once + leading_monomials = [g.LM(opt.order) for g in polys[2:]] + tested = set() + + def staircase(n): + """ + Compute all monomials with degree less than ``n`` that are + not divisible by any element of ``leading_monomials``. + """ + if n == 0: + return [1] + S = [] + for mi in combinations_with_replacement(range(len(opt.gens)), n): + m = [0]*len(opt.gens) + for i in mi: + m[i] += 1 + if all(monomial_div(m, lmg) is None for lmg in + leading_monomials): + S.append(m) + + return [Monomial(s).as_expr(*opt.gens) for s in S] + staircase(n - 1) + + def _ratsimpmodprime(a, b, allsol, N=0, D=0): + r""" + Computes a rational simplification of ``a/b`` which minimizes + the sum of the total degrees of the numerator and the denominator. + + Explanation + =========== + + The algorithm proceeds by looking at ``a * d - b * c`` modulo + the ideal generated by ``G`` for some ``c`` and ``d`` with degree + less than ``a`` and ``b`` respectively. + The coefficients of ``c`` and ``d`` are indeterminates and thus + the coefficients of the normalform of ``a * d - b * c`` are + linear polynomials in these indeterminates. + If these linear polynomials, considered as system of + equations, have a nontrivial solution, then `\frac{a}{b} + \equiv \frac{c}{d}` modulo the ideal generated by ``G``. So, + by construction, the degree of ``c`` and ``d`` is less than + the degree of ``a`` and ``b``, so a simpler representation + has been found. + After a simpler representation has been found, the algorithm + tries to reduce the degree of the numerator and denominator + and returns the result afterwards. + + As an extension, if quick=False, we look at all possible degrees such + that the total degree is less than *or equal to* the best current + solution. We retain a list of all solutions of minimal degree, and try + to find the best one at the end. + """ + c, d = a, b + steps = 0 + + maxdeg = a.total_degree() + b.total_degree() + if quick: + bound = maxdeg - 1 + else: + bound = maxdeg + while N + D <= bound: + if (N, D) in tested: + break + tested.add((N, D)) + + M1 = staircase(N) + M2 = staircase(D) + debugf('%s / %s: %s, %s', (N, D, M1, M2)) + + Cs = symbols("c:%d" % len(M1), cls=Dummy) + Ds = symbols("d:%d" % len(M2), cls=Dummy) + ng = Cs + Ds + + c_hat = Poly( + sum(Cs[i] * M1[i] for i in range(len(M1))), opt.gens + ng) + d_hat = Poly( + sum(Ds[i] * M2[i] for i in range(len(M2))), opt.gens + ng) + + r = reduced(a * d_hat - b * c_hat, G, opt.gens + ng, + order=opt.order, polys=True)[1] + + S = Poly(r, gens=opt.gens).coeffs() + sol = solve(S, Cs + Ds, particular=True, quick=True) + + if sol and not all(s == 0 for s in sol.values()): + c = c_hat.subs(sol) + d = d_hat.subs(sol) + + # The "free" variables occurring before as parameters + # might still be in the substituted c, d, so set them + # to the value chosen before: + c = c.subs(dict(list(zip(Cs + Ds, [1] * (len(Cs) + len(Ds)))))) + d = d.subs(dict(list(zip(Cs + Ds, [1] * (len(Cs) + len(Ds)))))) + + c = Poly(c, opt.gens) + d = Poly(d, opt.gens) + if d == 0: + raise ValueError('Ideal not prime?') + + allsol.append((c_hat, d_hat, S, Cs + Ds)) + if N + D != maxdeg: + allsol = [allsol[-1]] + + break + + steps += 1 + N += 1 + D += 1 + + if steps > 0: + c, d, allsol = _ratsimpmodprime(c, d, allsol, N, D - steps) + c, d, allsol = _ratsimpmodprime(c, d, allsol, N - steps, D) + + return c, d, allsol + + # preprocessing. this improves performance a bit when deg(num) + # and deg(denom) are large: + num = reduced(num, G, opt.gens, order=opt.order)[1] + denom = reduced(denom, G, opt.gens, order=opt.order)[1] + + if polynomial: + return (num/denom).cancel() + + c, d, allsol = _ratsimpmodprime( + Poly(num, opt.gens, domain=opt.domain), Poly(denom, opt.gens, domain=opt.domain), []) + if not quick and allsol: + debugf('Looking for best minimal solution. Got: %s', len(allsol)) + newsol = [] + for c_hat, d_hat, S, ng in allsol: + sol = solve(S, ng, particular=True, quick=False) + # all values of sol should be numbers; if not, solve is broken + newsol.append((c_hat.subs(sol), d_hat.subs(sol))) + c, d = min(newsol, key=lambda x: len(x[0].terms()) + len(x[1].terms())) + + if not domain.is_Field: + cn, c = c.clear_denoms(convert=True) + dn, d = d.clear_denoms(convert=True) + r = Rational(cn, dn) + else: + r = Rational(1) + + return (c*r.q)/(d*r.p) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/simplify.py b/.venv/lib/python3.13/site-packages/sympy/simplify/simplify.py new file mode 100644 index 0000000000000000000000000000000000000000..8b315cc20c19fc10c37b903d16129a7f5579ecd3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/simplify.py @@ -0,0 +1,2164 @@ +from __future__ import annotations + +from typing import overload + +from collections import defaultdict + +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core import (Basic, S, Add, Mul, Pow, Symbol, sympify, + expand_func, Function, Dummy, Expr, factor_terms, + expand_power_exp, Eq) +from sympy.core.exprtools import factor_nc +from sympy.core.parameters import global_parameters +from sympy.core.function import (expand_log, count_ops, _mexpand, + nfloat, expand_mul, expand) +from sympy.core.numbers import Float, I, pi, Rational, equal_valued +from sympy.core.relational import Relational +from sympy.core.rules import Transform +from sympy.core.sorting import ordered +from sympy.core.sympify import _sympify +from sympy.core.traversal import bottom_up as _bottom_up, walk as _walk +from sympy.functions import gamma, exp, sqrt, log, exp_polar, re +from sympy.functions.combinatorial.factorials import CombinatorialFunction +from sympy.functions.elementary.complexes import unpolarify, Abs, sign +from sympy.functions.elementary.exponential import ExpBase +from sympy.functions.elementary.hyperbolic import HyperbolicFunction +from sympy.functions.elementary.integers import ceiling +from sympy.functions.elementary.piecewise import (Piecewise, piecewise_fold, + piecewise_simplify) +from sympy.functions.elementary.trigonometric import TrigonometricFunction +from sympy.functions.special.bessel import (BesselBase, besselj, besseli, + besselk, bessely, jn) +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import Boolean +from sympy.matrices.expressions import (MatrixExpr, MatAdd, MatMul, + MatPow, MatrixSymbol) +from sympy.polys import together, cancel, factor +from sympy.polys.numberfields.minpoly import _is_sum_surds, _minimal_polynomial_sq +from sympy.sets.sets import Set +from sympy.simplify.combsimp import combsimp +from sympy.simplify.cse_opts import sub_pre, sub_post +from sympy.simplify.hyperexpand import hyperexpand +from sympy.simplify.powsimp import powsimp +from sympy.simplify.radsimp import radsimp, fraction, collect_abs +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.simplify.trigsimp import trigsimp, exptrigsimp +from sympy.utilities.decorator import deprecated +from sympy.utilities.iterables import has_variety, sift, subsets, iterable +from sympy.utilities.misc import as_int + +import mpmath + + +def separatevars(expr, symbols=[], dict=False, force=False): + """ + Separates variables in an expression, if possible. By + default, it separates with respect to all symbols in an + expression and collects constant coefficients that are + independent of symbols. + + Explanation + =========== + + If ``dict=True`` then the separated terms will be returned + in a dictionary keyed to their corresponding symbols. + By default, all symbols in the expression will appear as + keys; if symbols are provided, then all those symbols will + be used as keys, and any terms in the expression containing + other symbols or non-symbols will be returned keyed to the + string 'coeff'. (Passing None for symbols will return the + expression in a dictionary keyed to 'coeff'.) + + If ``force=True``, then bases of powers will be separated regardless + of assumptions on the symbols involved. + + Notes + ===== + + The order of the factors is determined by Mul, so that the + separated expressions may not necessarily be grouped together. + + Although factoring is necessary to separate variables in some + expressions, it is not necessary in all cases, so one should not + count on the returned factors being factored. + + Examples + ======== + + >>> from sympy.abc import x, y, z, alpha + >>> from sympy import separatevars, sin + >>> separatevars((x*y)**y) + (x*y)**y + >>> separatevars((x*y)**y, force=True) + x**y*y**y + + >>> e = 2*x**2*z*sin(y)+2*z*x**2 + >>> separatevars(e) + 2*x**2*z*(sin(y) + 1) + >>> separatevars(e, symbols=(x, y), dict=True) + {'coeff': 2*z, x: x**2, y: sin(y) + 1} + >>> separatevars(e, [x, y, alpha], dict=True) + {'coeff': 2*z, alpha: 1, x: x**2, y: sin(y) + 1} + + If the expression is not really separable, or is only partially + separable, separatevars will do the best it can to separate it + by using factoring. + + >>> separatevars(x + x*y - 3*x**2) + -x*(3*x - y - 1) + + If the expression is not separable then expr is returned unchanged + or (if dict=True) then None is returned. + + >>> eq = 2*x + y*sin(x) + >>> separatevars(eq) == eq + True + >>> separatevars(2*x + y*sin(x), symbols=(x, y), dict=True) is None + True + + """ + expr = sympify(expr) + if dict: + return _separatevars_dict(_separatevars(expr, force), symbols) + else: + return _separatevars(expr, force) + + +def _separatevars(expr, force): + if isinstance(expr, Abs): + arg = expr.args[0] + if arg.is_Mul and not arg.is_number: + s = separatevars(arg, dict=True, force=force) + if s is not None: + return Mul(*map(expr.func, s.values())) + else: + return expr + + if len(expr.free_symbols) < 2: + return expr + + # don't destroy a Mul since much of the work may already be done + if expr.is_Mul: + args = list(expr.args) + changed = False + for i, a in enumerate(args): + args[i] = separatevars(a, force) + changed = changed or args[i] != a + if changed: + expr = expr.func(*args) + return expr + + # get a Pow ready for expansion + if expr.is_Pow and expr.base != S.Exp1: + expr = Pow(separatevars(expr.base, force=force), expr.exp) + + # First try other expansion methods + expr = expr.expand(mul=False, multinomial=False, force=force) + + _expr, reps = posify(expr) if force else (expr, {}) + expr = factor(_expr).subs(reps) + + if not expr.is_Add: + return expr + + # Find any common coefficients to pull out + args = list(expr.args) + commonc = args[0].args_cnc(cset=True, warn=False)[0] + for i in args[1:]: + commonc &= i.args_cnc(cset=True, warn=False)[0] + commonc = Mul(*commonc) + commonc = commonc.as_coeff_Mul()[1] # ignore constants + commonc_set = commonc.args_cnc(cset=True, warn=False)[0] + + # remove them + for i, a in enumerate(args): + c, nc = a.args_cnc(cset=True, warn=False) + c = c - commonc_set + args[i] = Mul(*c)*Mul(*nc) + nonsepar = Add(*args) + + if len(nonsepar.free_symbols) > 1: + _expr = nonsepar + _expr, reps = posify(_expr) if force else (_expr, {}) + _expr = (factor(_expr)).subs(reps) + + if not _expr.is_Add: + nonsepar = _expr + + return commonc*nonsepar + + +def _separatevars_dict(expr, symbols): + if symbols: + if not all(t.is_Atom for t in symbols): + raise ValueError("symbols must be Atoms.") + symbols = list(symbols) + elif symbols is None: + return {'coeff': expr} + else: + symbols = list(expr.free_symbols) + if not symbols: + return None + + ret = {i: [] for i in symbols + ['coeff']} + + for i in Mul.make_args(expr): + expsym = i.free_symbols + intersection = set(symbols).intersection(expsym) + if len(intersection) > 1: + return None + if len(intersection) == 0: + # There are no symbols, so it is part of the coefficient + ret['coeff'].append(i) + else: + ret[intersection.pop()].append(i) + + # rebuild + for k, v in ret.items(): + ret[k] = Mul(*v) + + return ret + + +def posify(eq): + """Return ``eq`` (with generic symbols made positive) and a + dictionary containing the mapping between the old and new + symbols. + + Explanation + =========== + + Any symbol that has positive=None will be replaced with a positive dummy + symbol having the same name. This replacement will allow more symbolic + processing of expressions, especially those involving powers and + logarithms. + + A dictionary that can be sent to subs to restore ``eq`` to its original + symbols is also returned. + + >>> from sympy import posify, Symbol, log, solve + >>> from sympy.abc import x + >>> posify(x + Symbol('p', positive=True) + Symbol('n', negative=True)) + (_x + n + p, {_x: x}) + + >>> eq = 1/x + >>> log(eq).expand() + log(1/x) + >>> log(posify(eq)[0]).expand() + -log(_x) + >>> p, rep = posify(eq) + >>> log(p).expand().subs(rep) + -log(x) + + It is possible to apply the same transformations to an iterable + of expressions: + + >>> eq = x**2 - 4 + >>> solve(eq, x) + [-2, 2] + >>> eq_x, reps = posify([eq, x]); eq_x + [_x**2 - 4, _x] + >>> solve(*eq_x) + [2] + """ + eq = sympify(eq) + if not isinstance(eq, Basic) and iterable(eq): + f = type(eq) + eq = list(eq) + syms = set() + for e in eq: + syms = syms.union(e.atoms(Symbol)) + reps = {} + for s in syms: + reps.update({v: k for k, v in posify(s)[1].items()}) + for i, e in enumerate(eq): + eq[i] = e.subs(reps) + return f(eq), {r: s for s, r in reps.items()} + + reps = {s: Dummy(s.name, positive=True, **s.assumptions0) + for s in eq.free_symbols if s.is_positive is None} + eq = eq.subs(reps) + return eq, {r: s for s, r in reps.items()} + + +def hypersimp(f, k): + """Given combinatorial term f(k) simplify its consecutive term ratio + i.e. f(k+1)/f(k). The input term can be composed of functions and + integer sequences which have equivalent representation in terms + of gamma special function. + + Explanation + =========== + + The algorithm performs three basic steps: + + 1. Rewrite all functions in terms of gamma, if possible. + + 2. Rewrite all occurrences of gamma in terms of products + of gamma and rising factorial with integer, absolute + constant exponent. + + 3. Perform simplification of nested fractions, powers + and if the resulting expression is a quotient of + polynomials, reduce their total degree. + + If f(k) is hypergeometric then as result we arrive with a + quotient of polynomials of minimal degree. Otherwise None + is returned. + + For more information on the implemented algorithm refer to: + + 1. W. Koepf, Algorithms for m-fold Hypergeometric Summation, + Journal of Symbolic Computation (1995) 20, 399-417 + """ + f = sympify(f) + + g = f.subs(k, k + 1) / f + + g = g.rewrite(gamma) + if g.has(Piecewise): + g = piecewise_fold(g) + g = g.args[-1][0] + g = expand_func(g) + g = powsimp(g, deep=True, combine='exp') + + if g.is_rational_function(k): + return simplify(g, ratio=S.Infinity) + else: + return None + + +def hypersimilar(f, g, k): + """ + Returns True if ``f`` and ``g`` are hyper-similar. + + Explanation + =========== + + Similarity in hypergeometric sense means that a quotient of + f(k) and g(k) is a rational function in ``k``. This procedure + is useful in solving recurrence relations. + + For more information see hypersimp(). + + """ + f, g = list(map(sympify, (f, g))) + + h = (f/g).rewrite(gamma) + h = h.expand(func=True, basic=False) + + return h.is_rational_function(k) + + +def signsimp(expr, evaluate=None): + """Make all Add sub-expressions canonical wrt sign. + + Explanation + =========== + + If an Add subexpression, ``a``, can have a sign extracted, + as determined by could_extract_minus_sign, it is replaced + with Mul(-1, a, evaluate=False). This allows signs to be + extracted from powers and products. + + Examples + ======== + + >>> from sympy import signsimp, exp, symbols + >>> from sympy.abc import x, y + >>> i = symbols('i', odd=True) + >>> n = -1 + 1/x + >>> n/x/(-n)**2 - 1/n/x + (-1 + 1/x)/(x*(1 - 1/x)**2) - 1/(x*(-1 + 1/x)) + >>> signsimp(_) + 0 + >>> x*n + x*-n + x*(-1 + 1/x) + x*(1 - 1/x) + >>> signsimp(_) + 0 + + Since powers automatically handle leading signs + + >>> (-2)**i + -2**i + + signsimp can be used to put the base of a power with an integer + exponent into canonical form: + + >>> n**i + (-1 + 1/x)**i + + By default, signsimp does not leave behind any hollow simplification: + if making an Add canonical wrt sign didn't change the expression, the + original Add is restored. If this is not desired then the keyword + ``evaluate`` can be set to False: + + >>> e = exp(y - x) + >>> signsimp(e) == e + True + >>> signsimp(e, evaluate=False) + exp(-(x - y)) + + """ + if evaluate is None: + evaluate = global_parameters.evaluate + expr = sympify(expr) + if not isinstance(expr, (Expr, Relational)) or expr.is_Atom: + return expr + # get rid of an pre-existing unevaluation regarding sign + e = expr.replace(lambda x: x.is_Mul and -(-x) != x, lambda x: -(-x)) + e = sub_post(sub_pre(e)) + if not isinstance(e, (Expr, Relational)) or e.is_Atom: + return e + if e.is_Add: + rv = e.func(*[signsimp(a) for a in e.args]) + if not evaluate and isinstance(rv, Add + ) and rv.could_extract_minus_sign(): + return Mul(S.NegativeOne, -rv, evaluate=False) + return rv + if evaluate: + e = e.replace(lambda x: x.is_Mul and -(-x) != x, lambda x: -(-x)) + return e + + +@overload +def simplify(expr: Expr, **kwargs) -> Expr: ... +@overload +def simplify(expr: Boolean, **kwargs) -> Boolean: ... +@overload +def simplify(expr: Set, **kwargs) -> Set: ... +@overload +def simplify(expr: Basic, **kwargs) -> Basic: ... + +def simplify(expr, ratio=1.7, measure=count_ops, rational=False, inverse=False, doit=True, **kwargs): + """Simplifies the given expression. + + Explanation + =========== + + Simplification is not a well defined term and the exact strategies + this function tries can change in the future versions of SymPy. If + your algorithm relies on "simplification" (whatever it is), try to + determine what you need exactly - is it powsimp()?, radsimp()?, + together()?, logcombine()?, or something else? And use this particular + function directly, because those are well defined and thus your algorithm + will be robust. + + Nonetheless, especially for interactive use, or when you do not know + anything about the structure of the expression, simplify() tries to apply + intelligent heuristics to make the input expression "simpler". For + example: + + >>> from sympy import simplify, cos, sin + >>> from sympy.abc import x, y + >>> a = (x + x**2)/(x*sin(y)**2 + x*cos(y)**2) + >>> a + (x**2 + x)/(x*sin(y)**2 + x*cos(y)**2) + >>> simplify(a) + x + 1 + + Note that we could have obtained the same result by using specific + simplification functions: + + >>> from sympy import trigsimp, cancel + >>> trigsimp(a) + (x**2 + x)/x + >>> cancel(_) + x + 1 + + In some cases, applying :func:`simplify` may actually result in some more + complicated expression. The default ``ratio=1.7`` prevents more extreme + cases: if (result length)/(input length) > ratio, then input is returned + unmodified. The ``measure`` parameter lets you specify the function used + to determine how complex an expression is. The function should take a + single argument as an expression and return a number such that if + expression ``a`` is more complex than expression ``b``, then + ``measure(a) > measure(b)``. The default measure function is + :func:`~.count_ops`, which returns the total number of operations in the + expression. + + For example, if ``ratio=1``, ``simplify`` output cannot be longer + than input. + + :: + + >>> from sympy import sqrt, simplify, count_ops, oo + >>> root = 1/(sqrt(2)+3) + + Since ``simplify(root)`` would result in a slightly longer expression, + root is returned unchanged instead:: + + >>> simplify(root, ratio=1) == root + True + + If ``ratio=oo``, simplify will be applied anyway:: + + >>> count_ops(simplify(root, ratio=oo)) > count_ops(root) + True + + Note that the shortest expression is not necessary the simplest, so + setting ``ratio`` to 1 may not be a good idea. + Heuristically, the default value ``ratio=1.7`` seems like a reasonable + choice. + + You can easily define your own measure function based on what you feel + should represent the "size" or "complexity" of the input expression. Note + that some choices, such as ``lambda expr: len(str(expr))`` may appear to be + good metrics, but have other problems (in this case, the measure function + may slow down simplify too much for very large expressions). If you do not + know what a good metric would be, the default, ``count_ops``, is a good + one. + + For example: + + >>> from sympy import symbols, log + >>> a, b = symbols('a b', positive=True) + >>> g = log(a) + log(b) + log(a)*log(1/b) + >>> h = simplify(g) + >>> h + log(a*b**(1 - log(a))) + >>> count_ops(g) + 8 + >>> count_ops(h) + 5 + + So you can see that ``h`` is simpler than ``g`` using the count_ops metric. + However, we may not like how ``simplify`` (in this case, using + ``logcombine``) has created the ``b**(log(1/a) + 1)`` term. A simple way + to reduce this would be to give more weight to powers as operations in + ``count_ops``. We can do this by using the ``visual=True`` option: + + >>> print(count_ops(g, visual=True)) + 2*ADD + DIV + 4*LOG + MUL + >>> print(count_ops(h, visual=True)) + 2*LOG + MUL + POW + SUB + + >>> from sympy import Symbol, S + >>> def my_measure(expr): + ... POW = Symbol('POW') + ... # Discourage powers by giving POW a weight of 10 + ... count = count_ops(expr, visual=True).subs(POW, 10) + ... # Every other operation gets a weight of 1 (the default) + ... count = count.replace(Symbol, type(S.One)) + ... return count + >>> my_measure(g) + 8 + >>> my_measure(h) + 14 + >>> 15./8 > 1.7 # 1.7 is the default ratio + True + >>> simplify(g, measure=my_measure) + -log(a)*log(b) + log(a) + log(b) + + Note that because ``simplify()`` internally tries many different + simplification strategies and then compares them using the measure + function, we get a completely different result that is still different + from the input expression by doing this. + + If ``rational=True``, Floats will be recast as Rationals before simplification. + If ``rational=None``, Floats will be recast as Rationals but the result will + be recast as Floats. If rational=False(default) then nothing will be done + to the Floats. + + If ``inverse=True``, it will be assumed that a composition of inverse + functions, such as sin and asin, can be cancelled in any order. + For example, ``asin(sin(x))`` will yield ``x`` without checking whether + x belongs to the set where this relation is true. The default is + False. + + Note that ``simplify()`` automatically calls ``doit()`` on the final + expression. You can avoid this behavior by passing ``doit=False`` as + an argument. + + Also, it should be noted that simplifying a boolean expression is not + well defined. If the expression prefers automatic evaluation (such as + :obj:`~.Eq()` or :obj:`~.Or()`), simplification will return ``True`` or + ``False`` if truth value can be determined. If the expression is not + evaluated by default (such as :obj:`~.Predicate()`), simplification will + not reduce it and you should use :func:`~.refine` or :func:`~.ask` + function. This inconsistency will be resolved in future version. + + See Also + ======== + + sympy.assumptions.refine.refine : Simplification using assumptions. + sympy.assumptions.ask.ask : Query for boolean expressions using assumptions. + """ + + def shorter(*choices): + """ + Return the choice that has the fewest ops. In case of a tie, + the expression listed first is selected. + """ + if not has_variety(choices): + return choices[0] + return min(choices, key=measure) + + def done(e): + rv = e.doit() if doit else e + return shorter(rv, collect_abs(rv)) + + expr = sympify(expr, rational=rational) + kwargs = { + "ratio": kwargs.get('ratio', ratio), + "measure": kwargs.get('measure', measure), + "rational": kwargs.get('rational', rational), + "inverse": kwargs.get('inverse', inverse), + "doit": kwargs.get('doit', doit)} + # no routine for Expr needs to check for is_zero + if isinstance(expr, Expr) and expr.is_zero: + return S.Zero if not expr.is_Number else expr + + _eval_simplify = getattr(expr, '_eval_simplify', None) + if _eval_simplify is not None: + return _eval_simplify(**kwargs) + + original_expr = expr = collect_abs(signsimp(expr)) + + if not isinstance(expr, Basic) or not expr.args: # XXX: temporary hack + return expr + + if inverse and expr.has(Function): + expr = inversecombine(expr) + if not expr.args: # simplified to atomic + return expr + + # do deep simplification + handled = Add, Mul, Pow, ExpBase + expr = expr.replace( + # here, checking for x.args is not enough because Basic has + # args but Basic does not always play well with replace, e.g. + # when simultaneous is True found expressions will be masked + # off with a Dummy but not all Basic objects in an expression + # can be replaced with a Dummy + lambda x: isinstance(x, Expr) and x.args and not isinstance( + x, handled), + lambda x: x.func(*[simplify(i, **kwargs) for i in x.args]), + simultaneous=False) + if not isinstance(expr, handled): + return done(expr) + + if not expr.is_commutative: + expr = nc_simplify(expr) + + # TODO: Apply different strategies, considering expression pattern: + # is it a purely rational function? Is there any trigonometric function?... + # See also https://github.com/sympy/sympy/pull/185. + + # rationalize Floats + floats = False + if rational is not False and expr.has(Float): + floats = True + expr = nsimplify(expr, rational=True) + + expr = _bottom_up(expr, lambda w: getattr(w, 'normal', lambda: w)()) + expr = Mul(*powsimp(expr).as_content_primitive()) + _e = cancel(expr) + expr1 = shorter(_e, _mexpand(_e).cancel()) # issue 6829 + expr2 = shorter(together(expr, deep=True), together(expr1, deep=True)) + + if ratio is S.Infinity: + expr = expr2 + else: + expr = shorter(expr2, expr1, expr) + if not isinstance(expr, Basic): # XXX: temporary hack + return expr + + expr = factor_terms(expr, sign=False) + + # must come before `Piecewise` since this introduces more `Piecewise` terms + if expr.has(sign): + expr = expr.rewrite(Abs) + + # Deal with Piecewise separately to avoid recursive growth of expressions + if expr.has(Piecewise): + # Fold into a single Piecewise + expr = piecewise_fold(expr) + # Apply doit, if doit=True + expr = done(expr) + # Still a Piecewise? + if expr.has(Piecewise): + # Fold into a single Piecewise, in case doit lead to some + # expressions being Piecewise + expr = piecewise_fold(expr) + # kroneckersimp also affects Piecewise + if expr.has(KroneckerDelta): + expr = kroneckersimp(expr) + # Still a Piecewise? + if expr.has(Piecewise): + # Do not apply doit on the segments as it has already + # been done above, but simplify + expr = piecewise_simplify(expr, deep=True, doit=False) + # Still a Piecewise? + if expr.has(Piecewise): + # Try factor common terms + expr = shorter(expr, factor_terms(expr)) + # As all expressions have been simplified above with the + # complete simplify, nothing more needs to be done here + return expr + + # hyperexpand automatically only works on hypergeometric terms + # Do this after the Piecewise part to avoid recursive expansion + expr = hyperexpand(expr) + + if expr.has(KroneckerDelta): + expr = kroneckersimp(expr) + + if expr.has(BesselBase): + expr = besselsimp(expr) + + if expr.has(TrigonometricFunction, HyperbolicFunction): + expr = trigsimp(expr, deep=True) + + if expr.has(log): + expr = shorter(expand_log(expr, deep=True), logcombine(expr)) + + if expr.has(CombinatorialFunction, gamma): + # expression with gamma functions or non-integer arguments is + # automatically passed to gammasimp + expr = combsimp(expr) + + if expr.has(Sum): + expr = sum_simplify(expr, **kwargs) + + if expr.has(Integral): + expr = expr.xreplace({ + i: factor_terms(i) for i in expr.atoms(Integral)}) + + if expr.has(Product): + expr = product_simplify(expr, **kwargs) + + from sympy.physics.units import Quantity + + if expr.has(Quantity): + from sympy.physics.units.util import quantity_simplify + expr = quantity_simplify(expr) + + short = shorter(powsimp(expr, combine='exp', deep=True), powsimp(expr), expr) + short = shorter(short, cancel(short)) + short = shorter(short, factor_terms(short), expand_power_exp(expand_mul(short))) + if short.has(TrigonometricFunction, HyperbolicFunction, ExpBase, exp): + short = exptrigsimp(short) + + # get rid of hollow 2-arg Mul factorization + hollow_mul = Transform( + lambda x: Mul(*x.args), + lambda x: + x.is_Mul and + len(x.args) == 2 and + x.args[0].is_Number and + x.args[1].is_Add and + x.is_commutative) + expr = short.xreplace(hollow_mul) + + numer, denom = expr.as_numer_denom() + if denom.is_Add: + n, d = fraction(radsimp(1/denom, symbolic=False, max_terms=1)) + if n is not S.One: + expr = (numer*n).expand()/d + + if expr.could_extract_minus_sign(): + n, d = fraction(expr) + if d != 0: + expr = signsimp(-n/(-d)) + + if measure(expr) > ratio*measure(original_expr): + expr = original_expr + + # restore floats + if floats and rational is None: + expr = nfloat(expr, exponent=False) + + return done(expr) + + +def sum_simplify(s, **kwargs): + """Main function for Sum simplification""" + if not isinstance(s, Add): + s = s.xreplace({a: sum_simplify(a, **kwargs) + for a in s.atoms(Add) if a.has(Sum)}) + s = expand(s) + if not isinstance(s, Add): + return s + + terms = s.args + s_t = [] # Sum Terms + o_t = [] # Other Terms + + for term in terms: + sum_terms, other = sift(Mul.make_args(term), + lambda i: isinstance(i, Sum), binary=True) + if not sum_terms: + o_t.append(term) + continue + other = [Mul(*other)] + s_t.append(Mul(*(other + [s._eval_simplify(**kwargs) for s in sum_terms]))) + + result = Add(sum_combine(s_t), *o_t) + + return result + + +def sum_combine(s_t): + """Helper function for Sum simplification + + Attempts to simplify a list of sums, by combining limits / sum function's + returns the simplified sum + """ + used = [False] * len(s_t) + + for method in range(2): + for i, s_term1 in enumerate(s_t): + if not used[i]: + for j, s_term2 in enumerate(s_t): + if not used[j] and i != j: + temp = sum_add(s_term1, s_term2, method) + if isinstance(temp, (Sum, Mul)): + s_t[i] = temp + s_term1 = s_t[i] + used[j] = True + + result = S.Zero + for i, s_term in enumerate(s_t): + if not used[i]: + result = Add(result, s_term) + + return result + +def factor_sum(self, limits=None, radical=False, clear=False, fraction=False, sign=True): + """Return Sum with constant factors extracted. + + If ``limits`` is specified then ``self`` is the summand; the other + keywords are passed to ``factor_terms``. + + Examples + ======== + + >>> from sympy import Sum + >>> from sympy.abc import x, y + >>> from sympy.simplify.simplify import factor_sum + >>> s = Sum(x*y, (x, 1, 3)) + >>> factor_sum(s) + y*Sum(x, (x, 1, 3)) + >>> factor_sum(s.function, s.limits) + y*Sum(x, (x, 1, 3)) + """ + + # XXX deprecate in favor of direct call to factor_terms + kwargs = {"radical": radical, "clear": clear, + "fraction": fraction, "sign": sign} + expr = Sum(self, *limits) if limits else self + return factor_terms(expr, **kwargs) + + +def sum_add(self, other, method=0): + """Helper function for Sum simplification""" + #we know this is something in terms of a constant * a sum + #so we temporarily put the constants inside for simplification + #then simplify the result + def __refactor(val): + args = Mul.make_args(val) + sumv = next(x for x in args if isinstance(x, Sum)) + constant = Mul(*[x for x in args if x != sumv]) + return Sum(constant * sumv.function, *sumv.limits) + + if isinstance(self, Mul): + rself = __refactor(self) + else: + rself = self + + if isinstance(other, Mul): + rother = __refactor(other) + else: + rother = other + + if type(rself) is type(rother): + if method == 0: + if rself.limits == rother.limits: + return factor_sum(Sum(rself.function + rother.function, *rself.limits)) + elif method == 1: + if simplify(rself.function - rother.function) == 0: + if len(rself.limits) == len(rother.limits) == 1: + i = rself.limits[0][0] + x1 = rself.limits[0][1] + y1 = rself.limits[0][2] + j = rother.limits[0][0] + x2 = rother.limits[0][1] + y2 = rother.limits[0][2] + + if i == j: + if x2 == y1 + 1: + return factor_sum(Sum(rself.function, (i, x1, y2))) + elif x1 == y2 + 1: + return factor_sum(Sum(rself.function, (i, x2, y1))) + + return Add(self, other) + + +def product_simplify(s, **kwargs): + """Main function for Product simplification""" + terms = Mul.make_args(s) + p_t = [] # Product Terms + o_t = [] # Other Terms + + deep = kwargs.get('deep', True) + for term in terms: + if isinstance(term, Product): + if deep: + p_t.append(Product(term.function.simplify(**kwargs), + *term.limits)) + else: + p_t.append(term) + else: + o_t.append(term) + + used = [False] * len(p_t) + + for method in range(2): + for i, p_term1 in enumerate(p_t): + if not used[i]: + for j, p_term2 in enumerate(p_t): + if not used[j] and i != j: + tmp_prod = product_mul(p_term1, p_term2, method) + if isinstance(tmp_prod, Product): + p_t[i] = tmp_prod + used[j] = True + + result = Mul(*o_t) + + for i, p_term in enumerate(p_t): + if not used[i]: + result = Mul(result, p_term) + + return result + + +def product_mul(self, other, method=0): + """Helper function for Product simplification""" + if type(self) is type(other): + if method == 0: + if self.limits == other.limits: + return Product(self.function * other.function, *self.limits) + elif method == 1: + if simplify(self.function - other.function) == 0: + if len(self.limits) == len(other.limits) == 1: + i = self.limits[0][0] + x1 = self.limits[0][1] + y1 = self.limits[0][2] + j = other.limits[0][0] + x2 = other.limits[0][1] + y2 = other.limits[0][2] + + if i == j: + if x2 == y1 + 1: + return Product(self.function, (i, x1, y2)) + elif x1 == y2 + 1: + return Product(self.function, (i, x2, y1)) + + return Mul(self, other) + + +def _nthroot_solve(p, n, prec): + """ + helper function for ``nthroot`` + It denests ``p**Rational(1, n)`` using its minimal polynomial + """ + from sympy.solvers import solve + while n % 2 == 0: + p = sqrtdenest(sqrt(p)) + n = n // 2 + if n == 1: + return p + pn = p**Rational(1, n) + x = Symbol('x') + f = _minimal_polynomial_sq(p, n, x) + if f is None: + return None + sols = solve(f, x) + for sol in sols: + if abs(sol - pn).n() < 1./10**prec: + sol = sqrtdenest(sol) + if _mexpand(sol**n) == p: + return sol + + +def logcombine(expr, force=False): + """ + Takes logarithms and combines them using the following rules: + + - log(x) + log(y) == log(x*y) if both are positive + - a*log(x) == log(x**a) if x is positive and a is real + + If ``force`` is ``True`` then the assumptions above will be assumed to hold if + there is no assumption already in place on a quantity. For example, if + ``a`` is imaginary or the argument negative, force will not perform a + combination but if ``a`` is a symbol with no assumptions the change will + take place. + + Examples + ======== + + >>> from sympy import Symbol, symbols, log, logcombine, I + >>> from sympy.abc import a, x, y, z + >>> logcombine(a*log(x) + log(y) - log(z)) + a*log(x) + log(y) - log(z) + >>> logcombine(a*log(x) + log(y) - log(z), force=True) + log(x**a*y/z) + >>> x,y,z = symbols('x,y,z', positive=True) + >>> a = Symbol('a', real=True) + >>> logcombine(a*log(x) + log(y) - log(z)) + log(x**a*y/z) + + The transformation is limited to factors and/or terms that + contain logs, so the result depends on the initial state of + expansion: + + >>> eq = (2 + 3*I)*log(x) + >>> logcombine(eq, force=True) == eq + True + >>> logcombine(eq.expand(), force=True) + log(x**2) + I*log(x**3) + + See Also + ======== + + posify: replace all symbols with symbols having positive assumptions + sympy.core.function.expand_log: expand the logarithms of products + and powers; the opposite of logcombine + + """ + + def f(rv): + if not (rv.is_Add or rv.is_Mul): + return rv + + def gooda(a): + # bool to tell whether the leading ``a`` in ``a*log(x)`` + # could appear as log(x**a) + return (a is not S.NegativeOne and # -1 *could* go, but we disallow + (a.is_extended_real or force and a.is_extended_real is not False)) + + def goodlog(l): + # bool to tell whether log ``l``'s argument can combine with others + a = l.args[0] + return a.is_positive or force and a.is_nonpositive is not False + + other = [] + logs = [] + log1 = defaultdict(list) + for a in Add.make_args(rv): + if isinstance(a, log) and goodlog(a): + log1[()].append(([], a)) + elif not a.is_Mul: + other.append(a) + else: + ot = [] + co = [] + lo = [] + for ai in a.args: + if ai.is_Rational and ai < 0: + ot.append(S.NegativeOne) + co.append(-ai) + elif isinstance(ai, log) and goodlog(ai): + lo.append(ai) + elif gooda(ai): + co.append(ai) + else: + ot.append(ai) + if len(lo) > 1: + logs.append((ot, co, lo)) + elif lo: + log1[tuple(ot)].append((co, lo[0])) + else: + other.append(a) + + # if there is only one log in other, put it with the + # good logs + if len(other) == 1 and isinstance(other[0], log): + log1[()].append(([], other.pop())) + # if there is only one log at each coefficient and none have + # an exponent to place inside the log then there is nothing to do + if not logs and all(len(log1[k]) == 1 and log1[k][0] == [] for k in log1): + return rv + + # collapse multi-logs as far as possible in a canonical way + # TODO: see if x*log(a)+x*log(a)*log(b) -> x*log(a)*(1+log(b))? + # -- in this case, it's unambiguous, but if it were were a log(c) in + # each term then it's arbitrary whether they are grouped by log(a) or + # by log(c). So for now, just leave this alone; it's probably better to + # let the user decide + for o, e, l in logs: + l = list(ordered(l)) + e = log(l.pop(0).args[0]**Mul(*e)) + while l: + li = l.pop(0) + e = log(li.args[0]**e) + c, l = Mul(*o), e + if isinstance(l, log): # it should be, but check to be sure + log1[(c,)].append(([], l)) + else: + other.append(c*l) + + # logs that have the same coefficient can multiply + for k in list(log1.keys()): + log1[Mul(*k)] = log(logcombine(Mul(*[ + l.args[0]**Mul(*c) for c, l in log1.pop(k)]), + force=force), evaluate=False) + + # logs that have oppositely signed coefficients can divide + for k in ordered(list(log1.keys())): + if k not in log1: # already popped as -k + continue + if -k in log1: + # figure out which has the minus sign; the one with + # more op counts should be the one + num, den = k, -k + if num.count_ops() > den.count_ops(): + num, den = den, num + other.append( + num*log(log1.pop(num).args[0]/log1.pop(den).args[0], + evaluate=False)) + else: + other.append(k*log1.pop(k)) + + return Add(*other) + + return _bottom_up(expr, f) + + +def inversecombine(expr): + """Simplify the composition of a function and its inverse. + + Explanation + =========== + + No attention is paid to whether the inverse is a left inverse or a + right inverse; thus, the result will in general not be equivalent + to the original expression. + + Examples + ======== + + >>> from sympy.simplify.simplify import inversecombine + >>> from sympy import asin, sin, log, exp + >>> from sympy.abc import x + >>> inversecombine(asin(sin(x))) + x + >>> inversecombine(2*log(exp(3*x))) + 6*x + """ + + def f(rv): + if isinstance(rv, log): + if isinstance(rv.args[0], exp) or (rv.args[0].is_Pow and rv.args[0].base == S.Exp1): + rv = rv.args[0].exp + elif rv.is_Function and hasattr(rv, "inverse"): + if (len(rv.args) == 1 and len(rv.args[0].args) == 1 and + isinstance(rv.args[0], rv.inverse(argindex=1))): + rv = rv.args[0].args[0] + if rv.is_Pow and rv.base == S.Exp1: + if isinstance(rv.exp, log): + rv = rv.exp.args[0] + return rv + + return _bottom_up(expr, f) + + +def kroneckersimp(expr): + """ + Simplify expressions with KroneckerDelta. + + The only simplification currently attempted is to identify multiplicative cancellation: + + Examples + ======== + + >>> from sympy import KroneckerDelta, kroneckersimp + >>> from sympy.abc import i + >>> kroneckersimp(1 + KroneckerDelta(0, i) * KroneckerDelta(1, i)) + 1 + """ + def args_cancel(args1, args2): + for i1 in range(2): + for i2 in range(2): + a1 = args1[i1] + a2 = args2[i2] + a3 = args1[(i1 + 1) % 2] + a4 = args2[(i2 + 1) % 2] + if Eq(a1, a2) is S.true and Eq(a3, a4) is S.false: + return True + return False + + def cancel_kronecker_mul(m): + args = m.args + deltas = [a for a in args if isinstance(a, KroneckerDelta)] + for delta1, delta2 in subsets(deltas, 2): + args1 = delta1.args + args2 = delta2.args + if args_cancel(args1, args2): + return S.Zero * m # In case of oo etc + return m + + if not expr.has(KroneckerDelta): + return expr + + if expr.has(Piecewise): + expr = expr.rewrite(KroneckerDelta) + + newexpr = expr + expr = None + + while newexpr != expr: + expr = newexpr + newexpr = expr.replace(lambda e: isinstance(e, Mul), cancel_kronecker_mul) + + return expr + + +def besselsimp(expr): + """ + Simplify bessel-type functions. + + Explanation + =========== + + This routine tries to simplify bessel-type functions. Currently it only + works on the Bessel J and I functions, however. It works by looking at all + such functions in turn, and eliminating factors of "I" and "-1" (actually + their polar equivalents) in front of the argument. Then, functions of + half-integer order are rewritten using trigonometric functions and + functions of integer order (> 1) are rewritten using functions + of low order. Finally, if the expression was changed, compute + factorization of the result with factor(). + + >>> from sympy import besselj, besseli, besselsimp, polar_lift, I, S + >>> from sympy.abc import z, nu + >>> besselsimp(besselj(nu, z*polar_lift(-1))) + exp(I*pi*nu)*besselj(nu, z) + >>> besselsimp(besseli(nu, z*polar_lift(-I))) + exp(-I*pi*nu/2)*besselj(nu, z) + >>> besselsimp(besseli(S(-1)/2, z)) + sqrt(2)*cosh(z)/(sqrt(pi)*sqrt(z)) + >>> besselsimp(z*besseli(0, z) + z*(besseli(2, z))/2 + besseli(1, z)) + 3*z*besseli(0, z)/2 + """ + # TODO + # - better algorithm? + # - simplify (cos(pi*b)*besselj(b,z) - besselj(-b,z))/sin(pi*b) ... + # - use contiguity relations? + + def replacer(fro, to, factors): + factors = set(factors) + + def repl(nu, z): + if factors.intersection(Mul.make_args(z)): + return to(nu, z) + return fro(nu, z) + return repl + + def torewrite(fro, to): + def tofunc(nu, z): + return fro(nu, z).rewrite(to) + return tofunc + + def tominus(fro): + def tofunc(nu, z): + return exp(I*pi*nu)*fro(nu, exp_polar(-I*pi)*z) + return tofunc + + orig_expr = expr + + ifactors = [I, exp_polar(I*pi/2), exp_polar(-I*pi/2)] + expr = expr.replace( + besselj, replacer(besselj, + torewrite(besselj, besseli), ifactors)) + expr = expr.replace( + besseli, replacer(besseli, + torewrite(besseli, besselj), ifactors)) + + minusfactors = [-1, exp_polar(I*pi)] + expr = expr.replace( + besselj, replacer(besselj, tominus(besselj), minusfactors)) + expr = expr.replace( + besseli, replacer(besseli, tominus(besseli), minusfactors)) + + z0 = Dummy('z') + + def expander(fro): + def repl(nu, z): + if (nu % 1) == S.Half: + return simplify(trigsimp(unpolarify( + fro(nu, z0).rewrite(besselj).rewrite(jn).expand( + func=True)).subs(z0, z))) + elif nu.is_Integer and nu > 1: + return fro(nu, z).expand(func=True) + return fro(nu, z) + return repl + + expr = expr.replace(besselj, expander(besselj)) + expr = expr.replace(bessely, expander(bessely)) + expr = expr.replace(besseli, expander(besseli)) + expr = expr.replace(besselk, expander(besselk)) + + def _bessel_simp_recursion(expr): + + def _use_recursion(bessel, expr): + while True: + bessels = expr.find(lambda x: isinstance(x, bessel)) + try: + for ba in sorted(bessels, key=lambda x: re(x.args[0])): + a, x = ba.args + bap1 = bessel(a+1, x) + bap2 = bessel(a+2, x) + if expr.has(bap1) and expr.has(bap2): + expr = expr.subs(ba, 2*(a+1)/x*bap1 - bap2) + break + else: + return expr + except (ValueError, TypeError): + return expr + if expr.has(besselj): + expr = _use_recursion(besselj, expr) + if expr.has(bessely): + expr = _use_recursion(bessely, expr) + return expr + + expr = _bessel_simp_recursion(expr) + if expr != orig_expr: + expr = expr.factor() + + return expr + + +def nthroot(expr, n, max_len=4, prec=15): + """ + Compute a real nth-root of a sum of surds. + + Parameters + ========== + + expr : sum of surds + n : integer + max_len : maximum number of surds passed as constants to ``nsimplify`` + + Algorithm + ========= + + First ``nsimplify`` is used to get a candidate root; if it is not a + root the minimal polynomial is computed; the answer is one of its + roots. + + Examples + ======== + + >>> from sympy.simplify.simplify import nthroot + >>> from sympy import sqrt + >>> nthroot(90 + 34*sqrt(7), 3) + sqrt(7) + 3 + + """ + expr = sympify(expr) + n = sympify(n) + p = expr**Rational(1, n) + if not n.is_integer: + return p + if not _is_sum_surds(expr): + return p + surds = [] + coeff_muls = [x.as_coeff_Mul() for x in expr.args] + for x, y in coeff_muls: + if not x.is_rational: + return p + if y is S.One: + continue + if not (y.is_Pow and y.exp == S.Half and y.base.is_integer): + return p + surds.append(y) + surds.sort() + surds = surds[:max_len] + if expr < 0 and n % 2 == 1: + p = (-expr)**Rational(1, n) + a = nsimplify(p, constants=surds) + res = a if _mexpand(a**n) == _mexpand(-expr) else p + return -res + a = nsimplify(p, constants=surds) + if _mexpand(a) is not _mexpand(p) and _mexpand(a**n) == _mexpand(expr): + return _mexpand(a) + expr = _nthroot_solve(expr, n, prec) + if expr is None: + return p + return expr + + +def nsimplify(expr, constants=(), tolerance=None, full=False, rational=None, + rational_conversion='base10'): + """ + Find a simple representation for a number or, if there are free symbols or + if ``rational=True``, then replace Floats with their Rational equivalents. If + no change is made and rational is not False then Floats will at least be + converted to Rationals. + + Explanation + =========== + + For numerical expressions, a simple formula that numerically matches the + given numerical expression is sought (and the input should be possible + to evalf to a precision of at least 30 digits). + + Optionally, a list of (rationally independent) constants to + include in the formula may be given. + + A lower tolerance may be set to find less exact matches. If no tolerance + is given then the least precise value will set the tolerance (e.g. Floats + default to 15 digits of precision, so would be tolerance=10**-15). + + With ``full=True``, a more extensive search is performed + (this is useful to find simpler numbers when the tolerance + is set low). + + When converting to rational, if rational_conversion='base10' (the default), then + convert floats to rationals using their base-10 (string) representation. + When rational_conversion='exact' it uses the exact, base-2 representation. + + Examples + ======== + + >>> from sympy import nsimplify, sqrt, GoldenRatio, exp, I, pi + >>> nsimplify(4/(1+sqrt(5)), [GoldenRatio]) + -2 + 2*GoldenRatio + >>> nsimplify((1/(exp(3*pi*I/5)+1))) + 1/2 - I*sqrt(sqrt(5)/10 + 1/4) + >>> nsimplify(I**I, [pi]) + exp(-pi/2) + >>> nsimplify(pi, tolerance=0.01) + 22/7 + + >>> nsimplify(0.333333333333333, rational=True, rational_conversion='exact') + 6004799503160655/18014398509481984 + >>> nsimplify(0.333333333333333, rational=True) + 1/3 + + See Also + ======== + + sympy.core.function.nfloat + + """ + try: + return sympify(as_int(expr)) + except (TypeError, ValueError): + pass + expr = sympify(expr).xreplace({ + Float('inf'): S.Infinity, + Float('-inf'): S.NegativeInfinity, + }) + if expr is S.Infinity or expr is S.NegativeInfinity: + return expr + if rational or expr.free_symbols: + return _real_to_rational(expr, tolerance, rational_conversion) + + # SymPy's default tolerance for Rationals is 15; other numbers may have + # lower tolerances set, so use them to pick the largest tolerance if None + # was given + if tolerance is None: + tolerance = 10**-min([15] + + [mpmath.libmp.libmpf.prec_to_dps(n._prec) + for n in expr.atoms(Float)]) + # XXX should prec be set independent of tolerance or should it be computed + # from tolerance? + prec = 30 + bprec = int(prec*3.33) + + constants_dict = {} + for constant in constants: + constant = sympify(constant) + v = constant.evalf(prec) + if not v.is_Float: + raise ValueError("constants must be real-valued") + constants_dict[str(constant)] = v._to_mpmath(bprec) + + exprval = expr.evalf(prec, chop=True) + re, im = exprval.as_real_imag() + + # safety check to make sure that this evaluated to a number + if not (re.is_Number and im.is_Number): + return expr + + def nsimplify_real(x): + orig = mpmath.mp.dps + xv = x._to_mpmath(bprec) + try: + # We'll be happy with low precision if a simple fraction + if not (tolerance or full): + mpmath.mp.dps = 15 + rat = mpmath.pslq([xv, 1]) + if rat is not None: + return Rational(-int(rat[1]), int(rat[0])) + mpmath.mp.dps = prec + newexpr = mpmath.identify(xv, constants=constants_dict, + tol=tolerance, full=full) + if not newexpr: + raise ValueError + if full: + newexpr = newexpr[0] + expr = sympify(newexpr) + if x and not expr: # don't let x become 0 + raise ValueError + if expr.is_finite is False and xv not in [mpmath.inf, mpmath.ninf]: + raise ValueError + return expr + finally: + # even though there are returns above, this is executed + # before leaving + mpmath.mp.dps = orig + try: + if re: + re = nsimplify_real(re) + if im: + im = nsimplify_real(im) + except ValueError: + if rational is None: + return _real_to_rational(expr, rational_conversion=rational_conversion) + return expr + + rv = re + im*S.ImaginaryUnit + # if there was a change or rational is explicitly not wanted + # return the value, else return the Rational representation + if rv != expr or rational is False: + return rv + return _real_to_rational(expr, rational_conversion=rational_conversion) + + +def _real_to_rational(expr, tolerance=None, rational_conversion='base10'): + """ + Replace all reals in expr with rationals. + + Examples + ======== + + >>> from sympy.simplify.simplify import _real_to_rational + >>> from sympy.abc import x + + >>> _real_to_rational(.76 + .1*x**.5) + sqrt(x)/10 + 19/25 + + If rational_conversion='base10', this uses the base-10 string. If + rational_conversion='exact', the exact, base-2 representation is used. + + >>> _real_to_rational(0.333333333333333, rational_conversion='exact') + 6004799503160655/18014398509481984 + >>> _real_to_rational(0.333333333333333) + 1/3 + + """ + expr = _sympify(expr) + inf = Float('inf') + p = expr + reps = {} + reduce_num = None + if tolerance is not None and tolerance < 1: + reduce_num = ceiling(1/tolerance) + for fl in p.atoms(Float): + key = fl + if reduce_num is not None: + r = Rational(fl).limit_denominator(reduce_num) + elif (tolerance is not None and tolerance >= 1 and + fl.is_Integer is False): + r = Rational(tolerance*round(fl/tolerance) + ).limit_denominator(int(tolerance)) + else: + if rational_conversion == 'exact': + r = Rational(fl) + reps[key] = r + continue + elif rational_conversion != 'base10': + raise ValueError("rational_conversion must be 'base10' or 'exact'") + + r = nsimplify(fl, rational=False) + # e.g. log(3).n() -> log(3) instead of a Rational + if fl and not r: + r = Rational(fl) + elif not r.is_Rational: + if fl in (inf, -inf): + r = S.ComplexInfinity + elif fl < 0: + fl = -fl + d = Pow(10, int(mpmath.log(fl)/mpmath.log(10))) + r = -Rational(str(fl/d))*d + elif fl > 0: + d = Pow(10, int(mpmath.log(fl)/mpmath.log(10))) + r = Rational(str(fl/d))*d + else: + r = S.Zero + reps[key] = r + return p.subs(reps, simultaneous=True) + + +def clear_coefficients(expr, rhs=S.Zero): + """Return `p, r` where `p` is the expression obtained when Rational + additive and multiplicative coefficients of `expr` have been stripped + away in a naive fashion (i.e. without simplification). The operations + needed to remove the coefficients will be applied to `rhs` and returned + as `r`. + + Examples + ======== + + >>> from sympy.simplify.simplify import clear_coefficients + >>> from sympy.abc import x, y + >>> from sympy import Dummy + >>> expr = 4*y*(6*x + 3) + >>> clear_coefficients(expr - 2) + (y*(2*x + 1), 1/6) + + When solving 2 or more expressions like `expr = a`, + `expr = b`, etc..., it is advantageous to provide a Dummy symbol + for `rhs` and simply replace it with `a`, `b`, etc... in `r`. + + >>> rhs = Dummy('rhs') + >>> clear_coefficients(expr, rhs) + (y*(2*x + 1), _rhs/12) + >>> _[1].subs(rhs, 2) + 1/6 + """ + was = None + free = expr.free_symbols + if expr.is_Rational: + return (S.Zero, rhs - expr) + while expr and was != expr: + was = expr + m, expr = ( + expr.as_content_primitive() + if free else + factor_terms(expr).as_coeff_Mul(rational=True)) + rhs /= m + c, expr = expr.as_coeff_Add(rational=True) + rhs -= c + expr = signsimp(expr, evaluate = False) + if expr.could_extract_minus_sign(): + expr = -expr + rhs = -rhs + return expr, rhs + +def nc_simplify(expr, deep=True): + ''' + Simplify a non-commutative expression composed of multiplication + and raising to a power by grouping repeated subterms into one power. + Priority is given to simplifications that give the fewest number + of arguments in the end (for example, in a*b*a*b*c*a*b*c simplifying + to (a*b)**2*c*a*b*c gives 5 arguments while a*b*(a*b*c)**2 has 3). + If ``expr`` is a sum of such terms, the sum of the simplified terms + is returned. + + Keyword argument ``deep`` controls whether or not subexpressions + nested deeper inside the main expression are simplified. See examples + below. Setting `deep` to `False` can save time on nested expressions + that do not need simplifying on all levels. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.simplify.simplify import nc_simplify + >>> a, b, c = symbols("a b c", commutative=False) + >>> nc_simplify(a*b*a*b*c*a*b*c) + a*b*(a*b*c)**2 + >>> expr = a**2*b*a**4*b*a**4 + >>> nc_simplify(expr) + a**2*(b*a**4)**2 + >>> nc_simplify(a*b*a*b*c**2*(a*b)**2*c**2) + ((a*b)**2*c**2)**2 + >>> nc_simplify(a*b*a*b + 2*a*c*a**2*c*a**2*c*a) + (a*b)**2 + 2*(a*c*a)**3 + >>> nc_simplify(b**-1*a**-1*(a*b)**2) + a*b + >>> nc_simplify(a**-1*b**-1*c*a) + (b*a)**(-1)*c*a + >>> expr = (a*b*a*b)**2*a*c*a*c + >>> nc_simplify(expr) + (a*b)**4*(a*c)**2 + >>> nc_simplify(expr, deep=False) + (a*b*a*b)**2*(a*c)**2 + + ''' + if isinstance(expr, MatrixExpr): + expr = expr.doit(inv_expand=False) + _Add, _Mul, _Pow, _Symbol = MatAdd, MatMul, MatPow, MatrixSymbol + else: + _Add, _Mul, _Pow, _Symbol = Add, Mul, Pow, Symbol + + # =========== Auxiliary functions ======================== + def _overlaps(args): + # Calculate a list of lists m such that m[i][j] contains the lengths + # of all possible overlaps between args[:i+1] and args[i+1+j:]. + # An overlap is a suffix of the prefix that matches a prefix + # of the suffix. + # For example, let expr=c*a*b*a*b*a*b*a*b. Then m[3][0] contains + # the lengths of overlaps of c*a*b*a*b with a*b*a*b. The overlaps + # are a*b*a*b, a*b and the empty word so that m[3][0]=[4,2,0]. + # All overlaps rather than only the longest one are recorded + # because this information helps calculate other overlap lengths. + m = [[([1, 0] if a == args[0] else [0]) for a in args[1:]]] + for i in range(1, len(args)): + overlaps = [] + j = 0 + for j in range(len(args) - i - 1): + overlap = [] + for v in m[i-1][j+1]: + if j + i + 1 + v < len(args) and args[i] == args[j+i+1+v]: + overlap.append(v + 1) + overlap += [0] + overlaps.append(overlap) + m.append(overlaps) + return m + + def _reduce_inverses(_args): + # replace consecutive negative powers by an inverse + # of a product of positive powers, e.g. a**-1*b**-1*c + # will simplify to (a*b)**-1*c; + # return that new args list and the number of negative + # powers in it (inv_tot) + inv_tot = 0 # total number of inverses + inverses = [] + args = [] + for arg in _args: + if isinstance(arg, _Pow) and arg.args[1].is_extended_negative: + inverses = [arg**-1] + inverses + inv_tot += 1 + else: + if len(inverses) == 1: + args.append(inverses[0]**-1) + elif len(inverses) > 1: + args.append(_Pow(_Mul(*inverses), -1)) + inv_tot -= len(inverses) - 1 + inverses = [] + args.append(arg) + if inverses: + args.append(_Pow(_Mul(*inverses), -1)) + inv_tot -= len(inverses) - 1 + return inv_tot, tuple(args) + + def get_score(s): + # compute the number of arguments of s + # (including in nested expressions) overall + # but ignore exponents + if isinstance(s, _Pow): + return get_score(s.args[0]) + elif isinstance(s, (_Add, _Mul)): + return sum(get_score(a) for a in s.args) + return 1 + + def compare(s, alt_s): + # compare two possible simplifications and return a + # "better" one + if s != alt_s and get_score(alt_s) < get_score(s): + return alt_s + return s + # ======================================================== + + if not isinstance(expr, (_Add, _Mul, _Pow)) or expr.is_commutative: + return expr + args = expr.args[:] + if isinstance(expr, _Pow): + if deep: + return _Pow(nc_simplify(args[0]), args[1]).doit() + else: + return expr + elif isinstance(expr, _Add): + return _Add(*[nc_simplify(a, deep=deep) for a in args]).doit() + else: + # get the non-commutative part + c_args, args = expr.args_cnc() + com_coeff = Mul(*c_args) + if not equal_valued(com_coeff, 1): + return com_coeff*nc_simplify(expr/com_coeff, deep=deep) + + inv_tot, args = _reduce_inverses(args) + # if most arguments are negative, work with the inverse + # of the expression, e.g. a**-1*b*a**-1*c**-1 will become + # (c*a*b**-1*a)**-1 at the end so can work with c*a*b**-1*a + invert = False + if inv_tot > len(args)/2: + invert = True + args = [a**-1 for a in args[::-1]] + + if deep: + args = tuple(nc_simplify(a) for a in args) + + m = _overlaps(args) + + # simps will be {subterm: end} where `end` is the ending + # index of a sequence of repetitions of subterm; + # this is for not wasting time with subterms that are part + # of longer, already considered sequences + simps = {} + + post = 1 + pre = 1 + + # the simplification coefficient is the number of + # arguments by which contracting a given sequence + # would reduce the word; e.g. in a*b*a*b*c*a*b*c, + # contracting a*b*a*b to (a*b)**2 removes 3 arguments + # while a*b*c*a*b*c to (a*b*c)**2 removes 6. It's + # better to contract the latter so simplification + # with a maximum simplification coefficient will be chosen + max_simp_coeff = 0 + simp = None # information about future simplification + + for i in range(1, len(args)): + simp_coeff = 0 + l = 0 # length of a subterm + p = 0 # the power of a subterm + if i < len(args) - 1: + rep = m[i][0] + start = i # starting index of the repeated sequence + end = i+1 # ending index of the repeated sequence + if i == len(args)-1 or rep == [0]: + # no subterm is repeated at this stage, at least as + # far as the arguments are concerned - there may be + # a repetition if powers are taken into account + if (isinstance(args[i], _Pow) and + not isinstance(args[i].args[0], _Symbol)): + subterm = args[i].args[0].args + l = len(subterm) + if args[i-l:i] == subterm: + # e.g. a*b in a*b*(a*b)**2 is not repeated + # in args (= [a, b, (a*b)**2]) but it + # can be matched here + p += 1 + start -= l + if args[i+1:i+1+l] == subterm: + # e.g. a*b in (a*b)**2*a*b + p += 1 + end += l + if p: + p += args[i].args[1] + else: + continue + else: + l = rep[0] # length of the longest repeated subterm at this point + start -= l - 1 + subterm = args[start:end] + p = 2 + end += l + + if subterm in simps and simps[subterm] >= start: + # the subterm is part of a sequence that + # has already been considered + continue + + # count how many times it's repeated + while end < len(args): + if l in m[end-1][0]: + p += 1 + end += l + elif isinstance(args[end], _Pow) and args[end].args[0].args == subterm: + # for cases like a*b*a*b*(a*b)**2*a*b + p += args[end].args[1] + end += 1 + else: + break + + # see if another match can be made, e.g. + # for b*a**2 in b*a**2*b*a**3 or a*b in + # a**2*b*a*b + + pre_exp = 0 + pre_arg = 1 + if start - l >= 0 and args[start-l+1:start] == subterm[1:]: + if isinstance(subterm[0], _Pow): + pre_arg = subterm[0].args[0] + exp = subterm[0].args[1] + else: + pre_arg = subterm[0] + exp = 1 + if isinstance(args[start-l], _Pow) and args[start-l].args[0] == pre_arg: + pre_exp = args[start-l].args[1] - exp + start -= l + p += 1 + elif args[start-l] == pre_arg: + pre_exp = 1 - exp + start -= l + p += 1 + + post_exp = 0 + post_arg = 1 + if end + l - 1 < len(args) and args[end:end+l-1] == subterm[:-1]: + if isinstance(subterm[-1], _Pow): + post_arg = subterm[-1].args[0] + exp = subterm[-1].args[1] + else: + post_arg = subterm[-1] + exp = 1 + if isinstance(args[end+l-1], _Pow) and args[end+l-1].args[0] == post_arg: + post_exp = args[end+l-1].args[1] - exp + end += l + p += 1 + elif args[end+l-1] == post_arg: + post_exp = 1 - exp + end += l + p += 1 + + # Consider a*b*a**2*b*a**2*b*a: + # b*a**2 is explicitly repeated, but note + # that in this case a*b*a is also repeated + # so there are two possible simplifications: + # a*(b*a**2)**3*a**-1 or (a*b*a)**3 + # The latter is obviously simpler. + # But in a*b*a**2*b**2*a**2 the simplifications are + # a*(b*a**2)**2 and (a*b*a)**3*a in which case + # it's better to stick with the shorter subterm + if post_exp and exp % 2 == 0 and start > 0: + exp = exp/2 + _pre_exp = 1 + _post_exp = 1 + if isinstance(args[start-1], _Pow) and args[start-1].args[0] == post_arg: + _post_exp = post_exp + exp + _pre_exp = args[start-1].args[1] - exp + elif args[start-1] == post_arg: + _post_exp = post_exp + exp + _pre_exp = 1 - exp + if _pre_exp == 0 or _post_exp == 0: + if not pre_exp: + start -= 1 + post_exp = _post_exp + pre_exp = _pre_exp + pre_arg = post_arg + subterm = (post_arg**exp,) + subterm[:-1] + (post_arg**exp,) + + simp_coeff += end-start + + if post_exp: + simp_coeff -= 1 + if pre_exp: + simp_coeff -= 1 + + simps[subterm] = end + + if simp_coeff > max_simp_coeff: + max_simp_coeff = simp_coeff + simp = (start, _Mul(*subterm), p, end, l) + pre = pre_arg**pre_exp + post = post_arg**post_exp + + if simp: + subterm = _Pow(nc_simplify(simp[1], deep=deep), simp[2]) + pre = nc_simplify(_Mul(*args[:simp[0]])*pre, deep=deep) + post = post*nc_simplify(_Mul(*args[simp[3]:]), deep=deep) + simp = pre*subterm*post + if pre != 1 or post != 1: + # new simplifications may be possible but no need + # to recurse over arguments + simp = nc_simplify(simp, deep=False) + else: + simp = _Mul(*args) + + if invert: + simp = _Pow(simp, -1) + + # see if factor_nc(expr) is simplified better + if not isinstance(expr, MatrixExpr): + f_expr = factor_nc(expr) + if f_expr != expr: + alt_simp = nc_simplify(f_expr, deep=deep) + simp = compare(simp, alt_simp) + else: + simp = simp.doit(inv_expand=False) + return simp + + +def dotprodsimp(expr, withsimp=False): + """Simplification for a sum of products targeted at the kind of blowup that + occurs during summation of products. Intended to reduce expression blowup + during matrix multiplication or other similar operations. Only works with + algebraic expressions and does not recurse into non. + + Parameters + ========== + + withsimp : bool, optional + Specifies whether a flag should be returned along with the expression + to indicate roughly whether simplification was successful. It is used + in ``MatrixArithmetic._eval_pow_by_recursion`` to avoid attempting to + simplify an expression repetitively which does not simplify. + """ + + def count_ops_alg(expr): + """Optimized count algebraic operations with no recursion into + non-algebraic args that ``core.function.count_ops`` does. Also returns + whether rational functions may be present according to negative + exponents of powers or non-number fractions. + + Returns + ======= + + ops, ratfunc : int, bool + ``ops`` is the number of algebraic operations starting at the top + level expression (not recursing into non-alg children). ``ratfunc`` + specifies whether the expression MAY contain rational functions + which ``cancel`` MIGHT optimize. + """ + + ops = 0 + args = [expr] + ratfunc = False + + while args: + a = args.pop() + + if not isinstance(a, Basic): + continue + + if a.is_Rational: + if a is not S.One: # -1/3 = NEG + DIV + ops += bool (a.p < 0) + bool (a.q != 1) + + elif a.is_Mul: + if a.could_extract_minus_sign(): + ops += 1 + if a.args[0] is S.NegativeOne: + a = a.as_two_terms()[1] + else: + a = -a + + n, d = fraction(a) + + if n.is_Integer: + ops += 1 + bool (n < 0) + args.append(d) # won't be -Mul but could be Add + + elif d is not S.One: + if not d.is_Integer: + args.append(d) + ratfunc=True + + ops += 1 + args.append(n) # could be -Mul + + else: + ops += len(a.args) - 1 + args.extend(a.args) + + elif a.is_Add: + laargs = len(a.args) + negs = 0 + + for ai in a.args: + if ai.could_extract_minus_sign(): + negs += 1 + ai = -ai + args.append(ai) + + ops += laargs - (negs != laargs) # -x - y = NEG + SUB + + elif a.is_Pow: + ops += 1 + args.append(a.base) + + if not ratfunc: + ratfunc = a.exp.is_negative is not False + + return ops, ratfunc + + def nonalg_subs_dummies(expr, dummies): + """Substitute dummy variables for non-algebraic expressions to avoid + evaluation of non-algebraic terms that ``polys.polytools.cancel`` does. + """ + + if not expr.args: + return expr + + if expr.is_Add or expr.is_Mul or expr.is_Pow: + args = None + + for i, a in enumerate(expr.args): + c = nonalg_subs_dummies(a, dummies) + + if c is a: + continue + + if args is None: + args = list(expr.args) + + args[i] = c + + if args is None: + return expr + + return expr.func(*args) + + return dummies.setdefault(expr, Dummy()) + + simplified = False # doesn't really mean simplified, rather "can simplify again" + + if isinstance(expr, Basic) and (expr.is_Add or expr.is_Mul or expr.is_Pow): + expr2 = expr.expand(deep=True, modulus=None, power_base=False, + power_exp=False, mul=True, log=False, multinomial=True, basic=False) + + if expr2 != expr: + expr = expr2 + simplified = True + + exprops, ratfunc = count_ops_alg(expr) + + if exprops >= 6: # empirically tested cutoff for expensive simplification + if ratfunc: + dummies = {} + expr2 = nonalg_subs_dummies(expr, dummies) + + if expr2 is expr or count_ops_alg(expr2)[0] >= 6: # check again after substitution + expr3 = cancel(expr2) + + if expr3 != expr2: + expr = expr3.subs([(d, e) for e, d in dummies.items()]) + simplified = True + + # very special case: x/(x-1) - 1/(x-1) -> 1 + elif (exprops == 5 and expr.is_Add and expr.args [0].is_Mul and + expr.args [1].is_Mul and expr.args [0].args [-1].is_Pow and + expr.args [1].args [-1].is_Pow and + expr.args [0].args [-1].exp is S.NegativeOne and + expr.args [1].args [-1].exp is S.NegativeOne): + + expr2 = together (expr) + expr2ops = count_ops_alg(expr2)[0] + + if expr2ops < exprops: + expr = expr2 + simplified = True + + else: + simplified = True + + return (expr, simplified) if withsimp else expr + + +bottom_up = deprecated( + """ + Using bottom_up from the sympy.simplify.simplify submodule is + deprecated. + + Instead, use bottom_up from the top-level sympy namespace, like + + sympy.bottom_up + """, + deprecated_since_version="1.10", + active_deprecations_target="deprecated-traversal-functions-moved", +)(_bottom_up) + + +# XXX: This function really should either be private API or exported in the +# top-level sympy/__init__.py +walk = deprecated( + """ + Using walk from the sympy.simplify.simplify submodule is + deprecated. + + Instead, use walk from sympy.core.traversal.walk + """, + deprecated_since_version="1.10", + active_deprecations_target="deprecated-traversal-functions-moved", +)(_walk) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/sqrtdenest.py b/.venv/lib/python3.13/site-packages/sympy/simplify/sqrtdenest.py new file mode 100644 index 0000000000000000000000000000000000000000..d266de7e62a4b7d37a2109f7091ff91e4df7c79d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/sqrtdenest.py @@ -0,0 +1,678 @@ +from sympy.core import Add, Expr, Mul, S, sympify +from sympy.core.function import _mexpand, count_ops, expand_mul +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Dummy +from sympy.functions import root, sign, sqrt +from sympy.polys import Poly, PolynomialError + + +def is_sqrt(expr): + """Return True if expr is a sqrt, otherwise False.""" + + return expr.is_Pow and expr.exp.is_Rational and abs(expr.exp) is S.Half + + +def sqrt_depth(p) -> int: + """Return the maximum depth of any square root argument of p. + + >>> from sympy.functions.elementary.miscellaneous import sqrt + >>> from sympy.simplify.sqrtdenest import sqrt_depth + + Neither of these square roots contains any other square roots + so the depth is 1: + + >>> sqrt_depth(1 + sqrt(2)*(1 + sqrt(3))) + 1 + + The sqrt(3) is contained within a square root so the depth is + 2: + + >>> sqrt_depth(1 + sqrt(2)*sqrt(1 + sqrt(3))) + 2 + """ + if p is S.ImaginaryUnit: + return 1 + if p.is_Atom: + return 0 + if p.is_Add or p.is_Mul: + return max(sqrt_depth(x) for x in p.args) + if is_sqrt(p): + return sqrt_depth(p.base) + 1 + return 0 + + +def is_algebraic(p): + """Return True if p is comprised of only Rationals or square roots + of Rationals and algebraic operations. + + Examples + ======== + + >>> from sympy.functions.elementary.miscellaneous import sqrt + >>> from sympy.simplify.sqrtdenest import is_algebraic + >>> from sympy import cos + >>> is_algebraic(sqrt(2)*(3/(sqrt(7) + sqrt(5)*sqrt(2)))) + True + >>> is_algebraic(sqrt(2)*(3/(sqrt(7) + sqrt(5)*cos(2)))) + False + """ + + if p.is_Rational: + return True + elif p.is_Atom: + return False + elif is_sqrt(p) or p.is_Pow and p.exp.is_Integer: + return is_algebraic(p.base) + elif p.is_Add or p.is_Mul: + return all(is_algebraic(x) for x in p.args) + else: + return False + + +def _subsets(n): + """ + Returns all possible subsets of the set (0, 1, ..., n-1) except the + empty set, listed in reversed lexicographical order according to binary + representation, so that the case of the fourth root is treated last. + + Examples + ======== + + >>> from sympy.simplify.sqrtdenest import _subsets + >>> _subsets(2) + [[1, 0], [0, 1], [1, 1]] + + """ + if n == 1: + a = [[1]] + elif n == 2: + a = [[1, 0], [0, 1], [1, 1]] + elif n == 3: + a = [[1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]] + else: + b = _subsets(n - 1) + a0 = [x + [0] for x in b] + a1 = [x + [1] for x in b] + a = a0 + [[0]*(n - 1) + [1]] + a1 + return a + + +def sqrtdenest(expr, max_iter=3): + """Denests sqrts in an expression that contain other square roots + if possible, otherwise returns the expr unchanged. This is based on the + algorithms of [1]. + + Examples + ======== + + >>> from sympy.simplify.sqrtdenest import sqrtdenest + >>> from sympy import sqrt + >>> sqrtdenest(sqrt(5 + 2 * sqrt(6))) + sqrt(2) + sqrt(3) + + See Also + ======== + + sympy.solvers.solvers.unrad + + References + ========== + + .. [1] https://web.archive.org/web/20210806201615/https://researcher.watson.ibm.com/researcher/files/us-fagin/symb85.pdf + + .. [2] D. J. Jeffrey and A. D. Rich, 'Symplifying Square Roots of Square Roots + by Denesting' (available at https://www.cybertester.com/data/denest.pdf) + + """ + expr = expand_mul(expr) + for i in range(max_iter): + z = _sqrtdenest0(expr) + if expr == z: + return expr + expr = z + return expr + + +def _sqrt_match(p): + """Return [a, b, r] for p.match(a + b*sqrt(r)) where, in addition to + matching, sqrt(r) also has then maximal sqrt_depth among addends of p. + + Examples + ======== + + >>> from sympy.functions.elementary.miscellaneous import sqrt + >>> from sympy.simplify.sqrtdenest import _sqrt_match + >>> _sqrt_match(1 + sqrt(2) + sqrt(2)*sqrt(3) + 2*sqrt(1+sqrt(5))) + [1 + sqrt(2) + sqrt(6), 2, 1 + sqrt(5)] + """ + from sympy.simplify.radsimp import split_surds + + p = _mexpand(p) + if p.is_Number: + res = (p, S.Zero, S.Zero) + elif p.is_Add: + pargs = sorted(p.args, key=default_sort_key) + sqargs = [x**2 for x in pargs] + if all(sq.is_Rational and sq.is_positive for sq in sqargs): + r, b, a = split_surds(p) + res = a, b, r + return list(res) + # to make the process canonical, the argument is included in the tuple + # so when the max is selected, it will be the largest arg having a + # given depth + v = [(sqrt_depth(x), x, i) for i, x in enumerate(pargs)] + nmax = max(v, key=default_sort_key) + if nmax[0] == 0: + res = [] + else: + # select r + depth, _, i = nmax + r = pargs.pop(i) + v.pop(i) + b = S.One + if r.is_Mul: + bv = [] + rv = [] + for x in r.args: + if sqrt_depth(x) < depth: + bv.append(x) + else: + rv.append(x) + b = Mul._from_args(bv) + r = Mul._from_args(rv) + # collect terms containing r + a1 = [] + b1 = [b] + for x in v: + if x[0] < depth: + a1.append(x[1]) + else: + x1 = x[1] + if x1 == r: + b1.append(1) + else: + if x1.is_Mul: + x1args = list(x1.args) + if r in x1args: + x1args.remove(r) + b1.append(Mul(*x1args)) + else: + a1.append(x[1]) + else: + a1.append(x[1]) + a = Add(*a1) + b = Add(*b1) + res = (a, b, r**2) + else: + b, r = p.as_coeff_Mul() + if is_sqrt(r): + res = (S.Zero, b, r**2) + else: + res = [] + return list(res) + + +class SqrtdenestStopIteration(StopIteration): + pass + + +def _sqrtdenest0(expr): + """Returns expr after denesting its arguments.""" + + if is_sqrt(expr): + n, d = expr.as_numer_denom() + if d is S.One: # n is a square root + if n.base.is_Add: + args = sorted(n.base.args, key=default_sort_key) + if len(args) > 2 and all((x**2).is_Integer for x in args): + try: + return _sqrtdenest_rec(n) + except SqrtdenestStopIteration: + pass + expr = sqrt(_mexpand(Add(*[_sqrtdenest0(x) for x in args]))) + return _sqrtdenest1(expr) + else: + n, d = [_sqrtdenest0(i) for i in (n, d)] + return n/d + + if isinstance(expr, Add): + cs = [] + args = [] + for arg in expr.args: + c, a = arg.as_coeff_Mul() + cs.append(c) + args.append(a) + + if all(c.is_Rational for c in cs) and all(is_sqrt(arg) for arg in args): + return _sqrt_ratcomb(cs, args) + + if isinstance(expr, Expr): + args = expr.args + if args: + return expr.func(*[_sqrtdenest0(a) for a in args]) + return expr + + +def _sqrtdenest_rec(expr): + """Helper that denests the square root of three or more surds. + + Explanation + =========== + + It returns the denested expression; if it cannot be denested it + throws SqrtdenestStopIteration + + Algorithm: expr.base is in the extension Q_m = Q(sqrt(r_1),..,sqrt(r_k)); + split expr.base = a + b*sqrt(r_k), where `a` and `b` are on + Q_(m-1) = Q(sqrt(r_1),..,sqrt(r_(k-1))); then a**2 - b**2*r_k is + on Q_(m-1); denest sqrt(a**2 - b**2*r_k) and so on. + See [1], section 6. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.sqrtdenest import _sqrtdenest_rec + >>> _sqrtdenest_rec(sqrt(-72*sqrt(2) + 158*sqrt(5) + 498)) + -sqrt(10) + sqrt(2) + 9 + 9*sqrt(5) + >>> w=-6*sqrt(55)-6*sqrt(35)-2*sqrt(22)-2*sqrt(14)+2*sqrt(77)+6*sqrt(10)+65 + >>> _sqrtdenest_rec(sqrt(w)) + -sqrt(11) - sqrt(7) + sqrt(2) + 3*sqrt(5) + """ + from sympy.simplify.radsimp import radsimp, rad_rationalize, split_surds + if not expr.is_Pow: + return sqrtdenest(expr) + if expr.base < 0: + return sqrt(-1)*_sqrtdenest_rec(sqrt(-expr.base)) + g, a, b = split_surds(expr.base) + a = a*sqrt(g) + if a < b: + a, b = b, a + c2 = _mexpand(a**2 - b**2) + if len(c2.args) > 2: + g, a1, b1 = split_surds(c2) + a1 = a1*sqrt(g) + if a1 < b1: + a1, b1 = b1, a1 + c2_1 = _mexpand(a1**2 - b1**2) + c_1 = _sqrtdenest_rec(sqrt(c2_1)) + d_1 = _sqrtdenest_rec(sqrt(a1 + c_1)) + num, den = rad_rationalize(b1, d_1) + c = _mexpand(d_1/sqrt(2) + num/(den*sqrt(2))) + else: + c = _sqrtdenest1(sqrt(c2)) + + if sqrt_depth(c) > 1: + raise SqrtdenestStopIteration + ac = a + c + if len(ac.args) >= len(expr.args): + if count_ops(ac) >= count_ops(expr.base): + raise SqrtdenestStopIteration + d = sqrtdenest(sqrt(ac)) + if sqrt_depth(d) > 1: + raise SqrtdenestStopIteration + num, den = rad_rationalize(b, d) + r = d/sqrt(2) + num/(den*sqrt(2)) + r = radsimp(r) + return _mexpand(r) + + +def _sqrtdenest1(expr, denester=True): + """Return denested expr after denesting with simpler methods or, that + failing, using the denester.""" + + from sympy.simplify.simplify import radsimp + + if not is_sqrt(expr): + return expr + + a = expr.base + if a.is_Atom: + return expr + val = _sqrt_match(a) + if not val: + return expr + + a, b, r = val + # try a quick numeric denesting + d2 = _mexpand(a**2 - b**2*r) + if d2.is_Rational: + if d2.is_positive: + z = _sqrt_numeric_denest(a, b, r, d2) + if z is not None: + return z + else: + # fourth root case + # sqrtdenest(sqrt(3 + 2*sqrt(3))) = + # sqrt(2)*3**(1/4)/2 + sqrt(2)*3**(3/4)/2 + dr2 = _mexpand(-d2*r) + dr = sqrt(dr2) + if dr.is_Rational: + z = _sqrt_numeric_denest(_mexpand(b*r), a, r, dr2) + if z is not None: + return z/root(r, 4) + + else: + z = _sqrt_symbolic_denest(a, b, r) + if z is not None: + return z + + if not denester or not is_algebraic(expr): + return expr + + res = sqrt_biquadratic_denest(expr, a, b, r, d2) + if res: + return res + + # now call to the denester + av0 = [a, b, r, d2] + z = _denester([radsimp(expr**2)], av0, 0, sqrt_depth(expr))[0] + if av0[1] is None: + return expr + if z is not None: + if sqrt_depth(z) == sqrt_depth(expr) and count_ops(z) > count_ops(expr): + return expr + return z + return expr + + +def _sqrt_symbolic_denest(a, b, r): + """Given an expression, sqrt(a + b*sqrt(b)), return the denested + expression or None. + + Explanation + =========== + + If r = ra + rb*sqrt(rr), try replacing sqrt(rr) in ``a`` with + (y**2 - ra)/rb, and if the result is a quadratic, ca*y**2 + cb*y + cc, and + (cb + b)**2 - 4*ca*cc is 0, then sqrt(a + b*sqrt(r)) can be rewritten as + sqrt(ca*(sqrt(r) + (cb + b)/(2*ca))**2). + + Examples + ======== + + >>> from sympy.simplify.sqrtdenest import _sqrt_symbolic_denest, sqrtdenest + >>> from sympy import sqrt, Symbol + >>> from sympy.abc import x + + >>> a, b, r = 16 - 2*sqrt(29), 2, -10*sqrt(29) + 55 + >>> _sqrt_symbolic_denest(a, b, r) + sqrt(11 - 2*sqrt(29)) + sqrt(5) + + If the expression is numeric, it will be simplified: + + >>> w = sqrt(sqrt(sqrt(3) + 1) + 1) + 1 + sqrt(2) + >>> sqrtdenest(sqrt((w**2).expand())) + 1 + sqrt(2) + sqrt(1 + sqrt(1 + sqrt(3))) + + Otherwise, it will only be simplified if assumptions allow: + + >>> w = w.subs(sqrt(3), sqrt(x + 3)) + >>> sqrtdenest(sqrt((w**2).expand())) + sqrt((sqrt(sqrt(sqrt(x + 3) + 1) + 1) + 1 + sqrt(2))**2) + + Notice that the argument of the sqrt is a square. If x is made positive + then the sqrt of the square is resolved: + + >>> _.subs(x, Symbol('x', positive=True)) + sqrt(sqrt(sqrt(x + 3) + 1) + 1) + 1 + sqrt(2) + """ + + a, b, r = map(sympify, (a, b, r)) + rval = _sqrt_match(r) + if not rval: + return None + ra, rb, rr = rval + if rb: + y = Dummy('y', positive=True) + try: + newa = Poly(a.subs(sqrt(rr), (y**2 - ra)/rb), y) + except PolynomialError: + return None + if newa.degree() == 2: + ca, cb, cc = newa.all_coeffs() + cb += b + if _mexpand(cb**2 - 4*ca*cc).equals(0): + z = sqrt(ca*(sqrt(r) + cb/(2*ca))**2) + if z.is_number: + z = _mexpand(Mul._from_args(z.as_content_primitive())) + return z + + +def _sqrt_numeric_denest(a, b, r, d2): + r"""Helper that denest + $\sqrt{a + b \sqrt{r}}, d^2 = a^2 - b^2 r > 0$ + + If it cannot be denested, it returns ``None``. + """ + d = sqrt(d2) + s = a + d + # sqrt_depth(res) <= sqrt_depth(s) + 1 + # sqrt_depth(expr) = sqrt_depth(r) + 2 + # there is denesting if sqrt_depth(s) + 1 < sqrt_depth(r) + 2 + # if s**2 is Number there is a fourth root + if sqrt_depth(s) < sqrt_depth(r) + 1 or (s**2).is_Rational: + s1, s2 = sign(s), sign(b) + if s1 == s2 == -1: + s1 = s2 = 1 + res = (s1 * sqrt(a + d) + s2 * sqrt(a - d)) * sqrt(2) / 2 + return res.expand() + + +def sqrt_biquadratic_denest(expr, a, b, r, d2): + """denest expr = sqrt(a + b*sqrt(r)) + where a, b, r are linear combinations of square roots of + positive rationals on the rationals (SQRR) and r > 0, b != 0, + d2 = a**2 - b**2*r > 0 + + If it cannot denest it returns None. + + Explanation + =========== + + Search for a solution A of type SQRR of the biquadratic equation + 4*A**4 - 4*a*A**2 + b**2*r = 0 (1) + sqd = sqrt(a**2 - b**2*r) + Choosing the sqrt to be positive, the possible solutions are + A = sqrt(a/2 +/- sqd/2) + Since a, b, r are SQRR, then a**2 - b**2*r is a SQRR, + so if sqd can be denested, it is done by + _sqrtdenest_rec, and the result is a SQRR. + Similarly for A. + Examples of solutions (in both cases a and sqd are positive): + + Example of expr with solution sqrt(a/2 + sqd/2) but not + solution sqrt(a/2 - sqd/2): + expr = sqrt(-sqrt(15) - sqrt(2)*sqrt(-sqrt(5) + 5) - sqrt(3) + 8) + a = -sqrt(15) - sqrt(3) + 8; sqd = -2*sqrt(5) - 2 + 4*sqrt(3) + + Example of expr with solution sqrt(a/2 - sqd/2) but not + solution sqrt(a/2 + sqd/2): + w = 2 + r2 + r3 + (1 + r3)*sqrt(2 + r2 + 5*r3) + expr = sqrt((w**2).expand()) + a = 4*sqrt(6) + 8*sqrt(2) + 47 + 28*sqrt(3) + sqd = 29 + 20*sqrt(3) + + Define B = b/2*A; eq.(1) implies a = A**2 + B**2*r; then + expr**2 = a + b*sqrt(r) = (A + B*sqrt(r))**2 + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.sqrtdenest import _sqrt_match, sqrt_biquadratic_denest + >>> z = sqrt((2*sqrt(2) + 4)*sqrt(2 + sqrt(2)) + 5*sqrt(2) + 8) + >>> a, b, r = _sqrt_match(z**2) + >>> d2 = a**2 - b**2*r + >>> sqrt_biquadratic_denest(z, a, b, r, d2) + sqrt(2) + sqrt(sqrt(2) + 2) + 2 + """ + from sympy.simplify.radsimp import radsimp, rad_rationalize + if r <= 0 or d2 < 0 or not b or sqrt_depth(expr.base) < 2: + return None + for x in (a, b, r): + for y in x.args: + y2 = y**2 + if not y2.is_Integer or not y2.is_positive: + return None + sqd = _mexpand(sqrtdenest(sqrt(radsimp(d2)))) + if sqrt_depth(sqd) > 1: + return None + x1, x2 = [a/2 + sqd/2, a/2 - sqd/2] + # look for a solution A with depth 1 + for x in (x1, x2): + A = sqrtdenest(sqrt(x)) + if sqrt_depth(A) > 1: + continue + Bn, Bd = rad_rationalize(b, _mexpand(2*A)) + B = Bn/Bd + z = A + B*sqrt(r) + if z < 0: + z = -z + return _mexpand(z) + return None + + +def _denester(nested, av0, h, max_depth_level): + """Denests a list of expressions that contain nested square roots. + + Explanation + =========== + + Algorithm based on . + + It is assumed that all of the elements of 'nested' share the same + bottom-level radicand. (This is stated in the paper, on page 177, in + the paragraph immediately preceding the algorithm.) + + When evaluating all of the arguments in parallel, the bottom-level + radicand only needs to be denested once. This means that calling + _denester with x arguments results in a recursive invocation with x+1 + arguments; hence _denester has polynomial complexity. + + However, if the arguments were evaluated separately, each call would + result in two recursive invocations, and the algorithm would have + exponential complexity. + + This is discussed in the paper in the middle paragraph of page 179. + """ + from sympy.simplify.simplify import radsimp + if h > max_depth_level: + return None, None + if av0[1] is None: + return None, None + if (av0[0] is None and + all(n.is_Number for n in nested)): # no arguments are nested + for f in _subsets(len(nested)): # test subset 'f' of nested + p = _mexpand(Mul(*[nested[i] for i in range(len(f)) if f[i]])) + if f.count(1) > 1 and f[-1]: + p = -p + sqp = sqrt(p) + if sqp.is_Rational: + return sqp, f # got a perfect square so return its square root. + # Otherwise, return the radicand from the previous invocation. + return sqrt(nested[-1]), [0]*len(nested) + else: + R = None + if av0[0] is not None: + values = [av0[:2]] + R = av0[2] + nested2 = [av0[3], R] + av0[0] = None + else: + values = list(filter(None, [_sqrt_match(expr) for expr in nested])) + for v in values: + if v[2]: # Since if b=0, r is not defined + if R is not None: + if R != v[2]: + av0[1] = None + return None, None + else: + R = v[2] + if R is None: + # return the radicand from the previous invocation + return sqrt(nested[-1]), [0]*len(nested) + nested2 = [_mexpand(v[0]**2) - + _mexpand(R*v[1]**2) for v in values] + [R] + d, f = _denester(nested2, av0, h + 1, max_depth_level) + if not f: + return None, None + if not any(f[i] for i in range(len(nested))): + v = values[-1] + return sqrt(v[0] + _mexpand(v[1]*d)), f + else: + p = Mul(*[nested[i] for i in range(len(nested)) if f[i]]) + v = _sqrt_match(p) + if 1 in f and f.index(1) < len(nested) - 1 and f[len(nested) - 1]: + v[0] = -v[0] + v[1] = -v[1] + if not f[len(nested)]: # Solution denests with square roots + vad = _mexpand(v[0] + d) + if vad <= 0: + # return the radicand from the previous invocation. + return sqrt(nested[-1]), [0]*len(nested) + if not(sqrt_depth(vad) <= sqrt_depth(R) + 1 or + (vad**2).is_Number): + av0[1] = None + return None, None + + sqvad = _sqrtdenest1(sqrt(vad), denester=False) + if not (sqrt_depth(sqvad) <= sqrt_depth(R) + 1): + av0[1] = None + return None, None + sqvad1 = radsimp(1/sqvad) + res = _mexpand(sqvad/sqrt(2) + (v[1]*sqrt(R)*sqvad1/sqrt(2))) + return res, f + + # sign(v[1])*sqrt(_mexpand(v[1]**2*R*vad1/2))), f + else: # Solution requires a fourth root + s2 = _mexpand(v[1]*R) + d + if s2 <= 0: + return sqrt(nested[-1]), [0]*len(nested) + FR, s = root(_mexpand(R), 4), sqrt(s2) + return _mexpand(s/(sqrt(2)*FR) + v[0]*FR/(sqrt(2)*s)), f + + +def _sqrt_ratcomb(cs, args): + """Denest rational combinations of radicals. + + Based on section 5 of [1]. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.sqrtdenest import sqrtdenest + >>> z = sqrt(1+sqrt(3)) + sqrt(3+3*sqrt(3)) - sqrt(10+6*sqrt(3)) + >>> sqrtdenest(z) + 0 + """ + from sympy.simplify.radsimp import radsimp + + # check if there exists a pair of sqrt that can be denested + def find(a): + n = len(a) + for i in range(n - 1): + for j in range(i + 1, n): + s1 = a[i].base + s2 = a[j].base + p = _mexpand(s1 * s2) + s = sqrtdenest(sqrt(p)) + if s != sqrt(p): + return s, i, j + + indices = find(args) + if indices is None: + return Add(*[c * arg for c, arg in zip(cs, args)]) + + s, i1, i2 = indices + + c2 = cs.pop(i2) + args.pop(i2) + a1 = args[i1] + + # replace a2 by s/a1 + cs[i1] += radsimp(c2 * s / a1.base) + + return _sqrt_ratcomb(cs, args) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_combsimp.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_combsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..e56758a005fbb013c2b6ea4121b16c3434a54b03 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_combsimp.py @@ -0,0 +1,75 @@ +from sympy.core.numbers import Rational +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import (FallingFactorial, RisingFactorial, binomial, factorial) +from sympy.functions.special.gamma_functions import gamma +from sympy.simplify.combsimp import combsimp +from sympy.abc import x + + +def test_combsimp(): + k, m, n = symbols('k m n', integer = True) + + assert combsimp(factorial(n)) == factorial(n) + assert combsimp(binomial(n, k)) == binomial(n, k) + + assert combsimp(factorial(n)/factorial(n - 3)) == n*(-1 + n)*(-2 + n) + assert combsimp(binomial(n + 1, k + 1)/binomial(n, k)) == (1 + n)/(1 + k) + + assert combsimp(binomial(3*n + 4, n + 1)/binomial(3*n + 1, n)) == \ + Rational(3, 2)*((3*n + 2)*(3*n + 4)/((n + 1)*(2*n + 3))) + + assert combsimp(factorial(n)**2/factorial(n - 3)) == \ + factorial(n)*n*(-1 + n)*(-2 + n) + assert combsimp(factorial(n)*binomial(n + 1, k + 1)/binomial(n, k)) == \ + factorial(n + 1)/(1 + k) + + assert combsimp(gamma(n + 3)) == factorial(n + 2) + + assert combsimp(factorial(x)) == gamma(x + 1) + + # issue 9699 + assert combsimp((n + 1)*factorial(n)) == factorial(n + 1) + assert combsimp(factorial(n)/n) == factorial(n-1) + + # issue 6658 + assert combsimp(binomial(n, n - k)) == binomial(n, k) + + # issue 6341, 7135 + assert combsimp(factorial(n)/(factorial(k)*factorial(n - k))) == \ + binomial(n, k) + assert combsimp(factorial(k)*factorial(n - k)/factorial(n)) == \ + 1/binomial(n, k) + assert combsimp(factorial(2*n)/factorial(n)**2) == binomial(2*n, n) + assert combsimp(factorial(2*n)*factorial(k)*factorial(n - k)/ + factorial(n)**3) == binomial(2*n, n)/binomial(n, k) + + assert combsimp(factorial(n*(1 + n) - n**2 - n)) == 1 + + assert combsimp(6*FallingFactorial(-4, n)/factorial(n)) == \ + (-1)**n*(n + 1)*(n + 2)*(n + 3) + assert combsimp(6*FallingFactorial(-4, n - 1)/factorial(n - 1)) == \ + (-1)**(n - 1)*n*(n + 1)*(n + 2) + assert combsimp(6*FallingFactorial(-4, n - 3)/factorial(n - 3)) == \ + (-1)**(n - 3)*n*(n - 1)*(n - 2) + assert combsimp(6*FallingFactorial(-4, -n - 1)/factorial(-n - 1)) == \ + -(-1)**(-n - 1)*n*(n - 1)*(n - 2) + + assert combsimp(6*RisingFactorial(4, n)/factorial(n)) == \ + (n + 1)*(n + 2)*(n + 3) + assert combsimp(6*RisingFactorial(4, n - 1)/factorial(n - 1)) == \ + n*(n + 1)*(n + 2) + assert combsimp(6*RisingFactorial(4, n - 3)/factorial(n - 3)) == \ + n*(n - 1)*(n - 2) + assert combsimp(6*RisingFactorial(4, -n - 1)/factorial(-n - 1)) == \ + -n*(n - 1)*(n - 2) + + +def test_issue_6878(): + n = symbols('n', integer=True) + assert combsimp(RisingFactorial(-10, n)) == 3628800*(-1)**n/factorial(10 - n) + + +def test_issue_14528(): + p = symbols("p", integer=True, positive=True) + assert combsimp(binomial(1,p)) == 1/(factorial(p)*factorial(1-p)) + assert combsimp(factorial(2-p)) == factorial(2-p) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_cse.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_cse.py new file mode 100644 index 0000000000000000000000000000000000000000..c2a34dfb0e227547bd41bed2491284fd7150d0b6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_cse.py @@ -0,0 +1,761 @@ +from functools import reduce +import itertools +from operator import add + +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.expr import UnevaluatedExpr +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions import Inverse, MatAdd, MatMul, Transpose +from sympy.polys.rootoftools import CRootOf +from sympy.series.order import O +from sympy.simplify.cse_main import cse +from sympy.simplify.simplify import signsimp +from sympy.tensor.indexed import (Idx, IndexedBase) + +from sympy.core.function import count_ops +from sympy.simplify.cse_opts import sub_pre, sub_post +from sympy.functions.special.hyper import meijerg +from sympy.simplify import cse_main, cse_opts +from sympy.utilities.iterables import subsets +from sympy.testing.pytest import XFAIL, raises +from sympy.matrices import (MutableDenseMatrix, MutableSparseMatrix, + ImmutableDenseMatrix, ImmutableSparseMatrix) +from sympy.matrices.expressions import MatrixSymbol + + +w, x, y, z = symbols('w,x,y,z') +x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13') + + +def test_numbered_symbols(): + ns = cse_main.numbered_symbols(prefix='y') + assert list(itertools.islice( + ns, 0, 10)) == [Symbol('y%s' % i) for i in range(0, 10)] + ns = cse_main.numbered_symbols(prefix='y') + assert list(itertools.islice( + ns, 10, 20)) == [Symbol('y%s' % i) for i in range(10, 20)] + ns = cse_main.numbered_symbols() + assert list(itertools.islice( + ns, 0, 10)) == [Symbol('x%s' % i) for i in range(0, 10)] + +# Dummy "optimization" functions for testing. + + +def opt1(expr): + return expr + y + + +def opt2(expr): + return expr*z + + +def test_preprocess_for_cse(): + assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y + assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x + assert cse_main.preprocess_for_cse(x, [(None, None)]) == x + assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y + assert cse_main.preprocess_for_cse( + x, [(opt1, None), (opt2, None)]) == (x + y)*z + + +def test_postprocess_for_cse(): + assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x + assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x + y + assert cse_main.postprocess_for_cse(x, [(None, None)]) == x + assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z + # Note the reverse order of application. + assert cse_main.postprocess_for_cse( + x, [(None, opt1), (None, opt2)]) == x*z + y + + +def test_cse_single(): + # Simple substitution. + e = Add(Pow(x + y, 2), sqrt(x + y)) + substs, reduced = cse([e]) + assert substs == [(x0, x + y)] + assert reduced == [sqrt(x0) + x0**2] + + subst42, (red42,) = cse([42]) # issue_15082 + assert len(subst42) == 0 and red42 == 42 + subst_half, (red_half,) = cse([0.5]) + assert len(subst_half) == 0 and red_half == 0.5 + + +def test_cse_single2(): + # Simple substitution, test for being able to pass the expression directly + e = Add(Pow(x + y, 2), sqrt(x + y)) + substs, reduced = cse(e) + assert substs == [(x0, x + y)] + assert reduced == [sqrt(x0) + x0**2] + substs, reduced = cse(Matrix([[1]])) + assert isinstance(reduced[0], Matrix) + + subst42, (red42,) = cse(42) # issue 15082 + assert len(subst42) == 0 and red42 == 42 + subst_half, (red_half,) = cse(0.5) # issue 15082 + assert len(subst_half) == 0 and red_half == 0.5 + + +def test_cse_not_possible(): + # No substitution possible. + e = Add(x, y) + substs, reduced = cse([e]) + assert substs == [] + assert reduced == [x + y] + # issue 6329 + eq = (meijerg((1, 2), (y, 4), (5,), [], x) + + meijerg((1, 3), (y, 4), (5,), [], x)) + assert cse(eq) == ([], [eq]) + + +def test_nested_substitution(): + # Substitution within a substitution. + e = Add(Pow(w*x + y, 2), sqrt(w*x + y)) + substs, reduced = cse([e]) + assert substs == [(x0, w*x + y)] + assert reduced == [sqrt(x0) + x0**2] + + +def test_subtraction_opt(): + # Make sure subtraction is optimized. + e = (x - y)*(z - y) + exp((x - y)*(z - y)) + substs, reduced = cse( + [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) + assert substs == [(x0, (x - y)*(y - z))] + assert reduced == [-x0 + exp(-x0)] + e = -(x - y)*(z - y) + exp(-(x - y)*(z - y)) + substs, reduced = cse( + [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) + assert substs == [(x0, (x - y)*(y - z))] + assert reduced == [x0 + exp(x0)] + # issue 4077 + n = -1 + 1/x + e = n/x/(-n)**2 - 1/n/x + assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \ + ([], [0]) + assert cse(((w + x + y + z)*(w - y - z))/(w + x)**3) == \ + ([(x0, w + x), (x1, y + z)], [(w - x1)*(x0 + x1)/x0**3]) + + +def test_multiple_expressions(): + e1 = (x + y)*z + e2 = (x + y)*w + substs, reduced = cse([e1, e2]) + assert substs == [(x0, x + y)] + assert reduced == [x0*z, x0*w] + l = [w*x*y + z, w*y] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == rsubsts + assert reduced == [z + x*x0, x0] + l = [w*x*y, w*x*y + z, w*y] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == rsubsts + assert reduced == [x1, x1 + z, x0] + l = [(x - z)*(y - z), x - z, y - z] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)] + assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)] + assert reduced == [x1*x2, x1, x2] + l = [w*y + w + x + y + z, w*x*y] + assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0]) + assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0]) + assert cse([x + y, x + z]) == ([], [x + y, x + z]) + assert cse([x*y, z + x*y, x*y*z + 3]) == \ + ([(x0, x*y)], [x0, z + x0, 3 + x0*z]) + + +@XFAIL # CSE of non-commutative Mul terms is disabled +def test_non_commutative_cse(): + A, B, C = symbols('A B C', commutative=False) + l = [A*B*C, A*C] + assert cse(l) == ([], l) + l = [A*B*C, A*B] + assert cse(l) == ([(x0, A*B)], [x0*C, x0]) + + +# Test if CSE of non-commutative Mul terms is disabled +def test_bypass_non_commutatives(): + A, B, C = symbols('A B C', commutative=False) + l = [A*B*C, A*C] + assert cse(l) == ([], l) + l = [A*B*C, A*B] + assert cse(l) == ([], l) + l = [B*C, A*B*C] + assert cse(l) == ([], l) + + +@XFAIL # CSE fails when replacing non-commutative sub-expressions +def test_non_commutative_order(): + A, B, C = symbols('A B C', commutative=False) + x0 = symbols('x0', commutative=False) + l = [B+C, A*(B+C)] + assert cse(l) == ([(x0, B+C)], [x0, A*x0]) + + +@XFAIL # Worked in gh-11232, but was reverted due to performance considerations +def test_issue_10228(): + assert cse([x*y**2 + x*y]) == ([(x0, x*y)], [x0*y + x0]) + assert cse([x + y, 2*x + y]) == ([(x0, x + y)], [x0, x + x0]) + assert cse((w + 2*x + y + z, w + x + 1)) == ( + [(x0, w + x)], [x0 + x + y + z, x0 + 1]) + assert cse(((w + x + y + z)*(w - x))/(w + x)) == ( + [(x0, w + x)], [(x0 + y + z)*(w - x)/x0]) + a, b, c, d, f, g, j, m = symbols('a, b, c, d, f, g, j, m') + exprs = (d*g**2*j*m, 4*a*f*g*m, a*b*c*f**2) + assert cse(exprs) == ( + [(x0, g*m), (x1, a*f)], [d*g*j*x0, 4*x0*x1, b*c*f*x1] +) + +@XFAIL +def test_powers(): + assert cse(x*y**2 + x*y) == ([(x0, x*y)], [x0*y + x0]) + + +def test_issue_4498(): + assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \ + ([], [(w - z)/(x - y)]) + + +def test_issue_4020(): + assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \ + == ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)]) + + +def test_issue_4203(): + assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0]) + + +def test_issue_6263(): + e = Eq(x*(-x + 1) + x*(x - 1), 0) + assert cse(e, optimizations='basic') == ([], [True]) + + +def test_issue_25043(): + c = symbols("c") + x = symbols("x0", real=True) + cse_expr = cse(c*x**2 + c*(x**4 - x**2))[-1][-1] + free = cse_expr.free_symbols + assert len(free) == len({i.name for i in free}) + + +def test_dont_cse_tuples(): + from sympy.core.function import Subs + f = Function("f") + g = Function("g") + + name_val, (expr,) = cse( + Subs(f(x, y), (x, y), (0, 1)) + + Subs(g(x, y), (x, y), (0, 1))) + + assert name_val == [] + assert expr == (Subs(f(x, y), (x, y), (0, 1)) + + Subs(g(x, y), (x, y), (0, 1))) + + name_val, (expr,) = cse( + Subs(f(x, y), (x, y), (0, x + y)) + + Subs(g(x, y), (x, y), (0, x + y))) + + assert name_val == [(x0, x + y)] + assert expr == Subs(f(x, y), (x, y), (0, x0)) + \ + Subs(g(x, y), (x, y), (0, x0)) + + +def test_pow_invpow(): + assert cse(1/x**2 + x**2) == \ + ([(x0, x**2)], [x0 + 1/x0]) + assert cse(x**2 + (1 + 1/x**2)/x**2) == \ + ([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)]) + assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \ + ([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1]) + assert cse(cos(1/x**2) + sin(1/x**2)) == \ + ([(x0, x**(-2))], [sin(x0) + cos(x0)]) + assert cse(cos(x**2) + sin(x**2)) == \ + ([(x0, x**2)], [sin(x0) + cos(x0)]) + assert cse(y/(2 + x**2) + z/x**2/y) == \ + ([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)]) + assert cse(exp(x**2) + x**2*cos(1/x**2)) == \ + ([(x0, x**2)], [x0*cos(1/x0) + exp(x0)]) + assert cse((1 + 1/x**2)/x**2) == \ + ([(x0, x**(-2))], [x0*(x0 + 1)]) + assert cse(x**(2*y) + x**(-2*y)) == \ + ([(x0, x**(2*y))], [x0 + 1/x0]) + + +def test_postprocess(): + eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1)) + assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)], + postprocess=cse_main.cse_separate) == \ + [[(x0, y + 1), (x2, z + 1), (x, x2), (x1, x + 1)], + [x1 + exp(x1/x0) + cos(x0), z - 2, x1*x2]] + + +def test_issue_4499(): + # previously, this gave 16 constants + from sympy.abc import a, b + B = Function('B') + G = Function('G') + t = Tuple(* + (a, a + S.Half, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a - + b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1), + sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b, + sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1, + sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1), + (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1, + sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S.Half, z/2, -b + 1, -2*a + b, + -2*a)) + c = cse(t) + ans = ( + [(x0, 2*a), (x1, -b + x0), (x2, x1 + 1), (x3, b - 1), (x4, sqrt(z)), + (x5, B(x3, x4)), (x6, (x4/2)**(1 - x0)*G(b)*G(x2)), (x7, x6*B(x1, x4)), + (x8, B(b, x4)), (x9, x6*B(x2, x4))], + [(a, a + S.Half, x0, b, x2, x5*x7, x4*x7*x8, x4*x5*x9, x8*x9, + 1, 0, S.Half, z/2, -x3, -x1, -x0)]) + assert ans == c + + +def test_issue_6169(): + r = CRootOf(x**6 - 4*x**5 - 2, 1) + assert cse(r) == ([], [r]) + # and a check that the right thing is done with the new + # mechanism + assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y + + +def test_cse_Indexed(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + i = Idx('i', len_y-1) + + expr1 = (y[i+1]-y[i])/(x[i+1]-x[i]) + expr2 = 1/(x[i+1]-x[i]) + replacements, reduced_exprs = cse([expr1, expr2]) + assert len(replacements) > 0 + + +def test_cse_MatrixSymbol(): + # MatrixSymbols have non-Basic args, so make sure that works + A = MatrixSymbol("A", 3, 3) + assert cse(A) == ([], [A]) + + n = symbols('n', integer=True) + B = MatrixSymbol("B", n, n) + assert cse(B) == ([], [B]) + + assert cse(A[0] * A[0]) == ([], [A[0]*A[0]]) + + assert cse(A[0,0]*A[0,1] + A[0,0]*A[0,1]*A[0,2]) == ([(x0, A[0, 0]*A[0, 1])], [x0*A[0, 2] + x0]) + +def test_cse_MatrixExpr(): + A = MatrixSymbol('A', 3, 3) + y = MatrixSymbol('y', 3, 1) + + expr1 = (A.T*A).I * A * y + expr2 = (A.T*A) * A * y + replacements, reduced_exprs = cse([expr1, expr2]) + assert len(replacements) > 0 + + replacements, reduced_exprs = cse([expr1 + expr2, expr1]) + assert replacements + + replacements, reduced_exprs = cse([A**2, A + A**2]) + assert replacements + + +def test_Piecewise(): + f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True)) + ans = cse(f) + actual_ans = ([(x0, x*y)], + [Piecewise((x0 - z, Eq(y, 0)), (-z - x0, True))]) + assert ans == actual_ans + + +def test_ignore_order_terms(): + eq = exp(x).series(x,0,3) + sin(y+x**3) - 1 + assert cse(eq) == ([], [sin(x**3 + y) + x + x**2/2 + O(x**3)]) + + +def test_name_conflict(): + z1 = x0 + y + z2 = x2 + x3 + l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] + substs, reduced = cse(l) + assert [e.subs(reversed(substs)) for e in reduced] == l + + +def test_name_conflict_cust_symbols(): + z1 = x0 + y + z2 = x2 + x3 + l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] + substs, reduced = cse(l, symbols("x:10")) + assert [e.subs(reversed(substs)) for e in reduced] == l + + +def test_symbols_exhausted_error(): + l = cos(x+y)+x+y+cos(w+y)+sin(w+y) + sym = [x, y, z] + with raises(ValueError): + cse(l, symbols=sym) + + +def test_issue_7840(): + # daveknippers' example + C393 = sympify( \ + 'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \ + C391 > 2.35), (C392, True)), True))' + ) + C391 = sympify( \ + 'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))' + ) + C393 = C393.subs('C391',C391) + # simple substitution + sub = {} + sub['C390'] = 0.703451854 + sub['C392'] = 1.01417794 + ss_answer = C393.subs(sub) + # cse + substitutions,new_eqn = cse(C393) + for pair in substitutions: + sub[pair[0].name] = pair[1].subs(sub) + cse_answer = new_eqn[0].subs(sub) + # both methods should be the same + assert ss_answer == cse_answer + + # GitRay's example + expr = sympify( + "Piecewise((Symbol('ON'), Equality(Symbol('mode'), Symbol('ON'))), \ + (Piecewise((Piecewise((Symbol('OFF'), StrictLessThan(Symbol('x'), \ + Symbol('threshold'))), (Symbol('ON'), true)), Equality(Symbol('mode'), \ + Symbol('AUTO'))), (Symbol('OFF'), true)), true))" + ) + substitutions, new_eqn = cse(expr) + # this Piecewise should be exactly the same + assert new_eqn[0] == expr + # there should not be any replacements + assert len(substitutions) < 1 + + +def test_issue_8891(): + for cls in (MutableDenseMatrix, MutableSparseMatrix, + ImmutableDenseMatrix, ImmutableSparseMatrix): + m = cls(2, 2, [x + y, 0, 0, 0]) + res = cse([x + y, m]) + ans = ([(x0, x + y)], [x0, cls([[x0, 0], [0, 0]])]) + assert res == ans + assert isinstance(res[1][-1], cls) + + +def test_issue_11230(): + # a specific test that always failed + a, b, f, k, l, i = symbols('a b f k l i') + p = [a*b*f*k*l, a*i*k**2*l, f*i*k**2*l] + R, C = cse(p) + assert not any(i.is_Mul for a in C for i in a.args) + + # random tests for the issue + from sympy.core.random import choice + from sympy.core.function import expand_mul + s = symbols('a:m') + # 35 Mul tests, none of which should ever fail + ex = [Mul(*[choice(s) for i in range(5)]) for i in range(7)] + for p in subsets(ex, 3): + p = list(p) + R, C = cse(p) + assert not any(i.is_Mul for a in C for i in a.args) + for ri in reversed(R): + for i in range(len(C)): + C[i] = C[i].subs(*ri) + assert p == C + # 35 Add tests, none of which should ever fail + ex = [Add(*[choice(s[:7]) for i in range(5)]) for i in range(7)] + for p in subsets(ex, 3): + p = list(p) + R, C = cse(p) + assert not any(i.is_Add for a in C for i in a.args) + for ri in reversed(R): + for i in range(len(C)): + C[i] = C[i].subs(*ri) + # use expand_mul to handle cases like this: + # p = [a + 2*b + 2*e, 2*b + c + 2*e, b + 2*c + 2*g] + # x0 = 2*(b + e) is identified giving a rebuilt p that + # is now `[a + 2*(b + e), c + 2*(b + e), b + 2*c + 2*g]` + assert p == [expand_mul(i) for i in C] + + +@XFAIL +def test_issue_11577(): + def check(eq): + r, c = cse(eq) + assert eq.count_ops() >= \ + len(r) + sum(i[1].count_ops() for i in r) + \ + count_ops(c) + + eq = x**5*y**2 + x**5*y + x**5 + assert cse(eq) == ( + [(x0, x**4), (x1, x*y)], [x**5 + x0*x1*y + x0*x1]) + # ([(x0, x**5*y)], [x0*y + x0 + x**5]) or + # ([(x0, x**5)], [x0*y**2 + x0*y + x0]) + check(eq) + + eq = x**2/(y + 1)**2 + x/(y + 1) + assert cse(eq) == ( + [(x0, y + 1)], [x**2/x0**2 + x/x0]) + # ([(x0, x/(y + 1))], [x0**2 + x0]) + check(eq) + + +def test_hollow_rejection(): + eq = [x + 3, x + 4] + assert cse(eq) == ([], eq) + + +def test_cse_ignore(): + exprs = [exp(y)*(3*y + 3*sqrt(x+1)), exp(y)*(5*y + 5*sqrt(x+1))] + subst1, red1 = cse(exprs) + assert any(y in sub.free_symbols for _, sub in subst1), "cse failed to identify any term with y" + + subst2, red2 = cse(exprs, ignore=(y,)) # y is not allowed in substitutions + assert not any(y in sub.free_symbols for _, sub in subst2), "Sub-expressions containing y must be ignored" + assert any(sub - sqrt(x + 1) == 0 for _, sub in subst2), "cse failed to identify sqrt(x + 1) as sub-expression" + + +def test_cse_ignore_issue_15002(): + l = [ + w*exp(x)*exp(-z), + exp(y)*exp(x)*exp(-z) + ] + substs, reduced = cse(l, ignore=(x,)) + rl = [e.subs(reversed(substs)) for e in reduced] + assert rl == l + + +def test_cse_unevaluated(): + xp1 = UnevaluatedExpr(x + 1) + # This used to cause RecursionError + [(x0, ue)], [red] = cse([(-1 - xp1) / (1 - xp1)]) + if ue == xp1: + assert red == (-1 - x0) / (1 - x0) + elif ue == -xp1: + assert red == (-1 + x0) / (1 + x0) + else: + msg = f'Expected common subexpression {xp1} or {-xp1}, instead got {ue}' + assert False, msg + + +def test_cse__performance(): + nexprs, nterms = 3, 20 + x = symbols('x:%d' % nterms) + exprs = [ + reduce(add, [x[j]*(-1)**(i+j) for j in range(nterms)]) + for i in range(nexprs) + ] + assert (exprs[0] + exprs[1]).simplify() == 0 + subst, red = cse(exprs) + assert len(subst) > 0, "exprs[0] == -exprs[2], i.e. a CSE" + for i, e in enumerate(red): + assert (e.subs(reversed(subst)) - exprs[i]).simplify() == 0 + + +def test_issue_12070(): + exprs = [x + y, 2 + x + y, x + y + z, 3 + x + y + z] + subst, red = cse(exprs) + assert 6 >= (len(subst) + sum(v.count_ops() for k, v in subst) + + count_ops(red)) + + +def test_issue_13000(): + eq = x/(-4*x**2 + y**2) + cse_eq = cse(eq)[1][0] + assert cse_eq == eq + + +def test_issue_18203(): + eq = CRootOf(x**5 + 11*x - 2, 0) + CRootOf(x**5 + 11*x - 2, 1) + assert cse(eq) == ([], [eq]) + + +def test_unevaluated_mul(): + eq = Mul(x + y, x + y, evaluate=False) + assert cse(eq) == ([(x0, x + y)], [x0**2]) + + +def test_cse_release_variables(): + from sympy.simplify.cse_main import cse_release_variables + _0, _1, _2, _3, _4 = symbols('_:5') + eqs = [(x + y - 1)**2, x, + x + y, (x + y)/(2*x + 1) + (x + y - 1)**2, + (2*x + 1)**(x + y)] + r, e = cse(eqs, postprocess=cse_release_variables) + # this can change in keeping with the intention of the function + assert r, e == ([ + (x0, x + y), (x1, (x0 - 1)**2), (x2, 2*x + 1), + (_3, x0/x2 + x1), (_4, x2**x0), (x2, None), (_0, x1), + (x1, None), (_2, x0), (x0, None), (_1, x)], (_0, _1, _2, _3, _4)) + r.reverse() + r = [(s, v) for s, v in r if v is not None] + assert eqs == [i.subs(r) for i in e] + + +def test_cse_list(): + _cse = lambda x: cse(x, list=False) + assert _cse(x) == ([], x) + assert _cse('x') == ([], 'x') + it = [x] + for c in (list, tuple, set): + assert _cse(c(it)) == ([], c(it)) + #Tuple works different from tuple: + assert _cse(Tuple(*it)) == ([], Tuple(*it)) + d = {x: 1} + assert _cse(d) == ([], d) + +def test_issue_18991(): + A = MatrixSymbol('A', 2, 2) + assert signsimp(-A * A - A) == -A * A - A + + +def test_unevaluated_Mul(): + m = [Mul(1, 2, evaluate=False)] + assert cse(m) == ([], m) + + +def test_cse_matrix_expression_inverse(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = Inverse(A) + cse_expr = cse(x) + assert cse_expr == ([], [Inverse(A)]) + + +def test_cse_matrix_expression_matmul_inverse(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + b = ImmutableDenseMatrix(symbols('b:2')) + x = MatMul(Inverse(A), b) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +def test_cse_matrix_negate_matrix(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatMul(S.NegativeOne, A) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +def test_cse_matrix_negate_matmul_not_extracted(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2) + x = MatMul(S.NegativeOne, A, B) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +@XFAIL # No simplification rule for nested associative operations +def test_cse_matrix_nested_matmul_collapsed(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2) + x = MatMul(S.NegativeOne, MatMul(A, B)) + cse_expr = cse(x) + assert cse_expr == ([], [MatMul(S.NegativeOne, A, B)]) + + +def test_cse_matrix_optimize_out_single_argument_mul(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatMul(MatMul(MatMul(A))) + cse_expr = cse(x) + assert cse_expr == ([], [A]) + + +@XFAIL # Multiple simplification passed not supported in CSE +def test_cse_matrix_optimize_out_single_argument_mul_combined(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatAdd(MatMul(MatMul(MatMul(A))), MatMul(MatMul(A)), MatMul(A), A) + cse_expr = cse(x) + assert cse_expr == ([], [MatMul(4, A)]) + + +def test_cse_matrix_optimize_out_single_argument_add(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatAdd(MatAdd(MatAdd(MatAdd(A)))) + cse_expr = cse(x) + assert cse_expr == ([], [A]) + + +@XFAIL # Multiple simplification passed not supported in CSE +def test_cse_matrix_optimize_out_single_argument_add_combined(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatMul(MatAdd(MatAdd(MatAdd(A))), MatAdd(MatAdd(A)), MatAdd(A), A) + cse_expr = cse(x) + assert cse_expr == ([], [MatMul(4, A)]) + + +def test_cse_matrix_expression_matrix_solve(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + b = ImmutableDenseMatrix(symbols('b:2')) + x = MatrixSolve(A, b) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +def test_cse_matrix_matrix_expression(): + X = ImmutableDenseMatrix(symbols('X:4')).reshape(2, 2) + y = ImmutableDenseMatrix(symbols('y:2')) + b = MatMul(Inverse(MatMul(Transpose(X), X)), Transpose(X), y) + cse_expr = cse(b) + x0 = MatrixSymbol('x0', 2, 2) + reduced_expr_expected = MatMul(Inverse(MatMul(x0, X)), x0, y) + assert cse_expr == ([(x0, Transpose(X))], [reduced_expr_expected]) + + +def test_cse_matrix_kalman_filter(): + """Kalman Filter example from Matthew Rocklin's SciPy 2013 talk. + + Talk titled: "Matrix Expressions and BLAS/LAPACK; SciPy 2013 Presentation" + + Video: https://pyvideo.org/scipy-2013/matrix-expressions-and-blaslapack-scipy-2013-pr.html + + Notes + ===== + + Equations are: + + new_mu = mu + Sigma*H.T * (R + H*Sigma*H.T).I * (H*mu - data) + = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))) + new_Sigma = Sigma - Sigma*H.T * (R + H*Sigma*H.T).I * H * Sigma + = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H)), Inverse(MatAdd(R, MatMul(H*Sigma*Transpose(H)))), H, Sigma)) + + """ + N = 2 + mu = ImmutableDenseMatrix(symbols(f'mu:{N}')) + Sigma = ImmutableDenseMatrix(symbols(f'Sigma:{N * N}')).reshape(N, N) + H = ImmutableDenseMatrix(symbols(f'H:{N * N}')).reshape(N, N) + R = ImmutableDenseMatrix(symbols(f'R:{N * N}')).reshape(N, N) + data = ImmutableDenseMatrix(symbols(f'data:{N}')) + new_mu = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))) + new_Sigma = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), H, Sigma)) + cse_expr = cse([new_mu, new_Sigma]) + x0 = MatrixSymbol('x0', N, N) + x1 = MatrixSymbol('x1', N, N) + replacements_expected = [ + (x0, Transpose(H)), + (x1, Inverse(MatAdd(R, MatMul(H, Sigma, x0)))), + ] + reduced_exprs_expected = [ + MatAdd(mu, MatMul(Sigma, x0, x1, MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))), + MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, x0, x1, H, Sigma)), + ] + assert cse_expr == (replacements_expected, reduced_exprs_expected) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_cse_diff.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_cse_diff.py new file mode 100644 index 0000000000000000000000000000000000000000..92b2d3d6bbaafb838a5e75f32a214511a1d39567 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_cse_diff.py @@ -0,0 +1,206 @@ +"""Tests for the ``sympy.simplify._cse_diff.py`` module.""" + +import pytest + +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.numbers import Integer +from sympy.core.function import Function +from sympy.core import Derivative +from sympy.functions.elementary.exponential import exp +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.physics.mechanics import dynamicsymbols +from sympy.simplify._cse_diff import (_forward_jacobian, + _remove_cse_from_derivative, + _forward_jacobian_cse, + _forward_jacobian_norm_in_cse_out) +from sympy.simplify.simplify import simplify +from sympy.matrices import Matrix, eye + +from sympy.testing.pytest import raises +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.simplify.trigsimp import trigsimp + +from sympy import cse + + +w = Symbol('w') +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') + +q1, q2, q3 = dynamicsymbols('q1 q2 q3') + +# Define the custom functions +k = Function('k')(x, y) +f = Function('f')(k, z) + +zero = Integer(0) +one = Integer(1) +two = Integer(2) +neg_one = Integer(-1) + + +@pytest.mark.parametrize( + 'expr, wrt', + [ + ([zero], [x]), + ([one], [x]), + ([two], [x]), + ([neg_one], [x]), + ([x], [x]), + ([y], [x]), + ([x + y], [x]), + ([x*y], [x]), + ([x**2], [x]), + ([x**y], [x]), + ([exp(x)], [x]), + ([sin(x)], [x]), + ([tan(x)], [x]), + ([zero, one, x, y, x*y, x + y], [x, y]), + ([((x/y) + sin(x/y) - exp(y))*((x/y) - exp(y))], [x, y]), + ([w*tan(y*z)/(x - tan(y*z)), w*x*tan(y*z)/(x - tan(y*z))], [w, x, y, z]), + ([q1**2 + q2, q2**2 + q3, q3**2 + q1], [q1, q2, q3]), + ([f + Derivative(f, x) + k + 2*x], [x]) + ] +) + + +def test_forward_jacobian(expr, wrt): + expr = ImmutableDenseMatrix([expr]).T + wrt = ImmutableDenseMatrix([wrt]).T + jacobian = _forward_jacobian(expr, wrt) + zeros = ImmutableDenseMatrix.zeros(*jacobian.shape) + assert simplify(jacobian - expr.jacobian(wrt)) == zeros + + +def test_process_cse(): + x, y, z = symbols('x y z') + f = Function('f') + k = Function('k') + expr = Matrix([f(k(x,y), z) + Derivative(f(k(x,y), z), x) + k(x,y) + 2*x]) + repl, reduced = cse(expr) + p_repl, p_reduced = _remove_cse_from_derivative(repl, reduced) + + x0 = symbols('x0') + x1 = symbols('x1') + + expected_output = ( + [(x0, k(x, y)), (x1, f(x0, z))], + [Matrix([2 * x + x0 + x1 + Derivative(f(k(x, y), z), x)])] + ) + + assert p_repl == expected_output[0], f"Expected {expected_output[0]}, but got {p_repl}" + assert p_reduced == expected_output[1], f"Expected {expected_output[1]}, but got {p_reduced}" + + +def test_io_matrix_type(): + x, y, z = symbols('x y z') + expr = ImmutableDenseMatrix([ + x * y + y * z + x * y * z, + x ** 2 + y ** 2 + z ** 2, + x * y + x * z + y * z + ]) + wrt = ImmutableDenseMatrix([x, y, z]) + + replacements, reduced_expr = cse(expr) + + # Test _forward_jacobian_core + replacements_core, jacobian_core, precomputed_fs_core = _forward_jacobian_cse(replacements, reduced_expr, wrt) + assert isinstance(jacobian_core[0], type(reduced_expr[0])), "Jacobian should be a Matrix of the same type as the input" + + # Test _forward_jacobian_norm_in_dag_out + replacements_norm, jacobian_norm, precomputed_fs_norm = _forward_jacobian_norm_in_cse_out( + expr, wrt) + assert isinstance(jacobian_norm[0], type(reduced_expr[0])), "Jacobian should be a Matrix of the same type as the input" + + # Test _forward_jacobian + jacobian = _forward_jacobian(expr, wrt) + assert isinstance(jacobian, type(expr)), "Jacobian should be a Matrix of the same type as the input" + + +def test_forward_jacobian_input_output(): + x, y, z = symbols('x y z') + expr = Matrix([ + x * y + y * z + x * y * z, + x ** 2 + y ** 2 + z ** 2, + x * y + x * z + y * z + ]) + wrt = Matrix([x, y, z]) + + replacements, reduced_expr = cse(expr) + + # Test _forward_jacobian_core + replacements_core, jacobian_core, precomputed_fs_core = _forward_jacobian_cse(replacements, reduced_expr, wrt) + assert isinstance(replacements_core, type(replacements)), "Replacements should be a list" + assert isinstance(jacobian_core, type(reduced_expr)), "Jacobian should be a list" + assert isinstance(precomputed_fs_core, list), "Precomputed free symbols should be a list" + assert len(replacements_core) == len(replacements), "Length of replacements does not match" + assert len(jacobian_core) == 1, "Jacobian should have one element" + assert len(precomputed_fs_core) == len(replacements), "Length of precomputed free symbols does not match" + + # Test _forward_jacobian_norm_in_dag_out + replacements_norm, jacobian_norm, precomputed_fs_norm = _forward_jacobian_norm_in_cse_out(expr, wrt) + assert isinstance(replacements_norm, type(replacements)), "Replacements should be a list" + assert isinstance(jacobian_norm, type(reduced_expr)), "Jacobian should be a list" + assert isinstance(precomputed_fs_norm, list), "Precomputed free symbols should be a list" + assert len(replacements_norm) == len(replacements), "Length of replacements does not match" + assert len(jacobian_norm) == 1, "Jacobian should have one element" + assert len(precomputed_fs_norm) == len(replacements), "Length of precomputed free symbols does not match" + + +def test_jacobian_hessian(): + L = Matrix(1, 2, [x**2*y, 2*y**2 + x*y]) + syms = [x, y] + assert _forward_jacobian(L, syms) == Matrix([[2*x*y, x**2], [y, 4*y + x]]) + + L = Matrix(1, 2, [x, x**2*y**3]) + assert _forward_jacobian(L, syms) == Matrix([[1, 0], [2*x*y**3, x**2*3*y**2]]) + + +def test_jacobian_metrics(): + rho, phi = symbols("rho,phi") + X = Matrix([rho * cos(phi), rho * sin(phi)]) + Y = Matrix([rho, phi]) + J = _forward_jacobian(X, Y) + assert J == X.jacobian(Y.T) + assert J == (X.T).jacobian(Y) + assert J == (X.T).jacobian(Y.T) + g = J.T * eye(J.shape[0]) * J + g = g.applyfunc(trigsimp) + assert g == Matrix([[1, 0], [0, rho ** 2]]) + + +def test_jacobian2(): + rho, phi = symbols("rho,phi") + X = Matrix([rho * cos(phi), rho * sin(phi), rho ** 2]) + Y = Matrix([rho, phi]) + J = Matrix([ + [cos(phi), -rho * sin(phi)], + [sin(phi), rho * cos(phi)], + [2 * rho, 0], + ]) + assert _forward_jacobian(X, Y) == J + + +def test_issue_4564(): + X = Matrix([exp(x + y + z), exp(x + y + z), exp(x + y + z)]) + Y = Matrix([x, y, z]) + for i in range(1, 3): + for j in range(1, 3): + X_slice = X[:i, :] + Y_slice = Y[:j, :] + J = _forward_jacobian(X_slice, Y_slice) + assert J.rows == i + assert J.cols == j + for k in range(j): + assert J[:, k] == X_slice + + +def test_nonvectorJacobian(): + X = Matrix([[exp(x + y + z), exp(x + y + z)], + [exp(x + y + z), exp(x + y + z)]]) + raises(TypeError, lambda: _forward_jacobian(X, Matrix([x, y, z]))) + X = X[0, :] + Y = Matrix([[x, y], [x, z]]) + raises(TypeError, lambda: _forward_jacobian(X, Y)) + raises(TypeError, lambda: _forward_jacobian(X, Matrix([[x, y], [x, z]]))) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_epathtools.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_epathtools.py new file mode 100644 index 0000000000000000000000000000000000000000..a8bb47b2f2ff624077ab9905677b181c587ab5a7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_epathtools.py @@ -0,0 +1,90 @@ +"""Tests for tools for manipulation of expressions using paths. """ + +from sympy.simplify.epathtools import epath, EPath +from sympy.testing.pytest import raises + +from sympy.core.numbers import E +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.abc import x, y, z, t + + +def test_epath_select(): + expr = [((x, 1, t), 2), ((3, y, 4), z)] + + assert epath("/*", expr) == [((x, 1, t), 2), ((3, y, 4), z)] + assert epath("/*/*", expr) == [(x, 1, t), 2, (3, y, 4), z] + assert epath("/*/*/*", expr) == [x, 1, t, 3, y, 4] + assert epath("/*/*/*/*", expr) == [] + + assert epath("/[:]", expr) == [((x, 1, t), 2), ((3, y, 4), z)] + assert epath("/[:]/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z] + assert epath("/[:]/[:]/[:]", expr) == [x, 1, t, 3, y, 4] + assert epath("/[:]/[:]/[:]/[:]", expr) == [] + + assert epath("/*/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z] + + assert epath("/*/[0]", expr) == [(x, 1, t), (3, y, 4)] + assert epath("/*/[1]", expr) == [2, z] + assert epath("/*/[2]", expr) == [] + + assert epath("/*/int", expr) == [2] + assert epath("/*/Symbol", expr) == [z] + assert epath("/*/tuple", expr) == [(x, 1, t), (3, y, 4)] + assert epath("/*/__iter__?", expr) == [(x, 1, t), (3, y, 4)] + + assert epath("/*/int|tuple", expr) == [(x, 1, t), 2, (3, y, 4)] + assert epath("/*/Symbol|tuple", expr) == [(x, 1, t), (3, y, 4), z] + assert epath("/*/int|Symbol|tuple", expr) == [(x, 1, t), 2, (3, y, 4), z] + + assert epath("/*/int|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4)] + assert epath("/*/Symbol|__iter__?", expr) == [(x, 1, t), (3, y, 4), z] + assert epath( + "/*/int|Symbol|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4), z] + + assert epath("/*/[0]/int", expr) == [1, 3, 4] + assert epath("/*/[0]/Symbol", expr) == [x, t, y] + + assert epath("/*/[0]/int[1:]", expr) == [1, 4] + assert epath("/*/[0]/Symbol[1:]", expr) == [t, y] + + assert epath("/Symbol", x + y + z + 1) == [x, y, z] + assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E)) == [x, x, y] + + +def test_epath_apply(): + expr = [((x, 1, t), 2), ((3, y, 4), z)] + func = lambda expr: expr**2 + + assert epath("/*", expr, list) == [[(x, 1, t), 2], [(3, y, 4), z]] + + assert epath("/*/[0]", expr, list) == [([x, 1, t], 2), ([3, y, 4], z)] + assert epath("/*/[1]", expr, func) == [((x, 1, t), 4), ((3, y, 4), z**2)] + assert epath("/*/[2]", expr, list) == expr + + assert epath("/*/[0]/int", expr, func) == [((x, 1, t), 2), ((9, y, 16), z)] + assert epath("/*/[0]/Symbol", expr, func) == [((x**2, 1, t**2), 2), + ((3, y**2, 4), z)] + assert epath( + "/*/[0]/int[1:]", expr, func) == [((x, 1, t), 2), ((3, y, 16), z)] + assert epath("/*/[0]/Symbol[1:]", expr, func) == [((x, 1, t**2), + 2), ((3, y**2, 4), z)] + + assert epath("/Symbol", x + y + z + 1, func) == x**2 + y**2 + z**2 + 1 + assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E), func) == \ + t + sin(x**2 + 1) + cos(x**2 + y**2 + E) + + +def test_EPath(): + assert EPath("/*/[0]")._path == "/*/[0]" + assert EPath(EPath("/*/[0]"))._path == "/*/[0]" + assert isinstance(epath("/*/[0]"), EPath) is True + + assert repr(EPath("/*/[0]")) == "EPath('/*/[0]')" + + raises(ValueError, lambda: EPath("")) + raises(ValueError, lambda: EPath("/")) + raises(ValueError, lambda: EPath("/|x")) + raises(ValueError, lambda: EPath("/[")) + raises(ValueError, lambda: EPath("/[0]%")) + + raises(NotImplementedError, lambda: EPath("Symbol")) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_fu.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_fu.py new file mode 100644 index 0000000000000000000000000000000000000000..2de2126b7333195fceeffe72dc9cb642e7eba9a9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_fu.py @@ -0,0 +1,492 @@ +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.parameters import evaluate +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.hyperbolic import (cosh, coth, csch, sech, sinh, tanh) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (cos, cot, csc, sec, sin, tan) +from sympy.simplify.powsimp import powsimp +from sympy.simplify.fu import ( + L, TR1, TR10, TR10i, TR11, _TR11, TR12, TR12i, TR13, TR14, TR15, TR16, + TR111, TR2, TR2i, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TRmorrie, _TR56 as T, + TRpower, hyper_as_trig, fu, process_common_addends, trig_split, + as_f_sign_1) +from sympy.core.random import verify_numerically +from sympy.abc import a, b, c, x, y, z + + +def test_TR1(): + assert TR1(2*csc(x) + sec(x)) == 1/cos(x) + 2/sin(x) + + +def test_TR2(): + assert TR2(tan(x)) == sin(x)/cos(x) + assert TR2(cot(x)) == cos(x)/sin(x) + assert TR2(tan(tan(x) - sin(x)/cos(x))) == 0 + + +def test_TR2i(): + # just a reminder that ratios of powers only simplify if both + # numerator and denominator satisfy the condition that each + # has a positive base or an integer exponent; e.g. the following, + # at y=-1, x=1/2 gives sqrt(2)*I != -sqrt(2)*I + assert powsimp(2**x/y**x) != (2/y)**x + + assert TR2i(sin(x)/cos(x)) == tan(x) + assert TR2i(sin(x)*sin(y)/cos(x)) == tan(x)*sin(y) + assert TR2i(1/(sin(x)/cos(x))) == 1/tan(x) + assert TR2i(1/(sin(x)*sin(y)/cos(x))) == 1/tan(x)/sin(y) + assert TR2i(sin(x)/2/(cos(x) + 1)) == sin(x)/(cos(x) + 1)/2 + + assert TR2i(sin(x)/2/(cos(x) + 1), half=True) == tan(x/2)/2 + assert TR2i(sin(1)/(cos(1) + 1), half=True) == tan(S.Half) + assert TR2i(sin(2)/(cos(2) + 1), half=True) == tan(1) + assert TR2i(sin(4)/(cos(4) + 1), half=True) == tan(2) + assert TR2i(sin(5)/(cos(5) + 1), half=True) == tan(5*S.Half) + assert TR2i((cos(1) + 1)/sin(1), half=True) == 1/tan(S.Half) + assert TR2i((cos(2) + 1)/sin(2), half=True) == 1/tan(1) + assert TR2i((cos(4) + 1)/sin(4), half=True) == 1/tan(2) + assert TR2i((cos(5) + 1)/sin(5), half=True) == 1/tan(5*S.Half) + assert TR2i((cos(1) + 1)**(-a)*sin(1)**a, half=True) == tan(S.Half)**a + assert TR2i((cos(2) + 1)**(-a)*sin(2)**a, half=True) == tan(1)**a + assert TR2i((cos(4) + 1)**(-a)*sin(4)**a, half=True) == (cos(4) + 1)**(-a)*sin(4)**a + assert TR2i((cos(5) + 1)**(-a)*sin(5)**a, half=True) == (cos(5) + 1)**(-a)*sin(5)**a + assert TR2i((cos(1) + 1)**a*sin(1)**(-a), half=True) == tan(S.Half)**(-a) + assert TR2i((cos(2) + 1)**a*sin(2)**(-a), half=True) == tan(1)**(-a) + assert TR2i((cos(4) + 1)**a*sin(4)**(-a), half=True) == (cos(4) + 1)**a*sin(4)**(-a) + assert TR2i((cos(5) + 1)**a*sin(5)**(-a), half=True) == (cos(5) + 1)**a*sin(5)**(-a) + + i = symbols('i', integer=True) + assert TR2i(((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**(-i) + assert TR2i(1/((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**i + + +def test_TR3(): + assert TR3(cos(y - x*(y - x))) == cos(x*(x - y) + y) + assert cos(pi/2 + x) == -sin(x) + assert cos(30*pi/2 + x) == -cos(x) + + for f in (cos, sin, tan, cot, csc, sec): + i = f(pi*Rational(3, 7)) + j = TR3(i) + assert verify_numerically(i, j) and i.func != j.func + + with evaluate(False): + eq = cos(9*pi/22) + assert eq.has(9*pi) and TR3(eq) == sin(pi/11) + + +def test_TR4(): + for i in [0, pi/6, pi/4, pi/3, pi/2]: + with evaluate(False): + eq = cos(i) + assert isinstance(eq, cos) and TR4(eq) == cos(i) + + +def test__TR56(): + h = lambda x: 1 - x + assert T(sin(x)**3, sin, cos, h, 4, False) == sin(x)*(-cos(x)**2 + 1) + assert T(sin(x)**10, sin, cos, h, 4, False) == sin(x)**10 + assert T(sin(x)**6, sin, cos, h, 6, False) == (-cos(x)**2 + 1)**3 + assert T(sin(x)**6, sin, cos, h, 6, True) == sin(x)**6 + assert T(sin(x)**8, sin, cos, h, 10, True) == (-cos(x)**2 + 1)**4 + + # issue 17137 + assert T(sin(x)**I, sin, cos, h, 4, True) == sin(x)**I + assert T(sin(x)**(2*I + 1), sin, cos, h, 4, True) == sin(x)**(2*I + 1) + + +def test_TR5(): + assert TR5(sin(x)**2) == -cos(x)**2 + 1 + assert TR5(sin(x)**-2) == sin(x)**(-2) + assert TR5(sin(x)**4) == (-cos(x)**2 + 1)**2 + + +def test_TR6(): + assert TR6(cos(x)**2) == -sin(x)**2 + 1 + assert TR6(cos(x)**-2) == cos(x)**(-2) + assert TR6(cos(x)**4) == (-sin(x)**2 + 1)**2 + + +def test_TR7(): + assert TR7(cos(x)**2) == cos(2*x)/2 + S.Half + assert TR7(cos(x)**2 + 1) == cos(2*x)/2 + Rational(3, 2) + + +def test_TR8(): + assert TR8(cos(2)*cos(3)) == cos(5)/2 + cos(1)/2 + assert TR8(cos(2)*sin(3)) == sin(5)/2 + sin(1)/2 + assert TR8(sin(2)*sin(3)) == -cos(5)/2 + cos(1)/2 + assert TR8(sin(1)*sin(2)*sin(3)) == sin(4)/4 - sin(6)/4 + sin(2)/4 + assert TR8(cos(2)*cos(3)*cos(4)*cos(5)) == \ + cos(4)/4 + cos(10)/8 + cos(2)/8 + cos(8)/8 + cos(14)/8 + \ + cos(6)/8 + Rational(1, 8) + assert TR8(cos(2)*cos(3)*cos(4)*cos(5)*cos(6)) == \ + cos(10)/8 + cos(4)/8 + 3*cos(2)/16 + cos(16)/16 + cos(8)/8 + \ + cos(14)/16 + cos(20)/16 + cos(12)/16 + Rational(1, 16) + cos(6)/8 + assert TR8(sin(pi*Rational(3, 7))**2*cos(pi*Rational(3, 7))**2/(16*sin(pi/7)**2)) == Rational(1, 64) + +def test_TR9(): + a = S.Half + b = 3*a + assert TR9(a) == a + assert TR9(cos(1) + cos(2)) == 2*cos(a)*cos(b) + assert TR9(cos(1) - cos(2)) == 2*sin(a)*sin(b) + assert TR9(sin(1) - sin(2)) == -2*sin(a)*cos(b) + assert TR9(sin(1) + sin(2)) == 2*sin(b)*cos(a) + assert TR9(cos(1) + 2*sin(1) + 2*sin(2)) == cos(1) + 4*sin(b)*cos(a) + assert TR9(cos(4) + cos(2) + 2*cos(1)*cos(3)) == 4*cos(1)*cos(3) + assert TR9((cos(4) + cos(2))/cos(3)/2 + cos(3)) == 2*cos(1)*cos(2) + assert TR9(cos(3) + cos(4) + cos(5) + cos(6)) == \ + 4*cos(S.Half)*cos(1)*cos(Rational(9, 2)) + assert TR9(cos(3) + cos(3)*cos(2)) == cos(3) + cos(2)*cos(3) + assert TR9(-cos(y) + cos(x*y)) == -2*sin(x*y/2 - y/2)*sin(x*y/2 + y/2) + assert TR9(-sin(y) + sin(x*y)) == 2*sin(x*y/2 - y/2)*cos(x*y/2 + y/2) + c = cos(x) + s = sin(x) + for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)): + for a in ((c, s), (s, c), (cos(x), cos(x*y)), (sin(x), sin(x*y))): + args = zip(si, a) + ex = Add(*[Mul(*ai) for ai in args]) + t = TR9(ex) + assert not (a[0].func == a[1].func and ( + not verify_numerically(ex, t.expand(trig=True)) or t.is_Add) + or a[1].func != a[0].func and ex != t) + + +def test_TR10(): + assert TR10(cos(a + b)) == -sin(a)*sin(b) + cos(a)*cos(b) + assert TR10(sin(a + b)) == sin(a)*cos(b) + sin(b)*cos(a) + assert TR10(sin(a + b + c)) == \ + (-sin(a)*sin(b) + cos(a)*cos(b))*sin(c) + \ + (sin(a)*cos(b) + sin(b)*cos(a))*cos(c) + assert TR10(cos(a + b + c)) == \ + (-sin(a)*sin(b) + cos(a)*cos(b))*cos(c) - \ + (sin(a)*cos(b) + sin(b)*cos(a))*sin(c) + + +def test_TR10i(): + assert TR10i(cos(1)*cos(3) + sin(1)*sin(3)) == cos(2) + assert TR10i(cos(1)*cos(3) - sin(1)*sin(3)) == cos(4) + assert TR10i(cos(1)*sin(3) - sin(1)*cos(3)) == sin(2) + assert TR10i(cos(1)*sin(3) + sin(1)*cos(3)) == sin(4) + assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + 7) == sin(4) + 7 + assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + cos(3)) == cos(3) + sin(4) + assert TR10i(2*cos(1)*sin(3) + 2*sin(1)*cos(3) + cos(3)) == \ + 2*sin(4) + cos(3) + assert TR10i(cos(2)*cos(3) + sin(2)*(cos(1)*sin(2) + cos(2)*sin(1))) == \ + cos(1) + eq = (cos(2)*cos(3) + sin(2)*( + cos(1)*sin(2) + cos(2)*sin(1)))*cos(5) + sin(1)*sin(5) + assert TR10i(eq) == TR10i(eq.expand()) == cos(4) + assert TR10i(sqrt(2)*cos(x)*x + sqrt(6)*sin(x)*x) == \ + 2*sqrt(2)*x*sin(x + pi/6) + assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) + + cos(x)/sqrt(6)/3 + sin(x)/sqrt(2)/3) == 4*sqrt(6)*sin(x + pi/6)/9 + assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) + + cos(y)/sqrt(6)/3 + sin(y)/sqrt(2)/3) == \ + sqrt(6)*sin(x + pi/6)/3 + sqrt(6)*sin(y + pi/6)/9 + assert TR10i(cos(x) + sqrt(3)*sin(x) + 2*sqrt(3)*cos(x + pi/6)) == 4*cos(x) + assert TR10i(cos(x) + sqrt(3)*sin(x) + + 2*sqrt(3)*cos(x + pi/6) + 4*sin(x)) == 4*sqrt(2)*sin(x + pi/4) + assert TR10i(cos(2)*sin(3) + sin(2)*cos(4)) == \ + sin(2)*cos(4) + sin(3)*cos(2) + + A = Symbol('A', commutative=False) + assert TR10i(sqrt(2)*cos(x)*A + sqrt(6)*sin(x)*A) == \ + 2*sqrt(2)*sin(x + pi/6)*A + + + c = cos(x) + s = sin(x) + h = sin(y) + r = cos(y) + for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)): + for argsi in ((c*r, s*h), (c*h, s*r)): # explicit 2-args + args = zip(si, argsi) + ex = Add(*[Mul(*ai) for ai in args]) + t = TR10i(ex) + assert not (ex - t.expand(trig=True) or t.is_Add) + + c = cos(x) + s = sin(x) + h = sin(pi/6) + r = cos(pi/6) + for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)): + for argsi in ((c*r, s*h), (c*h, s*r)): # induced + args = zip(si, argsi) + ex = Add(*[Mul(*ai) for ai in args]) + t = TR10i(ex) + assert not (ex - t.expand(trig=True) or t.is_Add) + + +def test_TR11(): + + assert TR11(sin(2*x)) == 2*sin(x)*cos(x) + assert TR11(sin(4*x)) == 4*((-sin(x)**2 + cos(x)**2)*sin(x)*cos(x)) + assert TR11(sin(x*Rational(4, 3))) == \ + 4*((-sin(x/3)**2 + cos(x/3)**2)*sin(x/3)*cos(x/3)) + + assert TR11(cos(2*x)) == -sin(x)**2 + cos(x)**2 + assert TR11(cos(4*x)) == \ + (-sin(x)**2 + cos(x)**2)**2 - 4*sin(x)**2*cos(x)**2 + + assert TR11(cos(2)) == cos(2) + + assert TR11(cos(pi*Rational(3, 7)), pi*Rational(2, 7)) == -cos(pi*Rational(2, 7))**2 + sin(pi*Rational(2, 7))**2 + assert TR11(cos(4), 2) == -sin(2)**2 + cos(2)**2 + assert TR11(cos(6), 2) == cos(6) + assert TR11(sin(x)/cos(x/2), x/2) == 2*sin(x/2) + +def test__TR11(): + + assert _TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8))) == \ + 4*sin(x/8)*sin(x/6)*sin(2*x),_TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8))) + assert _TR11(sin(x/3)/cos(x/6)) == 2*sin(x/6) + + assert _TR11(cos(x/6)/sin(x/3)) == 1/(2*sin(x/6)) + assert _TR11(sin(2*x)*cos(x/8)/sin(x/4)) == sin(2*x)/(2*sin(x/8)), _TR11(sin(2*x)*cos(x/8)/sin(x/4)) + assert _TR11(sin(x)/sin(x/2)) == 2*cos(x/2) + + +def test_TR12(): + assert TR12(tan(x + y)) == (tan(x) + tan(y))/(-tan(x)*tan(y) + 1) + assert TR12(tan(x + y + z)) ==\ + (tan(z) + (tan(x) + tan(y))/(-tan(x)*tan(y) + 1))/( + 1 - (tan(x) + tan(y))*tan(z)/(-tan(x)*tan(y) + 1)) + assert TR12(tan(x*y)) == tan(x*y) + + +def test_TR13(): + assert TR13(tan(3)*tan(2)) == -tan(2)/tan(5) - tan(3)/tan(5) + 1 + assert TR13(cot(3)*cot(2)) == 1 + cot(3)*cot(5) + cot(2)*cot(5) + assert TR13(tan(1)*tan(2)*tan(3)) == \ + (-tan(2)/tan(5) - tan(3)/tan(5) + 1)*tan(1) + assert TR13(tan(1)*tan(2)*cot(3)) == \ + (-tan(2)/tan(3) + 1 - tan(1)/tan(3))*cot(3) + + +def test_L(): + assert L(cos(x) + sin(x)) == 2 + + +def test_fu(): + + assert fu(sin(50)**2 + cos(50)**2 + sin(pi/6)) == Rational(3, 2) + assert fu(sqrt(6)*cos(x) + sqrt(2)*sin(x)) == 2*sqrt(2)*sin(x + pi/3) + + + eq = sin(x)**4 - cos(y)**2 + sin(y)**2 + 2*cos(x)**2 + assert fu(eq) == cos(x)**4 - 2*cos(y)**2 + 2 + + assert fu(S.Half - cos(2*x)/2) == sin(x)**2 + + assert fu(sin(a)*(cos(b) - sin(b)) + cos(a)*(sin(b) + cos(b))) == \ + sqrt(2)*sin(a + b + pi/4) + + assert fu(sqrt(3)*cos(x)/2 + sin(x)/2) == sin(x + pi/3) + + assert fu(1 - sin(2*x)**2/4 - sin(y)**2 - cos(x)**4) == \ + -cos(x)**2 + cos(y)**2 + + assert fu(cos(pi*Rational(4, 9))) == sin(pi/18) + assert fu(cos(pi/9)*cos(pi*Rational(2, 9))*cos(pi*Rational(3, 9))*cos(pi*Rational(4, 9))) == Rational(1, 16) + + assert fu( + tan(pi*Rational(7, 18)) + tan(pi*Rational(5, 18)) - sqrt(3)*tan(pi*Rational(5, 18))*tan(pi*Rational(7, 18))) == \ + -sqrt(3) + + assert fu(tan(1)*tan(2)) == tan(1)*tan(2) + + expr = Mul(*[cos(2**i) for i in range(10)]) + assert fu(expr) == sin(1024)/(1024*sin(1)) + + # issue #18059: + assert fu(cos(x) + sqrt(sin(x)**2)) == cos(x) + sqrt(sin(x)**2) + + assert fu((-14*sin(x)**3 + 35*sin(x) + 6*sqrt(3)*cos(x)**3 + 9*sqrt(3)*cos(x))/((cos(2*x) + 4))) == \ + 7*sin(x) + 3*sqrt(3)*cos(x) + + +def test_objective(): + assert fu(sin(x)/cos(x), measure=lambda x: x.count_ops()) == \ + tan(x) + assert fu(sin(x)/cos(x), measure=lambda x: -x.count_ops()) == \ + sin(x)/cos(x) + + +def test_process_common_addends(): + # this tests that the args are not evaluated as they are given to do + # and that key2 works when key1 is False + do = lambda x: Add(*[i**(i%2) for i in x.args]) + assert process_common_addends(Add(*[1, 2, 3, 4], evaluate=False), do, + key2=lambda x: x%2, key1=False) == 1**1 + 3**1 + 2**0 + 4**0 + + +def test_trig_split(): + assert trig_split(cos(x), cos(y)) == (1, 1, 1, x, y, True) + assert trig_split(2*cos(x), -2*cos(y)) == (2, 1, -1, x, y, True) + assert trig_split(cos(x)*sin(y), cos(y)*sin(y)) == \ + (sin(y), 1, 1, x, y, True) + + assert trig_split(cos(x), -sqrt(3)*sin(x), two=True) == \ + (2, 1, -1, x, pi/6, False) + assert trig_split(cos(x), sin(x), two=True) == \ + (sqrt(2), 1, 1, x, pi/4, False) + assert trig_split(cos(x), -sin(x), two=True) == \ + (sqrt(2), 1, -1, x, pi/4, False) + assert trig_split(sqrt(2)*cos(x), -sqrt(6)*sin(x), two=True) == \ + (2*sqrt(2), 1, -1, x, pi/6, False) + assert trig_split(-sqrt(6)*cos(x), -sqrt(2)*sin(x), two=True) == \ + (-2*sqrt(2), 1, 1, x, pi/3, False) + assert trig_split(cos(x)/sqrt(6), sin(x)/sqrt(2), two=True) == \ + (sqrt(6)/3, 1, 1, x, pi/6, False) + assert trig_split(-sqrt(6)*cos(x)*sin(y), + -sqrt(2)*sin(x)*sin(y), two=True) == \ + (-2*sqrt(2)*sin(y), 1, 1, x, pi/3, False) + + assert trig_split(cos(x), sin(x)) is None + assert trig_split(cos(x), sin(z)) is None + assert trig_split(2*cos(x), -sin(x)) is None + assert trig_split(cos(x), -sqrt(3)*sin(x)) is None + assert trig_split(cos(x)*cos(y), sin(x)*sin(z)) is None + assert trig_split(cos(x)*cos(y), sin(x)*sin(y)) is None + assert trig_split(-sqrt(6)*cos(x), sqrt(2)*sin(x)*sin(y), two=True) is \ + None + + assert trig_split(sqrt(3)*sqrt(x), cos(3), two=True) is None + assert trig_split(sqrt(3)*root(x, 3), sin(3)*cos(2), two=True) is None + assert trig_split(cos(5)*cos(6), cos(7)*sin(5), two=True) is None + + +def test_TRmorrie(): + assert TRmorrie(7*Mul(*[cos(i) for i in range(10)])) == \ + 7*sin(12)*sin(16)*cos(5)*cos(7)*cos(9)/(64*sin(1)*sin(3)) + assert TRmorrie(x) == x + assert TRmorrie(2*x) == 2*x + e = cos(pi/7)*cos(pi*Rational(2, 7))*cos(pi*Rational(4, 7)) + assert TR8(TRmorrie(e)) == Rational(-1, 8) + e = Mul(*[cos(2**i*pi/17) for i in range(1, 17)]) + assert TR8(TR3(TRmorrie(e))) == Rational(1, 65536) + # issue 17063 + eq = cos(x)/cos(x/2) + assert TRmorrie(eq) == eq + # issue #20430 + eq = cos(x/2)*sin(x/2)*cos(x)**3 + assert TRmorrie(eq) == sin(2*x)*cos(x)**2/4 + + +def test_TRpower(): + assert TRpower(1/sin(x)**2) == 1/sin(x)**2 + assert TRpower(cos(x)**3*sin(x/2)**4) == \ + (3*cos(x)/4 + cos(3*x)/4)*(-cos(x)/2 + cos(2*x)/8 + Rational(3, 8)) + for k in range(2, 8): + assert verify_numerically(sin(x)**k, TRpower(sin(x)**k)) + assert verify_numerically(cos(x)**k, TRpower(cos(x)**k)) + + +def test_hyper_as_trig(): + from sympy.simplify.fu import _osborne, _osbornei + + eq = sinh(x)**2 + cosh(x)**2 + t, f = hyper_as_trig(eq) + assert f(fu(t)) == cosh(2*x) + e, f = hyper_as_trig(tanh(x + y)) + assert f(TR12(e)) == (tanh(x) + tanh(y))/(tanh(x)*tanh(y) + 1) + + d = Dummy() + assert _osborne(sinh(x), d) == I*sin(x*d) + assert _osborne(tanh(x), d) == I*tan(x*d) + assert _osborne(coth(x), d) == cot(x*d)/I + assert _osborne(cosh(x), d) == cos(x*d) + assert _osborne(sech(x), d) == sec(x*d) + assert _osborne(csch(x), d) == csc(x*d)/I + for func in (sinh, cosh, tanh, coth, sech, csch): + h = func(pi) + assert _osbornei(_osborne(h, d), d) == h + # /!\ the _osborne functions are not meant to work + # in the o(i(trig, d), d) direction so we just check + # that they work as they are supposed to work + assert _osbornei(cos(x*y + z), y) == cosh(x + z*I) + assert _osbornei(sin(x*y + z), y) == sinh(x + z*I)/I + assert _osbornei(tan(x*y + z), y) == tanh(x + z*I)/I + assert _osbornei(cot(x*y + z), y) == coth(x + z*I)*I + assert _osbornei(sec(x*y + z), y) == sech(x + z*I) + assert _osbornei(csc(x*y + z), y) == csch(x + z*I)*I + + +def test_TR12i(): + ta, tb, tc = [tan(i) for i in (a, b, c)] + assert TR12i((ta + tb)/(-ta*tb + 1)) == tan(a + b) + assert TR12i((ta + tb)/(ta*tb - 1)) == -tan(a + b) + assert TR12i((-ta - tb)/(ta*tb - 1)) == tan(a + b) + eq = (ta + tb)/(-ta*tb + 1)**2*(-3*ta - 3*tc)/(2*(ta*tc - 1)) + assert TR12i(eq.expand()) == \ + -3*tan(a + b)*tan(a + c)/(tan(a) + tan(b) - 1)/2 + assert TR12i(tan(x)/sin(x)) == tan(x)/sin(x) + eq = (ta + cos(2))/(-ta*tb + 1) + assert TR12i(eq) == eq + eq = (ta + tb + 2)**2/(-ta*tb + 1) + assert TR12i(eq) == eq + eq = ta/(-ta*tb + 1) + assert TR12i(eq) == eq + eq = (((ta + tb)*(a + 1)).expand())**2/(ta*tb - 1) + assert TR12i(eq) == -(a + 1)**2*tan(a + b) + + +def test_TR14(): + eq = (cos(x) - 1)*(cos(x) + 1) + ans = -sin(x)**2 + assert TR14(eq) == ans + assert TR14(1/eq) == 1/ans + assert TR14((cos(x) - 1)**2*(cos(x) + 1)**2) == ans**2 + assert TR14((cos(x) - 1)**2*(cos(x) + 1)**3) == ans**2*(cos(x) + 1) + assert TR14((cos(x) - 1)**3*(cos(x) + 1)**2) == ans**2*(cos(x) - 1) + eq = (cos(x) - 1)**y*(cos(x) + 1)**y + assert TR14(eq) == eq + eq = (cos(x) - 2)**y*(cos(x) + 1) + assert TR14(eq) == eq + eq = (tan(x) - 2)**2*(cos(x) + 1) + assert TR14(eq) == eq + i = symbols('i', integer=True) + assert TR14((cos(x) - 1)**i*(cos(x) + 1)**i) == ans**i + assert TR14((sin(x) - 1)**i*(sin(x) + 1)**i) == (-cos(x)**2)**i + # could use extraction in this case + eq = (cos(x) - 1)**(i + 1)*(cos(x) + 1)**i + assert TR14(eq) in [(cos(x) - 1)*ans**i, eq] + + assert TR14((sin(x) - 1)*(sin(x) + 1)) == -cos(x)**2 + p1 = (cos(x) + 1)*(cos(x) - 1) + p2 = (cos(y) - 1)*2*(cos(y) + 1) + p3 = (3*(cos(y) - 1))*(3*(cos(y) + 1)) + assert TR14(p1*p2*p3*(x - 1)) == -18*((x - 1)*sin(x)**2*sin(y)**4) + + +def test_TR15_16_17(): + assert TR15(1 - 1/sin(x)**2) == -cot(x)**2 + assert TR16(1 - 1/cos(x)**2) == -tan(x)**2 + assert TR111(1 - 1/tan(x)**2) == 1 - cot(x)**2 + + +def test_as_f_sign_1(): + assert as_f_sign_1(x + 1) == (1, x, 1) + assert as_f_sign_1(x - 1) == (1, x, -1) + assert as_f_sign_1(-x + 1) == (-1, x, -1) + assert as_f_sign_1(-x - 1) == (-1, x, 1) + assert as_f_sign_1(2*x + 2) == (2, x, 1) + assert as_f_sign_1(x*y - y) == (y, x, -1) + assert as_f_sign_1(-x*y + y) == (-y, x, -1) + + +def test_issue_25590(): + A = Symbol('A', commutative=False) + B = Symbol('B', commutative=False) + + assert TR8(2*cos(x)*sin(x)*B*A) == sin(2*x)*B*A + assert TR13(tan(2)*tan(3)*B*A) == (-tan(2)/tan(5) - tan(3)/tan(5) + 1)*B*A + + # XXX The result may not be optimal than + # sin(2*x)*B*A + cos(x)**2 and may change in the future + assert (2*cos(x)*sin(x)*B*A + cos(x)**2).simplify() == sin(2*x)*B*A + cos(2*x)/2 + S.One/2 diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_function.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_function.py new file mode 100644 index 0000000000000000000000000000000000000000..441b9faf1bb3c5e7f2279b2a61066d050e45f773 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_function.py @@ -0,0 +1,54 @@ +""" Unit tests for Hyper_Function""" +from sympy.core import symbols, Dummy, Tuple, S, Rational +from sympy.functions import hyper + +from sympy.simplify.hyperexpand import Hyper_Function + +def test_attrs(): + a, b = symbols('a, b', cls=Dummy) + f = Hyper_Function([2, a], [b]) + assert f.ap == Tuple(2, a) + assert f.bq == Tuple(b) + assert f.args == (Tuple(2, a), Tuple(b)) + assert f.sizes == (2, 1) + +def test_call(): + a, b, x = symbols('a, b, x', cls=Dummy) + f = Hyper_Function([2, a], [b]) + assert f(x) == hyper([2, a], [b], x) + +def test_has(): + a, b, c = symbols('a, b, c', cls=Dummy) + f = Hyper_Function([2, -a], [b]) + assert f.has(a) + assert f.has(Tuple(b)) + assert not f.has(c) + +def test_eq(): + assert Hyper_Function([1], []) == Hyper_Function([1], []) + assert (Hyper_Function([1], []) != Hyper_Function([1], [])) is False + assert Hyper_Function([1], []) != Hyper_Function([2], []) + assert Hyper_Function([1], []) != Hyper_Function([1, 2], []) + assert Hyper_Function([1], []) != Hyper_Function([1], [2]) + +def test_gamma(): + assert Hyper_Function([2, 3], [-1]).gamma == 0 + assert Hyper_Function([-2, -3], [-1]).gamma == 2 + n = Dummy(integer=True) + assert Hyper_Function([-1, n, 1], []).gamma == 1 + assert Hyper_Function([-1, -n, 1], []).gamma == 1 + p = Dummy(integer=True, positive=True) + assert Hyper_Function([-1, p, 1], []).gamma == 1 + assert Hyper_Function([-1, -p, 1], []).gamma == 2 + +def test_suitable_origin(): + assert Hyper_Function((S.Half,), (Rational(3, 2),))._is_suitable_origin() is True + assert Hyper_Function((S.Half,), (S.Half,))._is_suitable_origin() is False + assert Hyper_Function((S.Half,), (Rational(-1, 2),))._is_suitable_origin() is False + assert Hyper_Function((S.Half,), (0,))._is_suitable_origin() is False + assert Hyper_Function((S.Half,), (-1, 1,))._is_suitable_origin() is False + assert Hyper_Function((S.Half, 0), (1,))._is_suitable_origin() is False + assert Hyper_Function((S.Half, 1), + (2, Rational(-2, 3)))._is_suitable_origin() is True + assert Hyper_Function((S.Half, 1), + (2, Rational(-2, 3), Rational(3, 2)))._is_suitable_origin() is True diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_gammasimp.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_gammasimp.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c73093250b279510e3c2274db22818a9adffd8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_gammasimp.py @@ -0,0 +1,127 @@ +from sympy.core.function import Function +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import (rf, binomial, factorial) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.gamma_functions import gamma +from sympy.simplify.gammasimp import gammasimp +from sympy.simplify.powsimp import powsimp +from sympy.simplify.simplify import simplify + +from sympy.abc import x, y, n, k + + +def test_gammasimp(): + R = Rational + + # was part of test_combsimp_gamma() in test_combsimp.py + assert gammasimp(gamma(x)) == gamma(x) + assert gammasimp(gamma(x + 1)/x) == gamma(x) + assert gammasimp(gamma(x)/(x - 1)) == gamma(x - 1) + assert gammasimp(x*gamma(x)) == gamma(x + 1) + assert gammasimp((x + 1)*gamma(x + 1)) == gamma(x + 2) + assert gammasimp(gamma(x + y)*(x + y)) == gamma(x + y + 1) + assert gammasimp(x/gamma(x + 1)) == 1/gamma(x) + assert gammasimp((x + 1)**2/gamma(x + 2)) == (x + 1)/gamma(x + 1) + assert gammasimp(x*gamma(x) + gamma(x + 3)/(x + 2)) == \ + (x + 2)*gamma(x + 1) + + assert gammasimp(gamma(2*x)*x) == gamma(2*x + 1)/2 + assert gammasimp(gamma(2*x)/(x - S.Half)) == 2*gamma(2*x - 1) + + assert gammasimp(gamma(x)*gamma(1 - x)) == pi/sin(pi*x) + assert gammasimp(gamma(x)*gamma(-x)) == -pi/(x*sin(pi*x)) + assert gammasimp(1/gamma(x + 3)/gamma(1 - x)) == \ + sin(pi*x)/(pi*x*(x + 1)*(x + 2)) + + assert gammasimp(factorial(n + 2)) == gamma(n + 3) + assert gammasimp(binomial(n, k)) == \ + gamma(n + 1)/(gamma(k + 1)*gamma(-k + n + 1)) + + assert powsimp(gammasimp( + gamma(x)*gamma(x + S.Half)*gamma(y)/gamma(x + y))) == \ + 2**(-2*x + 1)*sqrt(pi)*gamma(2*x)*gamma(y)/gamma(x + y) + assert gammasimp(1/gamma(x)/gamma(x - Rational(1, 3))/gamma(x + Rational(1, 3))) == \ + 3**(3*x - Rational(3, 2))/(2*pi*gamma(3*x - 1)) + assert simplify( + gamma(S.Half + x/2)*gamma(1 + x/2)/gamma(1 + x)/sqrt(pi)*2**x) == 1 + assert gammasimp(gamma(Rational(-1, 4))*gamma(Rational(-3, 4))) == 16*sqrt(2)*pi/3 + + assert powsimp(gammasimp(gamma(2*x)/gamma(x))) == \ + 2**(2*x - 1)*gamma(x + S.Half)/sqrt(pi) + + # issue 6792 + e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2 + assert gammasimp(e) == -k + assert gammasimp(1/e) == -1/k + e = (gamma(x) + gamma(x + 1))/gamma(x) + assert gammasimp(e) == x + 1 + assert gammasimp(1/e) == 1/(x + 1) + e = (gamma(x) + gamma(x + 2))*(gamma(x - 1) + gamma(x))/gamma(x) + assert gammasimp(e) == (x**2 + x + 1)*gamma(x + 1)/(x - 1) + e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2 + assert gammasimp(e**2) == k**2 + assert gammasimp(e**2/gamma(k + 1)) == k/gamma(k) + a = R(1, 2) + R(1, 3) + b = a + R(1, 3) + assert gammasimp(gamma(2*k)/gamma(k)*gamma(k + a)*gamma(k + b) + ) == 3*2**(2*k + 1)*3**(-3*k - 2)*sqrt(pi)*gamma(3*k + R(3, 2))/2 + + # issue 9699 + assert gammasimp((x + 1)*factorial(x)/gamma(y)) == gamma(x + 2)/gamma(y) + assert gammasimp(rf(x + n, k)*binomial(n, k)).simplify() == Piecewise( + (gamma(n + 1)*gamma(k + n + x)/(gamma(k + 1)*gamma(n + x)*gamma(-k + n + 1)), n > -x), + ((-1)**k*gamma(n + 1)*gamma(-n - x + 1)/(gamma(k + 1)*gamma(-k + n + 1)*gamma(-k - n - x + 1)), True)) + + A, B = symbols('A B', commutative=False) + assert gammasimp(e*B*A) == gammasimp(e)*B*A + + # check iteration + assert gammasimp(gamma(2*k)/gamma(k)*gamma(-k - R(1, 2))) == ( + -2**(2*k + 1)*sqrt(pi)/(2*((2*k + 1)*cos(pi*k)))) + assert gammasimp( + gamma(k)*gamma(k + R(1, 3))*gamma(k + R(2, 3))/gamma(k*R(3, 2))) == ( + 3*2**(3*k + 1)*3**(-3*k - S.Half)*sqrt(pi)*gamma(k*R(3, 2) + S.Half)/2) + + # issue 6153 + assert gammasimp(gamma(Rational(1, 4))/gamma(Rational(5, 4))) == 4 + + # was part of test_combsimp() in test_combsimp.py + assert gammasimp(binomial(n + 2, k + S.Half)) == gamma(n + 3)/ \ + (gamma(k + R(3, 2))*gamma(-k + n + R(5, 2))) + assert gammasimp(binomial(n + 2, k + 2.0)) == \ + gamma(n + 3)/(gamma(k + 3.0)*gamma(-k + n + 1)) + + # issue 11548 + assert gammasimp(binomial(0, x)) == sin(pi*x)/(pi*x) + + e = gamma(n + Rational(1, 3))*gamma(n + R(2, 3)) + assert gammasimp(e) == e + assert gammasimp(gamma(4*n + S.Half)/gamma(2*n - R(3, 4))) == \ + 2**(4*n - R(5, 2))*(8*n - 3)*gamma(2*n + R(3, 4))/sqrt(pi) + + i, m = symbols('i m', integer = True) + e = gamma(exp(i)) + assert gammasimp(e) == e + e = gamma(m + 3) + assert gammasimp(e) == e + e = gamma(m + 1)/(gamma(i + 1)*gamma(-i + m + 1)) + assert gammasimp(e) == e + + p = symbols("p", integer=True, positive=True) + assert gammasimp(gamma(-p + 4)) == gamma(-p + 4) + + +def test_issue_22606(): + fx = Function('f')(x) + eq = x + gamma(y) + # seems like ans should be `eq`, not `(x*y + gamma(y + 1))/y` + ans = gammasimp(eq) + assert gammasimp(eq.subs(x, fx)).subs(fx, x) == ans + assert gammasimp(eq.subs(x, cos(x))).subs(cos(x), x) == ans + assert 1/gammasimp(1/eq) == ans + assert gammasimp(fx.subs(x, eq)).args[0] == ans diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_hyperexpand.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_hyperexpand.py new file mode 100644 index 0000000000000000000000000000000000000000..c703c228a13201de13cfd4c3413fc75a2cf5bdb6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_hyperexpand.py @@ -0,0 +1,1063 @@ +from sympy.core.random import randrange + +from sympy.simplify.hyperexpand import (ShiftA, ShiftB, UnShiftA, UnShiftB, + MeijerShiftA, MeijerShiftB, MeijerShiftC, MeijerShiftD, + MeijerUnShiftA, MeijerUnShiftB, MeijerUnShiftC, + MeijerUnShiftD, + ReduceOrder, reduce_order, apply_operators, + devise_plan, make_derivative_operator, Formula, + hyperexpand, Hyper_Function, G_Function, + reduce_order_meijer, + build_hypergeometric_formula) +from sympy.concrete.summations import Sum +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.numbers import I +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.hyper import (hyper, meijerg) +from sympy.abc import z, a, b, c +from sympy.testing.pytest import XFAIL, raises, slow, tooslow +from sympy.core.random import verify_numerically as tn + +from sympy.core.numbers import (Rational, pi) +from sympy.functions.elementary.exponential import (exp, exp_polar, log) +from sympy.functions.elementary.hyperbolic import atanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (asin, cos, sin) +from sympy.functions.special.bessel import besseli +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import (gamma, lowergamma) + + +def test_branch_bug(): + assert hyperexpand(hyper((Rational(-1, 3), S.Half), (Rational(2, 3), Rational(3, 2)), -z)) == \ + -z**S('1/3')*lowergamma(exp_polar(I*pi)/3, z)/5 \ + + sqrt(pi)*erf(sqrt(z))/(5*sqrt(z)) + assert hyperexpand(meijerg([Rational(7, 6), 1], [], [Rational(2, 3)], [Rational(1, 6), 0], z)) == \ + 2*z**S('2/3')*(2*sqrt(pi)*erf(sqrt(z))/sqrt(z) - 2*lowergamma( + Rational(2, 3), z)/z**S('2/3'))*gamma(Rational(2, 3))/gamma(Rational(5, 3)) + + +def test_hyperexpand(): + # Luke, Y. L. (1969), The Special Functions and Their Approximations, + # Volume 1, section 6.2 + + assert hyperexpand(hyper([], [], z)) == exp(z) + assert hyperexpand(hyper([1, 1], [2], -z)*z) == log(1 + z) + assert hyperexpand(hyper([], [S.Half], -z**2/4)) == cos(z) + assert hyperexpand(z*hyper([], [S('3/2')], -z**2/4)) == sin(z) + assert hyperexpand(hyper([S('1/2'), S('1/2')], [S('3/2')], z**2)*z) \ + == asin(z) + assert isinstance(Sum(binomial(2, z)*z**2, (z, 0, a)).doit(), Expr) + + +def can_do(ap, bq, numerical=True, div=1, lowerplane=False): + r = hyperexpand(hyper(ap, bq, z)) + if r.has(hyper): + return False + if not numerical: + return True + repl = {} + randsyms = r.free_symbols - {z} + while randsyms: + # Only randomly generated parameters are checked. + for n, ai in enumerate(randsyms): + repl[ai] = randcplx(n)/div + if not any(b.is_Integer and b <= 0 for b in Tuple(*bq).subs(repl)): + break + [a, b, c, d] = [2, -1, 3, 1] + if lowerplane: + [a, b, c, d] = [2, -2, 3, -1] + return tn( + hyper(ap, bq, z).subs(repl), + r.replace(exp_polar, exp).subs(repl), + z, a=a, b=b, c=c, d=d) + + +def test_roach(): + # Kelly B. Roach. Meijer G Function Representations. + # Section "Gallery" + assert can_do([S.Half], [Rational(9, 2)]) + assert can_do([], [1, Rational(5, 2), 4]) + assert can_do([Rational(-1, 2), 1, 2], [3, 4]) + assert can_do([Rational(1, 3)], [Rational(-2, 3), Rational(-1, 2), S.Half, 1]) + assert can_do([Rational(-3, 2), Rational(-1, 2)], [Rational(-5, 2), 1]) + assert can_do([Rational(-3, 2), ], [Rational(-1, 2), S.Half]) # shine-integral + assert can_do([Rational(-3, 2), Rational(-1, 2)], [2]) # elliptic integrals + + +@XFAIL +def test_roach_fail(): + assert can_do([Rational(-1, 2), 1], [Rational(1, 4), S.Half, Rational(3, 4)]) # PFDD + assert can_do([Rational(3, 2)], [Rational(5, 2), 5]) # struve function + assert can_do([Rational(-1, 2), S.Half, 1], [Rational(3, 2), Rational(5, 2)]) # polylog, pfdd + assert can_do([1, 2, 3], [S.Half, 4]) # XXX ? + assert can_do([S.Half], [Rational(-1, 3), Rational(-1, 2), Rational(-2, 3)]) # PFDD ? + +# For the long table tests, see end of file + + +def test_polynomial(): + from sympy.core.numbers import oo + assert hyperexpand(hyper([], [-1], z)) is oo + assert hyperexpand(hyper([-2], [-1], z)) is oo + assert hyperexpand(hyper([0, 0], [-1], z)) == 1 + assert can_do([-5, -2, randcplx(), randcplx()], [-10, randcplx()]) + assert hyperexpand(hyper((-1, 1), (-2,), z)) == 1 + z/2 + + +def test_hyperexpand_bases(): + assert hyperexpand(hyper([2], [a], z)) == \ + a + z**(-a + 1)*(-a**2 + 3*a + z*(a - 1) - 2)*exp(z)* \ + lowergamma(a - 1, z) - 1 + # TODO [a+1, aRational(-1, 2)], [2*a] + assert hyperexpand(hyper([1, 2], [3], z)) == -2/z - 2*log(-z + 1)/z**2 + assert hyperexpand(hyper([S.Half, 2], [Rational(3, 2)], z)) == \ + -1/(2*z - 2) + atanh(sqrt(z))/sqrt(z)/2 + assert hyperexpand(hyper([S.Half, S.Half], [Rational(5, 2)], z)) == \ + (-3*z + 3)/4/(z*sqrt(-z + 1)) \ + + (6*z - 3)*asin(sqrt(z))/(4*z**Rational(3, 2)) + assert hyperexpand(hyper([1, 2], [Rational(3, 2)], z)) == -1/(2*z - 2) \ + - asin(sqrt(z))/(sqrt(z)*(2*z - 2)*sqrt(-z + 1)) + assert hyperexpand(hyper([Rational(-1, 2) - 1, 1, 2], [S.Half, 3], z)) == \ + sqrt(z)*(z*Rational(6, 7) - Rational(6, 5))*atanh(sqrt(z)) \ + + (-30*z**2 + 32*z - 6)/35/z - 6*log(-z + 1)/(35*z**2) + assert hyperexpand(hyper([1 + S.Half, 1, 1], [2, 2], z)) == \ + -4*log(sqrt(-z + 1)/2 + S.Half)/z + # TODO hyperexpand(hyper([a], [2*a + 1], z)) + # TODO [S.Half, a], [Rational(3, 2), a+1] + assert hyperexpand(hyper([2], [b, 1], z)) == \ + z**(-b/2 + S.Half)*besseli(b - 1, 2*sqrt(z))*gamma(b) \ + + z**(-b/2 + 1)*besseli(b, 2*sqrt(z))*gamma(b) + # TODO [a], [a - S.Half, 2*a] + + +def test_hyperexpand_parametric(): + assert hyperexpand(hyper([a, S.Half + a], [S.Half], z)) \ + == (1 + sqrt(z))**(-2*a)/2 + (1 - sqrt(z))**(-2*a)/2 + assert hyperexpand(hyper([a, Rational(-1, 2) + a], [2*a], z)) \ + == 2**(2*a - 1)*((-z + 1)**S.Half + 1)**(-2*a + 1) + + +def test_shifted_sum(): + from sympy.simplify.simplify import simplify + assert simplify(hyperexpand(z**4*hyper([2], [3, S('3/2')], -z**2))) \ + == z*sin(2*z) + (-z**2 + S.Half)*cos(2*z) - S.Half + + +def _randrat(): + """ Steer clear of integers. """ + return S(randrange(25) + 10)/50 + + +def randcplx(offset=-1): + """ Polys is not good with real coefficients. """ + return _randrat() + I*_randrat() + I*(1 + offset) + + +@slow +def test_formulae(): + from sympy.simplify.hyperexpand import FormulaCollection + formulae = FormulaCollection().formulae + for formula in formulae: + h = formula.func(formula.z) + rep = {} + for n, sym in enumerate(formula.symbols): + rep[sym] = randcplx(n) + + # NOTE hyperexpand returns truly branched functions. We know we are + # on the main sheet, but numerical evaluation can still go wrong + # (e.g. if exp_polar cannot be evalf'd). + # Just replace all exp_polar by exp, this usually works. + + # first test if the closed-form is actually correct + h = h.subs(rep) + closed_form = formula.closed_form.subs(rep).rewrite('nonrepsmall') + z = formula.z + assert tn(h, closed_form.replace(exp_polar, exp), z) + + # now test the computed matrix + cl = (formula.C * formula.B)[0].subs(rep).rewrite('nonrepsmall') + assert tn(closed_form.replace( + exp_polar, exp), cl.replace(exp_polar, exp), z) + deriv1 = z*formula.B.applyfunc(lambda t: t.rewrite( + 'nonrepsmall')).diff(z) + deriv2 = formula.M * formula.B + for d1, d2 in zip(deriv1, deriv2): + assert tn(d1.subs(rep).replace(exp_polar, exp), + d2.subs(rep).rewrite('nonrepsmall').replace(exp_polar, exp), z) + + +def test_meijerg_formulae(): + from sympy.simplify.hyperexpand import MeijerFormulaCollection + formulae = MeijerFormulaCollection().formulae + for sig in formulae: + for formula in formulae[sig]: + g = meijerg(formula.func.an, formula.func.ap, + formula.func.bm, formula.func.bq, + formula.z) + rep = {} + for sym in formula.symbols: + rep[sym] = randcplx() + + # first test if the closed-form is actually correct + g = g.subs(rep) + closed_form = formula.closed_form.subs(rep) + z = formula.z + assert tn(g, closed_form, z) + + # now test the computed matrix + cl = (formula.C * formula.B)[0].subs(rep) + assert tn(closed_form, cl, z) + deriv1 = z*formula.B.diff(z) + deriv2 = formula.M * formula.B + for d1, d2 in zip(deriv1, deriv2): + assert tn(d1.subs(rep), d2.subs(rep), z) + + +def op(f): + return z*f.diff(z) + + +def test_plan(): + assert devise_plan(Hyper_Function([0], ()), + Hyper_Function([0], ()), z) == [] + with raises(ValueError): + devise_plan(Hyper_Function([1], ()), Hyper_Function((), ()), z) + with raises(ValueError): + devise_plan(Hyper_Function([2], [1]), Hyper_Function([2], [2]), z) + with raises(ValueError): + devise_plan(Hyper_Function([2], []), Hyper_Function([S("1/2")], []), z) + + # We cannot use pi/(10000 + n) because polys is insanely slow. + a1, a2, b1 = (randcplx(n) for n in range(3)) + b1 += 2*I + h = hyper([a1, a2], [b1], z) + + h2 = hyper((a1 + 1, a2), [b1], z) + assert tn(apply_operators(h, + devise_plan(Hyper_Function((a1 + 1, a2), [b1]), + Hyper_Function((a1, a2), [b1]), z), op), + h2, z) + + h2 = hyper((a1 + 1, a2 - 1), [b1], z) + assert tn(apply_operators(h, + devise_plan(Hyper_Function((a1 + 1, a2 - 1), [b1]), + Hyper_Function((a1, a2), [b1]), z), op), + h2, z) + + +def test_plan_derivatives(): + a1, a2, a3 = 1, 2, S('1/2') + b1, b2 = 3, S('5/2') + h = Hyper_Function((a1, a2, a3), (b1, b2)) + h2 = Hyper_Function((a1 + 1, a2 + 1, a3 + 2), (b1 + 1, b2 + 1)) + ops = devise_plan(h2, h, z) + f = Formula(h, z, h(z), []) + deriv = make_derivative_operator(f.M, z) + assert tn((apply_operators(f.C, ops, deriv)*f.B)[0], h2(z), z) + + h2 = Hyper_Function((a1, a2 - 1, a3 - 2), (b1 - 1, b2 - 1)) + ops = devise_plan(h2, h, z) + assert tn((apply_operators(f.C, ops, deriv)*f.B)[0], h2(z), z) + + +def test_reduction_operators(): + a1, a2, b1 = (randcplx(n) for n in range(3)) + h = hyper([a1], [b1], z) + + assert ReduceOrder(2, 0) is None + assert ReduceOrder(2, -1) is None + assert ReduceOrder(1, S('1/2')) is None + + h2 = hyper((a1, a2), (b1, a2), z) + assert tn(ReduceOrder(a2, a2).apply(h, op), h2, z) + + h2 = hyper((a1, a2 + 1), (b1, a2), z) + assert tn(ReduceOrder(a2 + 1, a2).apply(h, op), h2, z) + + h2 = hyper((a2 + 4, a1), (b1, a2), z) + assert tn(ReduceOrder(a2 + 4, a2).apply(h, op), h2, z) + + # test several step order reduction + ap = (a2 + 4, a1, b1 + 1) + bq = (a2, b1, b1) + func, ops = reduce_order(Hyper_Function(ap, bq)) + assert func.ap == (a1,) + assert func.bq == (b1,) + assert tn(apply_operators(h, ops, op), hyper(ap, bq, z), z) + + +def test_shift_operators(): + a1, a2, b1, b2, b3 = (randcplx(n) for n in range(5)) + h = hyper((a1, a2), (b1, b2, b3), z) + + raises(ValueError, lambda: ShiftA(0)) + raises(ValueError, lambda: ShiftB(1)) + + assert tn(ShiftA(a1).apply(h, op), hyper((a1 + 1, a2), (b1, b2, b3), z), z) + assert tn(ShiftA(a2).apply(h, op), hyper((a1, a2 + 1), (b1, b2, b3), z), z) + assert tn(ShiftB(b1).apply(h, op), hyper((a1, a2), (b1 - 1, b2, b3), z), z) + assert tn(ShiftB(b2).apply(h, op), hyper((a1, a2), (b1, b2 - 1, b3), z), z) + assert tn(ShiftB(b3).apply(h, op), hyper((a1, a2), (b1, b2, b3 - 1), z), z) + + +def test_ushift_operators(): + a1, a2, b1, b2, b3 = (randcplx(n) for n in range(5)) + h = hyper((a1, a2), (b1, b2, b3), z) + + raises(ValueError, lambda: UnShiftA((1,), (), 0, z)) + raises(ValueError, lambda: UnShiftB((), (-1,), 0, z)) + raises(ValueError, lambda: UnShiftA((1,), (0, -1, 1), 0, z)) + raises(ValueError, lambda: UnShiftB((0, 1), (1,), 0, z)) + + s = UnShiftA((a1, a2), (b1, b2, b3), 0, z) + assert tn(s.apply(h, op), hyper((a1 - 1, a2), (b1, b2, b3), z), z) + s = UnShiftA((a1, a2), (b1, b2, b3), 1, z) + assert tn(s.apply(h, op), hyper((a1, a2 - 1), (b1, b2, b3), z), z) + + s = UnShiftB((a1, a2), (b1, b2, b3), 0, z) + assert tn(s.apply(h, op), hyper((a1, a2), (b1 + 1, b2, b3), z), z) + s = UnShiftB((a1, a2), (b1, b2, b3), 1, z) + assert tn(s.apply(h, op), hyper((a1, a2), (b1, b2 + 1, b3), z), z) + s = UnShiftB((a1, a2), (b1, b2, b3), 2, z) + assert tn(s.apply(h, op), hyper((a1, a2), (b1, b2, b3 + 1), z), z) + + +def can_do_meijer(a1, a2, b1, b2, numeric=True): + """ + This helper function tries to hyperexpand() the meijer g-function + corresponding to the parameters a1, a2, b1, b2. + It returns False if this expansion still contains g-functions. + If numeric is True, it also tests the so-obtained formula numerically + (at random values) and returns False if the test fails. + Else it returns True. + """ + from sympy.core.function import expand + from sympy.functions.elementary.complexes import unpolarify + r = hyperexpand(meijerg(a1, a2, b1, b2, z)) + if r.has(meijerg): + return False + # NOTE hyperexpand() returns a truly branched function, whereas numerical + # evaluation only works on the main branch. Since we are evaluating on + # the main branch, this should not be a problem, but expressions like + # exp_polar(I*pi/2*x)**a are evaluated incorrectly. We thus have to get + # rid of them. The expand heuristically does this... + r = unpolarify(expand(r, force=True, power_base=True, power_exp=False, + mul=False, log=False, multinomial=False, basic=False)) + + if not numeric: + return True + + repl = {} + for n, ai in enumerate(meijerg(a1, a2, b1, b2, z).free_symbols - {z}): + repl[ai] = randcplx(n) + return tn(meijerg(a1, a2, b1, b2, z).subs(repl), r.subs(repl), z) + + +@slow +def test_meijerg_expand(): + from sympy.simplify.gammasimp import gammasimp + from sympy.simplify.simplify import simplify + # from mpmath docs + assert hyperexpand(meijerg([[], []], [[0], []], -z)) == exp(z) + + assert hyperexpand(meijerg([[1, 1], []], [[1], [0]], z)) == \ + log(z + 1) + assert hyperexpand(meijerg([[1, 1], []], [[1], [1]], z)) == \ + z/(z + 1) + assert hyperexpand(meijerg([[], []], [[S.Half], [0]], (z/2)**2)) \ + == sin(z)/sqrt(pi) + assert hyperexpand(meijerg([[], []], [[0], [S.Half]], (z/2)**2)) \ + == cos(z)/sqrt(pi) + assert can_do_meijer([], [a], [a - 1, a - S.Half], []) + assert can_do_meijer([], [], [a/2], [-a/2], False) # branches... + assert can_do_meijer([a], [b], [a], [b, a - 1]) + + # wikipedia + assert hyperexpand(meijerg([1], [], [], [0], z)) == \ + Piecewise((0, abs(z) < 1), (1, abs(1/z) < 1), + (meijerg([1], [], [], [0], z), True)) + assert hyperexpand(meijerg([], [1], [0], [], z)) == \ + Piecewise((1, abs(z) < 1), (0, abs(1/z) < 1), + (meijerg([], [1], [0], [], z), True)) + + # The Special Functions and their Approximations + assert can_do_meijer([], [], [a + b/2], [a, a - b/2, a + S.Half]) + assert can_do_meijer( + [], [], [a], [b], False) # branches only agree for small z + assert can_do_meijer([], [S.Half], [a], [-a]) + assert can_do_meijer([], [], [a, b], []) + assert can_do_meijer([], [], [a, b], []) + assert can_do_meijer([], [], [a, a + S.Half], [b, b + S.Half]) + assert can_do_meijer([], [], [a, -a], [0, S.Half], False) # dito + assert can_do_meijer([], [], [a, a + S.Half, b, b + S.Half], []) + assert can_do_meijer([S.Half], [], [0], [a, -a]) + assert can_do_meijer([S.Half], [], [a], [0, -a], False) # dito + assert can_do_meijer([], [a - S.Half], [a, b], [a - S.Half], False) + assert can_do_meijer([], [a + S.Half], [a + b, a - b, a], [], False) + assert can_do_meijer([a + S.Half], [], [b, 2*a - b, a], [], False) + + # This for example is actually zero. + assert can_do_meijer([], [], [], [a, b]) + + # Testing a bug: + assert hyperexpand(meijerg([0, 2], [], [], [-1, 1], z)) == \ + Piecewise((0, abs(z) < 1), + (z*(1 - 1/z**2)/2, abs(1/z) < 1), + (meijerg([0, 2], [], [], [-1, 1], z), True)) + + # Test that the simplest possible answer is returned: + assert gammasimp(simplify(hyperexpand( + meijerg([1], [1 - a], [-a/2, -a/2 + S.Half], [], 1/z)))) == \ + -2*sqrt(pi)*(sqrt(z + 1) + 1)**a/a + + # Test that hyper is returned + assert hyperexpand(meijerg([1], [], [a], [0, 0], z)) == hyper( + (a,), (a + 1, a + 1), z*exp_polar(I*pi))*z**a*gamma(a)/gamma(a + 1)**2 + + # Test place option + f = meijerg(((0, 1), ()), ((S.Half,), (0,)), z**2) + assert hyperexpand(f) == sqrt(pi)/sqrt(1 + z**(-2)) + assert hyperexpand(f, place=0) == sqrt(pi)*z/sqrt(z**2 + 1) + + +def test_meijerg_lookup(): + from sympy.functions.special.error_functions import (Ci, Si) + from sympy.functions.special.gamma_functions import uppergamma + assert hyperexpand(meijerg([a], [], [b, a], [], z)) == \ + z**b*exp(z)*gamma(-a + b + 1)*uppergamma(a - b, z) + assert hyperexpand(meijerg([0], [], [0, 0], [], z)) == \ + exp(z)*uppergamma(0, z) + assert can_do_meijer([a], [], [b, a + 1], []) + assert can_do_meijer([a], [], [b + 2, a], []) + assert can_do_meijer([a], [], [b - 2, a], []) + + assert hyperexpand(meijerg([a], [], [a, a, a - S.Half], [], z)) == \ + -sqrt(pi)*z**(a - S.Half)*(2*cos(2*sqrt(z))*(Si(2*sqrt(z)) - pi/2) + - 2*sin(2*sqrt(z))*Ci(2*sqrt(z))) == \ + hyperexpand(meijerg([a], [], [a, a - S.Half, a], [], z)) == \ + hyperexpand(meijerg([a], [], [a - S.Half, a, a], [], z)) + assert can_do_meijer([a - 1], [], [a + 2, a - Rational(3, 2), a + 1], []) + + +@XFAIL +def test_meijerg_expand_fail(): + # These basically test hyper([], [1/2 - a, 1/2 + 1, 1/2], z), + # which is *very* messy. But since the meijer g actually yields a + # sum of bessel functions, things can sometimes be simplified a lot and + # are then put into tables... + assert can_do_meijer([], [], [a + S.Half], [a, a - b/2, a + b/2]) + assert can_do_meijer([], [], [0, S.Half], [a, -a]) + assert can_do_meijer([], [], [3*a - S.Half, a, -a - S.Half], [a - S.Half]) + assert can_do_meijer([], [], [0, a - S.Half, -a - S.Half], [S.Half]) + assert can_do_meijer([], [], [a, b + S.Half, b], [2*b - a]) + assert can_do_meijer([], [], [a, b + S.Half, b, 2*b - a]) + assert can_do_meijer([S.Half], [], [-a, a], [0]) + + +@slow +def test_meijerg(): + # carefully set up the parameters. + # NOTE: this used to fail sometimes. I believe it is fixed, but if you + # hit an inexplicable test failure here, please let me know the seed. + a1, a2 = (randcplx(n) - 5*I - n*I for n in range(2)) + b1, b2 = (randcplx(n) + 5*I + n*I for n in range(2)) + b3, b4, b5, a3, a4, a5 = (randcplx() for n in range(6)) + g = meijerg([a1], [a3, a4], [b1], [b3, b4], z) + + assert ReduceOrder.meijer_minus(3, 4) is None + assert ReduceOrder.meijer_plus(4, 3) is None + + g2 = meijerg([a1, a2], [a3, a4], [b1], [b3, b4, a2], z) + assert tn(ReduceOrder.meijer_plus(a2, a2).apply(g, op), g2, z) + + g2 = meijerg([a1, a2], [a3, a4], [b1], [b3, b4, a2 + 1], z) + assert tn(ReduceOrder.meijer_plus(a2, a2 + 1).apply(g, op), g2, z) + + g2 = meijerg([a1, a2 - 1], [a3, a4], [b1], [b3, b4, a2 + 2], z) + assert tn(ReduceOrder.meijer_plus(a2 - 1, a2 + 2).apply(g, op), g2, z) + + g2 = meijerg([a1], [a3, a4, b2 - 1], [b1, b2 + 2], [b3, b4], z) + assert tn(ReduceOrder.meijer_minus( + b2 + 2, b2 - 1).apply(g, op), g2, z, tol=1e-6) + + # test several-step reduction + an = [a1, a2] + bq = [b3, b4, a2 + 1] + ap = [a3, a4, b2 - 1] + bm = [b1, b2 + 1] + niq, ops = reduce_order_meijer(G_Function(an, ap, bm, bq)) + assert niq.an == (a1,) + assert set(niq.ap) == {a3, a4} + assert niq.bm == (b1,) + assert set(niq.bq) == {b3, b4} + assert tn(apply_operators(g, ops, op), meijerg(an, ap, bm, bq, z), z) + + +def test_meijerg_shift_operators(): + # carefully set up the parameters. XXX this still fails sometimes + a1, a2, a3, a4, a5, b1, b2, b3, b4, b5 = (randcplx(n) for n in range(10)) + g = meijerg([a1], [a3, a4], [b1], [b3, b4], z) + + assert tn(MeijerShiftA(b1).apply(g, op), + meijerg([a1], [a3, a4], [b1 + 1], [b3, b4], z), z) + assert tn(MeijerShiftB(a1).apply(g, op), + meijerg([a1 - 1], [a3, a4], [b1], [b3, b4], z), z) + assert tn(MeijerShiftC(b3).apply(g, op), + meijerg([a1], [a3, a4], [b1], [b3 + 1, b4], z), z) + assert tn(MeijerShiftD(a3).apply(g, op), + meijerg([a1], [a3 - 1, a4], [b1], [b3, b4], z), z) + + s = MeijerUnShiftA([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1], [a3, a4], [b1 - 1], [b3, b4], z), z) + + s = MeijerUnShiftC([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1], [a3, a4], [b1], [b3 - 1, b4], z), z) + + s = MeijerUnShiftB([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1 + 1], [a3, a4], [b1], [b3, b4], z), z) + + s = MeijerUnShiftD([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1], [a3 + 1, a4], [b1], [b3, b4], z), z) + + +@slow +def test_meijerg_confluence(): + def t(m, a, b): + from sympy.core.sympify import sympify + a, b = sympify([a, b]) + m_ = m + m = hyperexpand(m) + if not m == Piecewise((a, abs(z) < 1), (b, abs(1/z) < 1), (m_, True)): + return False + if not (m.args[0].args[0] == a and m.args[1].args[0] == b): + return False + z0 = randcplx()/10 + if abs(m.subs(z, z0).n() - a.subs(z, z0).n()).n() > 1e-10: + return False + if abs(m.subs(z, 1/z0).n() - b.subs(z, 1/z0).n()).n() > 1e-10: + return False + return True + + assert t(meijerg([], [1, 1], [0, 0], [], z), -log(z), 0) + assert t(meijerg( + [], [3, 1], [0, 0], [], z), -z**2/4 + z - log(z)/2 - Rational(3, 4), 0) + assert t(meijerg([], [3, 1], [-1, 0], [], z), + z**2/12 - z/2 + log(z)/2 + Rational(1, 4) + 1/(6*z), 0) + assert t(meijerg([], [1, 1, 1, 1], [0, 0, 0, 0], [], z), -log(z)**3/6, 0) + assert t(meijerg([1, 1], [], [], [0, 0], z), 0, -log(1/z)) + assert t(meijerg([1, 1], [2, 2], [1, 1], [0, 0], z), + -z*log(z) + 2*z, -log(1/z) + 2) + assert t(meijerg([S.Half], [1, 1], [0, 0], [Rational(3, 2)], z), log(z)/2 - 1, 0) + + def u(an, ap, bm, bq): + m = meijerg(an, ap, bm, bq, z) + m2 = hyperexpand(m, allow_hyper=True) + if m2.has(meijerg) and not (m2.is_Piecewise and len(m2.args) == 3): + return False + return tn(m, m2, z) + assert u([], [1], [0, 0], []) + assert u([1, 1], [], [], [0]) + assert u([1, 1], [2, 2, 5], [1, 1, 6], [0, 0]) + assert u([1, 1], [2, 2, 5], [1, 1, 6], [0]) + + +def test_meijerg_with_Floats(): + # see issue #10681 + from sympy.polys.domains.realfield import RR + f = meijerg(((3.0, 1), ()), ((Rational(3, 2),), (0,)), z) + a = -2.3632718012073 + g = a*z**Rational(3, 2)*hyper((-0.5, Rational(3, 2)), (Rational(5, 2),), z*exp_polar(I*pi)) + assert RR.almosteq((hyperexpand(f)/g).n(), 1.0, 1e-12) + + +def test_lerchphi(): + from sympy.functions.special.zeta_functions import (lerchphi, polylog) + from sympy.simplify.gammasimp import gammasimp + assert hyperexpand(hyper([1, a], [a + 1], z)/a) == lerchphi(z, 1, a) + assert hyperexpand( + hyper([1, a, a], [a + 1, a + 1], z)/a**2) == lerchphi(z, 2, a) + assert hyperexpand(hyper([1, a, a, a], [a + 1, a + 1, a + 1], z)/a**3) == \ + lerchphi(z, 3, a) + assert hyperexpand(hyper([1] + [a]*10, [a + 1]*10, z)/a**10) == \ + lerchphi(z, 10, a) + assert gammasimp(hyperexpand(meijerg([0, 1 - a], [], [0], + [-a], exp_polar(-I*pi)*z))) == lerchphi(z, 1, a) + assert gammasimp(hyperexpand(meijerg([0, 1 - a, 1 - a], [], [0], + [-a, -a], exp_polar(-I*pi)*z))) == lerchphi(z, 2, a) + assert gammasimp(hyperexpand(meijerg([0, 1 - a, 1 - a, 1 - a], [], [0], + [-a, -a, -a], exp_polar(-I*pi)*z))) == lerchphi(z, 3, a) + + assert hyperexpand(z*hyper([1, 1], [2], z)) == -log(1 + -z) + assert hyperexpand(z*hyper([1, 1, 1], [2, 2], z)) == polylog(2, z) + assert hyperexpand(z*hyper([1, 1, 1, 1], [2, 2, 2], z)) == polylog(3, z) + + assert hyperexpand(hyper([1, a, 1 + S.Half], [a + 1, S.Half], z)) == \ + -2*a/(z - 1) + (-2*a**2 + a)*lerchphi(z, 1, a) + + # Now numerical tests. These make sure reductions etc are carried out + # correctly + + # a rational function (polylog at negative integer order) + assert can_do([2, 2, 2], [1, 1]) + + # NOTE these contain log(1-x) etc ... better make sure we have |z| < 1 + # reduction of order for polylog + assert can_do([1, 1, 1, b + 5], [2, 2, b], div=10) + + # reduction of order for lerchphi + # XXX lerchphi in mpmath is flaky + assert can_do( + [1, a, a, a, b + 5], [a + 1, a + 1, a + 1, b], numerical=False) + + # test a bug + from sympy.functions.elementary.complexes import Abs + assert hyperexpand(hyper([S.Half, S.Half, S.Half, 1], + [Rational(3, 2), Rational(3, 2), Rational(3, 2)], Rational(1, 4))) == \ + Abs(-polylog(3, exp_polar(I*pi)/2) + polylog(3, S.Half)) + + +def test_partial_simp(): + # First test that hypergeometric function formulae work. + a, b, c, d, e = (randcplx() for _ in range(5)) + for func in [Hyper_Function([a, b, c], [d, e]), + Hyper_Function([], [a, b, c, d, e])]: + f = build_hypergeometric_formula(func) + z = f.z + assert f.closed_form == func(z) + deriv1 = f.B.diff(z)*z + deriv2 = f.M*f.B + for func1, func2 in zip(deriv1, deriv2): + assert tn(func1, func2, z) + + # Now test that formulae are partially simplified. + a, b, z = symbols('a b z') + assert hyperexpand(hyper([3, a], [1, b], z)) == \ + (-a*b/2 + a*z/2 + 2*a)*hyper([a + 1], [b], z) \ + + (a*b/2 - 2*a + 1)*hyper([a], [b], z) + assert tn( + hyperexpand(hyper([3, d], [1, e], z)), hyper([3, d], [1, e], z), z) + assert hyperexpand(hyper([3], [1, a, b], z)) == \ + hyper((), (a, b), z) \ + + z*hyper((), (a + 1, b), z)/(2*a) \ + - z*(b - 4)*hyper((), (a + 1, b + 1), z)/(2*a*b) + assert tn( + hyperexpand(hyper([3], [1, d, e], z)), hyper([3], [1, d, e], z), z) + + +def test_hyperexpand_special(): + assert hyperexpand(hyper([a, b], [c], 1)) == \ + gamma(c)*gamma(c - a - b)/gamma(c - a)/gamma(c - b) + assert hyperexpand(hyper([a, b], [1 + a - b], -1)) == \ + gamma(1 + a/2)*gamma(1 + a - b)/gamma(1 + a)/gamma(1 + a/2 - b) + assert hyperexpand(hyper([a, b], [1 + b - a], -1)) == \ + gamma(1 + b/2)*gamma(1 + b - a)/gamma(1 + b)/gamma(1 + b/2 - a) + assert hyperexpand(meijerg([1 - z - a/2], [1 - z + a/2], [b/2], [-b/2], 1)) == \ + gamma(1 - 2*z)*gamma(z + a/2 + b/2)/gamma(1 - z + a/2 - b/2) \ + /gamma(1 - z - a/2 + b/2)/gamma(1 - z + a/2 + b/2) + assert hyperexpand(hyper([a], [b], 0)) == 1 + assert hyper([a], [b], 0) != 0 + + +def test_Mod1_behavior(): + from sympy.core.symbol import Symbol + from sympy.simplify.simplify import simplify + n = Symbol('n', integer=True) + # Note: this should not hang. + assert simplify(hyperexpand(meijerg([1], [], [n + 1], [0], z))) == \ + lowergamma(n + 1, z) + + +@slow +def test_prudnikov_misc(): + assert can_do([1, (3 + I)/2, (3 - I)/2], [Rational(3, 2), 2]) + assert can_do([S.Half, a - 1], [Rational(3, 2), a + 1], lowerplane=True) + assert can_do([], [b + 1]) + assert can_do([a], [a - 1, b + 1]) + + assert can_do([a], [a - S.Half, 2*a]) + assert can_do([a], [a - S.Half, 2*a + 1]) + assert can_do([a], [a - S.Half, 2*a - 1]) + assert can_do([a], [a + S.Half, 2*a]) + assert can_do([a], [a + S.Half, 2*a + 1]) + assert can_do([a], [a + S.Half, 2*a - 1]) + assert can_do([S.Half], [b, 2 - b]) + assert can_do([S.Half], [b, 3 - b]) + assert can_do([1], [2, b]) + + assert can_do([a, a + S.Half], [2*a, b, 2*a - b + 1]) + assert can_do([a, a + S.Half], [S.Half, 2*a, 2*a + S.Half]) + assert can_do([a], [a + 1], lowerplane=True) # lowergamma + + +def test_prudnikov_1(): + # A. P. Prudnikov, Yu. A. Brychkov and O. I. Marichev (1990). + # Integrals and Series: More Special Functions, Vol. 3,. + # Gordon and Breach Science Publisher + + # 7.3.1 + assert can_do([a, -a], [S.Half]) + assert can_do([a, 1 - a], [S.Half]) + assert can_do([a, 1 - a], [Rational(3, 2)]) + assert can_do([a, 2 - a], [S.Half]) + assert can_do([a, 2 - a], [Rational(3, 2)]) + assert can_do([a, 2 - a], [Rational(3, 2)]) + assert can_do([a, a + S.Half], [2*a - 1]) + assert can_do([a, a + S.Half], [2*a]) + assert can_do([a, a + S.Half], [2*a + 1]) + assert can_do([a, a + S.Half], [S.Half]) + assert can_do([a, a + S.Half], [Rational(3, 2)]) + assert can_do([a, a/2 + 1], [a/2]) + assert can_do([1, b], [2]) + assert can_do([1, b], [b + 1], numerical=False) # Lerch Phi + # NOTE: branches are complicated for |z| > 1 + + assert can_do([a], [2*a]) + assert can_do([a], [2*a + 1]) + assert can_do([a], [2*a - 1]) + + +@slow +def test_prudnikov_2(): + h = S.Half + assert can_do([-h, -h], [h]) + assert can_do([-h, h], [3*h]) + assert can_do([-h, h], [5*h]) + assert can_do([-h, h], [7*h]) + assert can_do([-h, 1], [h]) + + for p in [-h, h]: + for n in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4]: + for m in [-h, h, 3*h, 5*h, 7*h]: + assert can_do([p, n], [m]) + for n in [1, 2, 3, 4]: + for m in [1, 2, 3, 4]: + assert can_do([p, n], [m]) + + +def test_prudnikov_3(): + h = S.Half + assert can_do([Rational(1, 4), Rational(3, 4)], [h]) + assert can_do([Rational(1, 4), Rational(3, 4)], [3*h]) + assert can_do([Rational(1, 3), Rational(2, 3)], [3*h]) + assert can_do([Rational(3, 4), Rational(5, 4)], [h]) + assert can_do([Rational(3, 4), Rational(5, 4)], [3*h]) + + +@tooslow +def test_prudnikov_3_slow(): + # XXX: This is marked as tooslow and hence skipped in CI. None of the + # individual cases below fails or hangs. Some cases are slow and the loops + # below generate 280 different cases. Is it really necessary to test all + # 280 cases here? + h = S.Half + for p in [1, 2, 3, 4]: + for n in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4, 9*h]: + for m in [1, 3*h, 2, 5*h, 3, 7*h, 4]: + assert can_do([p, m], [n]) + + +@slow +def test_prudnikov_4(): + h = S.Half + for p in [3*h, 5*h, 7*h]: + for n in [-h, h, 3*h, 5*h, 7*h]: + for m in [3*h, 2, 5*h, 3, 7*h, 4]: + assert can_do([p, m], [n]) + for n in [1, 2, 3, 4]: + for m in [2, 3, 4]: + assert can_do([p, m], [n]) + + +@slow +def test_prudnikov_5(): + h = S.Half + + for p in [1, 2, 3]: + for q in range(p, 4): + for r in [1, 2, 3]: + for s in range(r, 4): + assert can_do([-h, p, q], [r, s]) + + for p in [h, 1, 3*h, 2, 5*h, 3]: + for q in [h, 3*h, 5*h]: + for r in [h, 3*h, 5*h]: + for s in [h, 3*h, 5*h]: + if s <= q and s <= r: + assert can_do([-h, p, q], [r, s]) + + for p in [h, 1, 3*h, 2, 5*h, 3]: + for q in [1, 2, 3]: + for r in [h, 3*h, 5*h]: + for s in [1, 2, 3]: + assert can_do([-h, p, q], [r, s]) + + +@slow +def test_prudnikov_6(): + h = S.Half + + for m in [3*h, 5*h]: + for n in [1, 2, 3]: + for q in [h, 1, 2]: + for p in [1, 2, 3]: + assert can_do([h, q, p], [m, n]) + for q in [1, 2, 3]: + for p in [3*h, 5*h]: + assert can_do([h, q, p], [m, n]) + + for q in [1, 2]: + for p in [1, 2, 3]: + for m in [1, 2, 3]: + for n in [1, 2, 3]: + assert can_do([h, q, p], [m, n]) + + assert can_do([h, h, 5*h], [3*h, 3*h]) + assert can_do([h, 1, 5*h], [3*h, 3*h]) + assert can_do([h, 2, 2], [1, 3]) + + # pages 435 to 457 contain more PFDD and stuff like this + + +@slow +def test_prudnikov_7(): + assert can_do([3], [6]) + + h = S.Half + for n in [h, 3*h, 5*h, 7*h]: + assert can_do([-h], [n]) + for m in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4]: # HERE + for n in [-h, h, 3*h, 5*h, 7*h, 1, 2, 3, 4]: + assert can_do([m], [n]) + + +@slow +def test_prudnikov_8(): + h = S.Half + + # 7.12.2 + for ai in [1, 2, 3]: + for bi in [1, 2, 3]: + for ci in range(1, ai + 1): + for di in [h, 1, 3*h, 2, 5*h, 3]: + assert can_do([ai, bi], [ci, di]) + for bi in [3*h, 5*h]: + for ci in [h, 1, 3*h, 2, 5*h, 3]: + for di in [1, 2, 3]: + assert can_do([ai, bi], [ci, di]) + + for ai in [-h, h, 3*h, 5*h]: + for bi in [1, 2, 3]: + for ci in [h, 1, 3*h, 2, 5*h, 3]: + for di in [1, 2, 3]: + assert can_do([ai, bi], [ci, di]) + for bi in [h, 3*h, 5*h]: + for ci in [h, 3*h, 5*h, 3]: + for di in [h, 1, 3*h, 2, 5*h, 3]: + if ci <= bi: + assert can_do([ai, bi], [ci, di]) + + +def test_prudnikov_9(): + # 7.13.1 [we have a general formula ... so this is a bit pointless] + for i in range(9): + assert can_do([], [(S(i) + 1)/2]) + for i in range(5): + assert can_do([], [-(2*S(i) + 1)/2]) + + +@slow +def test_prudnikov_10(): + # 7.14.2 + h = S.Half + for p in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4]: + for m in [1, 2, 3, 4]: + for n in range(m, 5): + assert can_do([p], [m, n]) + + for p in [1, 2, 3, 4]: + for n in [h, 3*h, 5*h, 7*h]: + for m in [1, 2, 3, 4]: + assert can_do([p], [n, m]) + + for p in [3*h, 5*h, 7*h]: + for m in [h, 1, 2, 5*h, 3, 7*h, 4]: + assert can_do([p], [h, m]) + assert can_do([p], [3*h, m]) + + for m in [h, 1, 2, 5*h, 3, 7*h, 4]: + assert can_do([7*h], [5*h, m]) + + assert can_do([Rational(-1, 2)], [S.Half, S.Half]) # shine-integral shi + + +def test_prudnikov_11(): + # 7.15 + assert can_do([a, a + S.Half], [2*a, b, 2*a - b]) + assert can_do([a, a + S.Half], [Rational(3, 2), 2*a, 2*a - S.Half]) + + assert can_do([Rational(1, 4), Rational(3, 4)], [S.Half, S.Half, 1]) + assert can_do([Rational(5, 4), Rational(3, 4)], [Rational(3, 2), S.Half, 2]) + assert can_do([Rational(5, 4), Rational(3, 4)], [Rational(3, 2), Rational(3, 2), 1]) + assert can_do([Rational(5, 4), Rational(7, 4)], [Rational(3, 2), Rational(5, 2), 2]) + + assert can_do([1, 1], [Rational(3, 2), 2, 2]) # cosh-integral chi + + +def test_prudnikov_12(): + # 7.16 + assert can_do( + [], [a, a + S.Half, 2*a], False) # branches only agree for some z! + assert can_do([], [a, a + S.Half, 2*a + 1], False) # dito + assert can_do([], [S.Half, a, a + S.Half]) + assert can_do([], [Rational(3, 2), a, a + S.Half]) + + assert can_do([], [Rational(1, 4), S.Half, Rational(3, 4)]) + assert can_do([], [S.Half, S.Half, 1]) + assert can_do([], [S.Half, Rational(3, 2), 1]) + assert can_do([], [Rational(3, 4), Rational(3, 2), Rational(5, 4)]) + assert can_do([], [1, 1, Rational(3, 2)]) + assert can_do([], [1, 2, Rational(3, 2)]) + assert can_do([], [1, Rational(3, 2), Rational(3, 2)]) + assert can_do([], [Rational(5, 4), Rational(3, 2), Rational(7, 4)]) + assert can_do([], [2, Rational(3, 2), Rational(3, 2)]) + + +@slow +def test_prudnikov_2F1(): + h = S.Half + # Elliptic integrals + for p in [-h, h]: + for m in [h, 3*h, 5*h, 7*h]: + for n in [1, 2, 3, 4]: + assert can_do([p, m], [n]) + + +@XFAIL +def test_prudnikov_fail_2F1(): + assert can_do([a, b], [b + 1]) # incomplete beta function + assert can_do([-1, b], [c]) # Poly. also -2, -3 etc + + # TODO polys + + # Legendre functions: + assert can_do([a, b], [a + b + S.Half]) + assert can_do([a, b], [a + b - S.Half]) + assert can_do([a, b], [a + b + Rational(3, 2)]) + assert can_do([a, b], [(a + b + 1)/2]) + assert can_do([a, b], [(a + b)/2 + 1]) + assert can_do([a, b], [a - b + 1]) + assert can_do([a, b], [a - b + 2]) + assert can_do([a, b], [2*b]) + assert can_do([a, b], [S.Half]) + assert can_do([a, b], [Rational(3, 2)]) + assert can_do([a, 1 - a], [c]) + assert can_do([a, 2 - a], [c]) + assert can_do([a, 3 - a], [c]) + assert can_do([a, a + S.Half], [c]) + assert can_do([1, b], [c]) + assert can_do([1, b], [Rational(3, 2)]) + + assert can_do([Rational(1, 4), Rational(3, 4)], [1]) + + # PFDD + o = S.One + assert can_do([o/8, 1], [o/8*9]) + assert can_do([o/6, 1], [o/6*7]) + assert can_do([o/6, 1], [o/6*13]) + assert can_do([o/5, 1], [o/5*6]) + assert can_do([o/5, 1], [o/5*11]) + assert can_do([o/4, 1], [o/4*5]) + assert can_do([o/4, 1], [o/4*9]) + assert can_do([o/3, 1], [o/3*4]) + assert can_do([o/3, 1], [o/3*7]) + assert can_do([o/8*3, 1], [o/8*11]) + assert can_do([o/5*2, 1], [o/5*7]) + assert can_do([o/5*2, 1], [o/5*12]) + assert can_do([o/5*3, 1], [o/5*8]) + assert can_do([o/5*3, 1], [o/5*13]) + assert can_do([o/8*5, 1], [o/8*13]) + assert can_do([o/4*3, 1], [o/4*7]) + assert can_do([o/4*3, 1], [o/4*11]) + assert can_do([o/3*2, 1], [o/3*5]) + assert can_do([o/3*2, 1], [o/3*8]) + assert can_do([o/5*4, 1], [o/5*9]) + assert can_do([o/5*4, 1], [o/5*14]) + assert can_do([o/6*5, 1], [o/6*11]) + assert can_do([o/6*5, 1], [o/6*17]) + assert can_do([o/8*7, 1], [o/8*15]) + + +@XFAIL +def test_prudnikov_fail_3F2(): + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [Rational(1, 3), Rational(2, 3)]) + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [Rational(2, 3), Rational(4, 3)]) + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [Rational(4, 3), Rational(5, 3)]) + + # page 421 + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [a*Rational(3, 2), (3*a + 1)/2]) + + # pages 422 ... + assert can_do([Rational(-1, 2), S.Half, S.Half], [1, 1]) # elliptic integrals + assert can_do([Rational(-1, 2), S.Half, 1], [Rational(3, 2), Rational(3, 2)]) + # TODO LOTS more + + # PFDD + assert can_do([Rational(1, 8), Rational(3, 8), 1], [Rational(9, 8), Rational(11, 8)]) + assert can_do([Rational(1, 8), Rational(5, 8), 1], [Rational(9, 8), Rational(13, 8)]) + assert can_do([Rational(1, 8), Rational(7, 8), 1], [Rational(9, 8), Rational(15, 8)]) + assert can_do([Rational(1, 6), Rational(1, 3), 1], [Rational(7, 6), Rational(4, 3)]) + assert can_do([Rational(1, 6), Rational(2, 3), 1], [Rational(7, 6), Rational(5, 3)]) + assert can_do([Rational(1, 6), Rational(2, 3), 1], [Rational(5, 3), Rational(13, 6)]) + assert can_do([S.Half, 1, 1], [Rational(1, 4), Rational(3, 4)]) + # LOTS more + + +@XFAIL +def test_prudnikov_fail_other(): + # 7.11.2 + + # 7.12.1 + assert can_do([1, a], [b, 1 - 2*a + b]) # ??? + + # 7.14.2 + assert can_do([Rational(-1, 2)], [S.Half, 1]) # struve + assert can_do([1], [S.Half, S.Half]) # struve + assert can_do([Rational(1, 4)], [S.Half, Rational(5, 4)]) # PFDD + assert can_do([Rational(3, 4)], [Rational(3, 2), Rational(7, 4)]) # PFDD + assert can_do([1], [Rational(1, 4), Rational(3, 4)]) # PFDD + assert can_do([1], [Rational(3, 4), Rational(5, 4)]) # PFDD + assert can_do([1], [Rational(5, 4), Rational(7, 4)]) # PFDD + # TODO LOTS more + + # 7.15.2 + assert can_do([S.Half, 1], [Rational(3, 4), Rational(5, 4), Rational(3, 2)]) # PFDD + assert can_do([S.Half, 1], [Rational(7, 4), Rational(5, 4), Rational(3, 2)]) # PFDD + + # 7.16.1 + assert can_do([], [Rational(1, 3), S(2/3)]) # PFDD + assert can_do([], [Rational(2, 3), S(4/3)]) # PFDD + assert can_do([], [Rational(5, 3), S(4/3)]) # PFDD + + # XXX this does not *evaluate* right?? + assert can_do([], [a, a + S.Half, 2*a - 1]) + + +def test_bug(): + h = hyper([-1, 1], [z], -1) + assert hyperexpand(h) == (z + 1)/z + + +def test_omgissue_203(): + h = hyper((-5, -3, -4), (-6, -6), 1) + assert hyperexpand(h) == Rational(1, 30) + h = hyper((-6, -7, -5), (-6, -6), 1) + assert hyperexpand(h) == Rational(-1, 6) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_powsimp.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_powsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..61bdc93d052baf4b1e80da8f5864cf22b8fa383e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_powsimp.py @@ -0,0 +1,368 @@ +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import (E, I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import sin +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import hyper +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.simplify.powsimp import (powdenest, powsimp) +from sympy.simplify.simplify import (signsimp, simplify) +from sympy.core.symbol import Str + +from sympy.abc import x, y, z, a, b + + +def test_powsimp(): + x, y, z, n = symbols('x,y,z,n') + f = Function('f') + assert powsimp( 4**x * 2**(-x) * 2**(-x) ) == 1 + assert powsimp( (-4)**x * (-2)**(-x) * 2**(-x) ) == 1 + + assert powsimp( + f(4**x * 2**(-x) * 2**(-x)) ) == f(4**x * 2**(-x) * 2**(-x)) + assert powsimp( f(4**x * 2**(-x) * 2**(-x)), deep=True ) == f(1) + assert exp(x)*exp(y) == exp(x)*exp(y) + assert powsimp(exp(x)*exp(y)) == exp(x + y) + assert powsimp(exp(x)*exp(y)*2**x*2**y) == (2*E)**(x + y) + assert powsimp(exp(x)*exp(y)*2**x*2**y, combine='exp') == \ + exp(x + y)*2**(x + y) + assert powsimp(exp(x)*exp(y)*exp(2)*sin(x) + sin(y) + 2**x*2**y) == \ + exp(2 + x + y)*sin(x) + sin(y) + 2**(x + y) + assert powsimp(sin(exp(x)*exp(y))) == sin(exp(x)*exp(y)) + assert powsimp(sin(exp(x)*exp(y)), deep=True) == sin(exp(x + y)) + assert powsimp(x**2*x**y) == x**(2 + y) + # This should remain factored, because 'exp' with deep=True is supposed + # to act like old automatic exponent combining. + assert powsimp((1 + E*exp(E))*exp(-E), combine='exp', deep=True) == \ + (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E), deep=True) == \ + (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E)) == (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E), combine='exp') == \ + (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E), combine='base') == \ + (1 + E*exp(E))*exp(-E) + x, y = symbols('x,y', nonnegative=True) + n = Symbol('n', real=True) + assert powsimp(y**n * (y/x)**(-n)) == x**n + assert powsimp(x**(x**(x*y)*y**(x*y))*y**(x**(x*y)*y**(x*y)), deep=True) \ + == (x*y)**(x*y)**(x*y) + assert powsimp(2**(2**(2*x)*x), deep=False) == 2**(2**(2*x)*x) + assert powsimp(2**(2**(2*x)*x), deep=True) == 2**(x*4**x) + assert powsimp( + exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \ + exp(-x + exp(-x)*exp(-x*log(x))) + assert powsimp( + exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \ + exp(-x + exp(-x)*exp(-x*log(x))) + assert powsimp((x + y)/(3*z), deep=False, combine='exp') == (x + y)/(3*z) + assert powsimp((x/3 + y/3)/z, deep=True, combine='exp') == (x/3 + y/3)/z + assert powsimp(exp(x)/(1 + exp(x)*exp(y)), deep=True) == \ + exp(x)/(1 + exp(x + y)) + assert powsimp(x*y**(z**x*z**y), deep=True) == x*y**(z**(x + y)) + assert powsimp((z**x*z**y)**x, deep=True) == (z**(x + y))**x + assert powsimp(x*(z**x*z**y)**x, deep=True) == x*(z**(x + y))**x + p = symbols('p', positive=True) + assert powsimp((1/x)**log(2)/x) == (1/x)**(1 + log(2)) + assert powsimp((1/p)**log(2)/p) == p**(-1 - log(2)) + + # coefficient of exponent can only be simplified for positive bases + assert powsimp(2**(2*x)) == 4**x + assert powsimp((-1)**(2*x)) == (-1)**(2*x) + i = symbols('i', integer=True) + assert powsimp((-1)**(2*i)) == 1 + assert powsimp((-1)**(-x)) != (-1)**x # could be 1/((-1)**x), but is not + # force=True overrides assumptions + assert powsimp((-1)**(2*x), force=True) == 1 + + # rational exponents allow combining of negative terms + w, n, m = symbols('w n m', negative=True) + e = i/a # not a rational exponent if `a` is unknown + ex = w**e*n**e*m**e + assert powsimp(ex) == m**(i/a)*n**(i/a)*w**(i/a) + e = i/3 + ex = w**e*n**e*m**e + assert powsimp(ex) == (-1)**i*(-m*n*w)**(i/3) + e = (3 + i)/i + ex = w**e*n**e*m**e + assert powsimp(ex) == (-1)**(3*e)*(-m*n*w)**e + + eq = x**(a*Rational(2, 3)) + # eq != (x**a)**(2/3) (try x = -1 and a = 3 to see) + assert powsimp(eq).exp == eq.exp == a*Rational(2, 3) + # powdenest goes the other direction + assert powsimp(2**(2*x)) == 4**x + + assert powsimp(exp(p/2)) == exp(p/2) + + # issue 6368 + eq = Mul(*[sqrt(Dummy(imaginary=True)) for i in range(3)]) + assert powsimp(eq) == eq and eq.is_Mul + + assert all(powsimp(e) == e for e in (sqrt(x**a), sqrt(x**2))) + + # issue 8836 + assert str( powsimp(exp(I*pi/3)*root(-1,3)) ) == '(-1)**(2/3)' + + # issue 9183 + assert powsimp(-0.1**x) == -0.1**x + + # issue 10095 + assert powsimp((1/(2*E))**oo) == (exp(-1)/2)**oo + + # PR 13131 + eq = sin(2*x)**2*sin(2.0*x)**2 + assert powsimp(eq) == eq + + # issue 14615 + assert powsimp(x**2*y**3*(x*y**2)**Rational(3, 2) + ) == x*y*(x*y**2)**Rational(5, 2) + + #issue 27380 + assert powsimp(1.0**(x+1)/1.0**x) == 1.0 + +def test_powsimp_negated_base(): + assert powsimp((-x + y)/sqrt(x - y)) == -sqrt(x - y) + assert powsimp((-x + y)*(-z + y)/sqrt(x - y)/sqrt(z - y)) == sqrt(x - y)*sqrt(z - y) + p = symbols('p', positive=True) + reps = {p: 2, a: S.Half} + assert powsimp((-p)**a/p**a).subs(reps) == ((-1)**a).subs(reps) + assert powsimp((-p)**a*p**a).subs(reps) == ((-p**2)**a).subs(reps) + n = symbols('n', negative=True) + reps = {p: -2, a: S.Half} + assert powsimp((-n)**a/n**a).subs(reps) == (-1)**(-a).subs(a, S.Half) + assert powsimp((-n)**a*n**a).subs(reps) == ((-n**2)**a).subs(reps) + # if x is 0 then the lhs is 0**a*oo**a which is not (-1)**a + eq = (-x)**a/x**a + assert powsimp(eq) == eq + + +def test_powsimp_nc(): + x, y, z = symbols('x,y,z') + A, B, C = symbols('A B C', commutative=False) + + assert powsimp(A**x*A**y, combine='all') == A**(x + y) + assert powsimp(A**x*A**y, combine='base') == A**x*A**y + assert powsimp(A**x*A**y, combine='exp') == A**(x + y) + + assert powsimp(A**x*B**x, combine='all') == A**x*B**x + assert powsimp(A**x*B**x, combine='base') == A**x*B**x + assert powsimp(A**x*B**x, combine='exp') == A**x*B**x + + assert powsimp(B**x*A**x, combine='all') == B**x*A**x + assert powsimp(B**x*A**x, combine='base') == B**x*A**x + assert powsimp(B**x*A**x, combine='exp') == B**x*A**x + + assert powsimp(A**x*A**y*A**z, combine='all') == A**(x + y + z) + assert powsimp(A**x*A**y*A**z, combine='base') == A**x*A**y*A**z + assert powsimp(A**x*A**y*A**z, combine='exp') == A**(x + y + z) + + assert powsimp(A**x*B**x*C**x, combine='all') == A**x*B**x*C**x + assert powsimp(A**x*B**x*C**x, combine='base') == A**x*B**x*C**x + assert powsimp(A**x*B**x*C**x, combine='exp') == A**x*B**x*C**x + + assert powsimp(B**x*A**x*C**x, combine='all') == B**x*A**x*C**x + assert powsimp(B**x*A**x*C**x, combine='base') == B**x*A**x*C**x + assert powsimp(B**x*A**x*C**x, combine='exp') == B**x*A**x*C**x + + +def test_issue_6440(): + assert powsimp(16*2**a*8**b) == 2**(a + 3*b + 4) + + +def test_powdenest(): + x, y = symbols('x,y') + p, q = symbols('p q', positive=True) + i, j = symbols('i,j', integer=True) + + assert powdenest(x) == x + assert powdenest(x + 2*(x**(a*Rational(2, 3)))**(3*x)) == (x + 2*(x**(a*Rational(2, 3)))**(3*x)) + assert powdenest((exp(a*Rational(2, 3)))**(3*x)) # -X-> (exp(a/3))**(6*x) + assert powdenest((x**(a*Rational(2, 3)))**(3*x)) == ((x**(a*Rational(2, 3)))**(3*x)) + assert powdenest(exp(3*x*log(2))) == 2**(3*x) + assert powdenest(sqrt(p**2)) == p + eq = p**(2*i)*q**(4*i) + assert powdenest(eq) == (p*q**2)**(2*i) + # -X-> (x**x)**i*(x**x)**j == x**(x*(i + j)) + assert powdenest((x**x)**(i + j)) + assert powdenest(exp(3*y*log(x))) == x**(3*y) + assert powdenest(exp(y*(log(a) + log(b)))) == (a*b)**y + assert powdenest(exp(3*(log(a) + log(b)))) == a**3*b**3 + assert powdenest(((x**(2*i))**(3*y))**x) == ((x**(2*i))**(3*y))**x + assert powdenest(((x**(2*i))**(3*y))**x, force=True) == x**(6*i*x*y) + assert powdenest(((x**(a*Rational(2, 3)))**(3*y/i))**x) == \ + (((x**(a*Rational(2, 3)))**(3*y/i))**x) + assert powdenest((x**(2*i)*y**(4*i))**z, force=True) == (x*y**2)**(2*i*z) + assert powdenest((p**(2*i)*q**(4*i))**j) == (p*q**2)**(2*i*j) + e = ((p**(2*a))**(3*y))**x + assert powdenest(e) == e + e = ((x**2*y**4)**a)**(x*y) + assert powdenest(e) == e + e = (((x**2*y**4)**a)**(x*y))**3 + assert powdenest(e) == ((x**2*y**4)**a)**(3*x*y) + assert powdenest((((x**2*y**4)**a)**(x*y)), force=True) == \ + (x*y**2)**(2*a*x*y) + assert powdenest((((x**2*y**4)**a)**(x*y))**3, force=True) == \ + (x*y**2)**(6*a*x*y) + assert powdenest((x**2*y**6)**i) != (x*y**3)**(2*i) + x, y = symbols('x,y', positive=True) + assert powdenest((x**2*y**6)**i) == (x*y**3)**(2*i) + + assert powdenest((x**(i*Rational(2, 3))*y**(i/2))**(2*i)) == (x**Rational(4, 3)*y)**(i**2) + assert powdenest(sqrt(x**(2*i)*y**(6*i))) == (x*y**3)**i + + assert powdenest(4**x) == 2**(2*x) + assert powdenest((4**x)**y) == 2**(2*x*y) + assert powdenest(4**x*y) == 2**(2*x)*y + + +def test_powdenest_polar(): + x, y, z = symbols('x y z', polar=True) + a, b, c = symbols('a b c') + assert powdenest((x*y*z)**a) == x**a*y**a*z**a + assert powdenest((x**a*y**b)**c) == x**(a*c)*y**(b*c) + assert powdenest(((x**a)**b*y**c)**c) == x**(a*b*c)*y**(c**2) + + +def test_issue_5805(): + arg = ((gamma(x)*hyper((), (), x))*pi)**2 + assert powdenest(arg) == (pi*gamma(x)*hyper((), (), x))**2 + assert arg.is_positive is None + + +def test_issue_9324_powsimp_on_matrix_symbol(): + M = MatrixSymbol('M', 10, 10) + expr = powsimp(M, deep=True) + assert expr == M + assert expr.args[0] == Str('M') + + +def test_issue_6367(): + z = -5*sqrt(2)/(2*sqrt(2*sqrt(29) + 29)) + sqrt(-sqrt(29)/29 + S.Half) + assert Mul(*[powsimp(a) for a in Mul.make_args(z.normal())]) == 0 + assert powsimp(z.normal()) == 0 + assert simplify(z) == 0 + assert powsimp(sqrt(2 + sqrt(3))*sqrt(2 - sqrt(3)) + 1) == 2 + assert powsimp(z) != 0 + + +def test_powsimp_polar(): + from sympy.functions.elementary.complexes import polar_lift + from sympy.functions.elementary.exponential import exp_polar + x, y, z = symbols('x y z') + p, q, r = symbols('p q r', polar=True) + + assert (polar_lift(-1))**(2*x) == exp_polar(2*pi*I*x) + assert powsimp(p**x * q**x) == (p*q)**x + assert p**x * (1/p)**x == 1 + assert (1/p)**x == p**(-x) + + assert exp_polar(x)*exp_polar(y) == exp_polar(x)*exp_polar(y) + assert powsimp(exp_polar(x)*exp_polar(y)) == exp_polar(x + y) + assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y) == \ + (p*exp_polar(1))**(x + y) + assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y, combine='exp') == \ + exp_polar(x + y)*p**(x + y) + assert powsimp( + exp_polar(x)*exp_polar(y)*exp_polar(2)*sin(x) + sin(y) + p**x*p**y) \ + == p**(x + y) + sin(x)*exp_polar(2 + x + y) + sin(y) + assert powsimp(sin(exp_polar(x)*exp_polar(y))) == \ + sin(exp_polar(x)*exp_polar(y)) + assert powsimp(sin(exp_polar(x)*exp_polar(y)), deep=True) == \ + sin(exp_polar(x + y)) + + +def test_issue_5728(): + b = x*sqrt(y) + a = sqrt(b) + c = sqrt(sqrt(x)*y) + assert powsimp(a*b) == sqrt(b)**3 + assert powsimp(a*b**2*sqrt(y)) == sqrt(y)*a**5 + assert powsimp(a*x**2*c**3*y) == c**3*a**5 + assert powsimp(a*x*c**3*y**2) == c**7*a + assert powsimp(x*c**3*y**2) == c**7 + assert powsimp(x*c**3*y) == x*y*c**3 + assert powsimp(sqrt(x)*c**3*y) == c**5 + assert powsimp(sqrt(x)*a**3*sqrt(y)) == sqrt(x)*sqrt(y)*a**3 + assert powsimp(Mul(sqrt(x)*c**3*sqrt(y), y, evaluate=False)) == \ + sqrt(x)*sqrt(y)**3*c**3 + assert powsimp(a**2*a*x**2*y) == a**7 + + # symbolic powers work, too + b = x**y*y + a = b*sqrt(b) + assert a.is_Mul is True + assert powsimp(a) == sqrt(b)**3 + + # as does exp + a = x*exp(y*Rational(2, 3)) + assert powsimp(a*sqrt(a)) == sqrt(a)**3 + assert powsimp(a**2*sqrt(a)) == sqrt(a)**5 + assert powsimp(a**2*sqrt(sqrt(a))) == sqrt(sqrt(a))**9 + + +def test_issue_from_PR1599(): + n1, n2, n3, n4 = symbols('n1 n2 n3 n4', negative=True) + assert (powsimp(sqrt(n1)*sqrt(n2)*sqrt(n3)) == + -I*sqrt(-n1)*sqrt(-n2)*sqrt(-n3)) + assert (powsimp(root(n1, 3)*root(n2, 3)*root(n3, 3)*root(n4, 3)) == + -(-1)**Rational(1, 3)* + (-n1)**Rational(1, 3)*(-n2)**Rational(1, 3)*(-n3)**Rational(1, 3)*(-n4)**Rational(1, 3)) + + +def test_issue_10195(): + a = Symbol('a', integer=True) + l = Symbol('l', even=True, nonzero=True) + n = Symbol('n', odd=True) + e_x = (-1)**(n/2 - S.Half) - (-1)**(n*Rational(3, 2) - S.Half) + assert powsimp((-1)**(l/2)) == I**l + assert powsimp((-1)**(n/2)) == I**n + assert powsimp((-1)**(n*Rational(3, 2))) == -I**n + assert powsimp(e_x) == (-1)**(n/2 - S.Half) + (-1)**(n*Rational(3, 2) + + S.Half) + assert powsimp((-1)**(a*Rational(3, 2))) == (-I)**a + +def test_issue_15709(): + assert powsimp(3**x*Rational(2, 3)) == 2*3**(x-1) + assert powsimp(2*3**x/3) == 2*3**(x-1) + + +def test_issue_11981(): + x, y = symbols('x y', commutative=False) + assert powsimp((x*y)**2 * (y*x)**2) == (x*y)**2 * (y*x)**2 + + +def test_issue_17524(): + a = symbols("a", real=True) + e = (-1 - a**2)*sqrt(1 + a**2) + assert signsimp(powsimp(e)) == signsimp(e) == -(a**2 + 1)**(S(3)/2) + + +def test_issue_19627(): + # if you use force the user must verify + assert powdenest(sqrt(sin(x)**2), force=True) == sin(x) + assert powdenest((x**(S.Half/y))**(2*y), force=True) == x + from sympy.core.function import expand_power_base + e = 1 - a + expr = (exp(z/e)*x**(b/e)*y**((1 - b)/e))**e + assert powdenest(expand_power_base(expr, force=True), force=True + ) == x**b*y**(1 - b)*exp(z) + + +def test_issue_22546(): + p1, p2 = symbols('p1, p2', positive=True) + ref = powsimp(p1**z/p2**z) + e = z + 1 + ans = ref.subs(z, e) + assert ans.is_Pow + assert powsimp(p1**e/p2**e) == ans + i = symbols('i', integer=True) + ref = powsimp(x**i/y**i) + e = i + 1 + ans = ref.subs(i, e) + assert ans.is_Pow + assert powsimp(x**e/y**e) == ans diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_radsimp.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_radsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ff955e48a34536c1752c565c0864dedae6a214 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_radsimp.py @@ -0,0 +1,498 @@ +from sympy.core.add import Add +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, Wild, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.polys.polytools import factor +from sympy.series.order import O +from sympy.simplify.radsimp import (collect, collect_const, fraction, radsimp, rcollect) + +from sympy.core.expr import unchanged +from sympy.core.mul import _unevaluated_Mul as umul +from sympy.simplify.radsimp import (_unevaluated_Add, + collect_sqrt, fraction_expand, collect_abs) +from sympy.testing.pytest import raises + +from sympy.abc import x, y, z, a, b, c, d + + +def test_radsimp(): + r2 = sqrt(2) + r3 = sqrt(3) + r5 = sqrt(5) + r7 = sqrt(7) + assert fraction(radsimp(1/r2)) == (sqrt(2), 2) + assert radsimp(1/(1 + r2)) == \ + -1 + sqrt(2) + assert radsimp(1/(r2 + r3)) == \ + -sqrt(2) + sqrt(3) + assert fraction(radsimp(1/(1 + r2 + r3))) == \ + (-sqrt(6) + sqrt(2) + 2, 4) + assert fraction(radsimp(1/(r2 + r3 + r5))) == \ + (-sqrt(30) + 2*sqrt(3) + 3*sqrt(2), 12) + assert fraction(radsimp(1/(1 + r2 + r3 + r5))) == ( + (-34*sqrt(10) - 26*sqrt(15) - 55*sqrt(3) - 61*sqrt(2) + 14*sqrt(30) + + 93 + 46*sqrt(6) + 53*sqrt(5), 71)) + assert fraction(radsimp(1/(r2 + r3 + r5 + r7))) == ( + (-50*sqrt(42) - 133*sqrt(5) - 34*sqrt(70) - 145*sqrt(3) + 22*sqrt(105) + + 185*sqrt(2) + 62*sqrt(30) + 135*sqrt(7), 215)) + z = radsimp(1/(1 + r2/3 + r3/5 + r5 + r7)) + assert len((3616791619821680643598*z).args) == 16 + assert radsimp(1/z) == 1/z + assert radsimp(1/z, max_terms=20).expand() == 1 + r2/3 + r3/5 + r5 + r7 + assert radsimp(1/(r2*3)) == \ + sqrt(2)/6 + assert radsimp(1/(r2*a + r3 + r5 + r7)) == ( + (8*sqrt(2)*a**7 - 8*sqrt(7)*a**6 - 8*sqrt(5)*a**6 - 8*sqrt(3)*a**6 - + 180*sqrt(2)*a**5 + 8*sqrt(30)*a**5 + 8*sqrt(42)*a**5 + 8*sqrt(70)*a**5 + - 24*sqrt(105)*a**4 + 84*sqrt(3)*a**4 + 100*sqrt(5)*a**4 + + 116*sqrt(7)*a**4 - 72*sqrt(70)*a**3 - 40*sqrt(42)*a**3 - + 8*sqrt(30)*a**3 + 782*sqrt(2)*a**3 - 462*sqrt(3)*a**2 - + 302*sqrt(7)*a**2 - 254*sqrt(5)*a**2 + 120*sqrt(105)*a**2 - + 795*sqrt(2)*a - 62*sqrt(30)*a + 82*sqrt(42)*a + 98*sqrt(70)*a - + 118*sqrt(105) + 59*sqrt(7) + 295*sqrt(5) + 531*sqrt(3))/(16*a**8 - + 480*a**6 + 3128*a**4 - 6360*a**2 + 3481)) + assert radsimp(1/(r2*a + r2*b + r3 + r7)) == ( + (sqrt(2)*a*(a + b)**2 - 5*sqrt(2)*a + sqrt(42)*a + sqrt(2)*b*(a + + b)**2 - 5*sqrt(2)*b + sqrt(42)*b - sqrt(7)*(a + b)**2 - sqrt(3)*(a + + b)**2 - 2*sqrt(3) + 2*sqrt(7))/(2*a**4 + 8*a**3*b + 12*a**2*b**2 - + 20*a**2 + 8*a*b**3 - 40*a*b + 2*b**4 - 20*b**2 + 8)) + assert radsimp(1/(r2*a + r2*b + r2*c + r2*d)) == \ + sqrt(2)/(2*a + 2*b + 2*c + 2*d) + assert radsimp(1/(1 + r2*a + r2*b + r2*c + r2*d)) == ( + (sqrt(2)*a + sqrt(2)*b + sqrt(2)*c + sqrt(2)*d - 1)/(2*a**2 + 4*a*b + + 4*a*c + 4*a*d + 2*b**2 + 4*b*c + 4*b*d + 2*c**2 + 4*c*d + 2*d**2 - 1)) + assert radsimp((y**2 - x)/(y - sqrt(x))) == \ + sqrt(x) + y + assert radsimp(-(y**2 - x)/(y - sqrt(x))) == \ + -(sqrt(x) + y) + assert radsimp(1/(1 - I + a*I)) == \ + (-I*a + 1 + I)/(a**2 - 2*a + 2) + assert radsimp(1/((-x + y)*(x - sqrt(y)))) == \ + (-x - sqrt(y))/((x - y)*(x**2 - y)) + e = (3 + 3*sqrt(2))*x*(3*x - 3*sqrt(y)) + assert radsimp(e) == x*(3 + 3*sqrt(2))*(3*x - 3*sqrt(y)) + assert radsimp(1/e) == ( + (-9*x + 9*sqrt(2)*x - 9*sqrt(y) + 9*sqrt(2)*sqrt(y))/(9*x*(9*x**2 - + 9*y))) + assert radsimp(1 + 1/(1 + sqrt(3))) == \ + Mul(S.Half, -1 + sqrt(3), evaluate=False) + 1 + A = symbols("A", commutative=False) + assert radsimp(x**2 + sqrt(2)*x**2 - sqrt(2)*x*A) == \ + x**2 + sqrt(2)*x**2 - sqrt(2)*x*A + assert radsimp(1/sqrt(5 + 2 * sqrt(6))) == -sqrt(2) + sqrt(3) + assert radsimp(1/sqrt(5 + 2 * sqrt(6))**3) == -(-sqrt(3) + sqrt(2))**3 + + # issue 6532 + assert fraction(radsimp(1/sqrt(x))) == (sqrt(x), x) + assert fraction(radsimp(1/sqrt(2*x + 3))) == (sqrt(2*x + 3), 2*x + 3) + assert fraction(radsimp(1/sqrt(2*(x + 3)))) == (sqrt(2*x + 6), 2*x + 6) + + # issue 5994 + e = S('-(2 + 2*sqrt(2) + 4*2**(1/4))/' + '(1 + 2**(3/4) + 3*2**(1/4) + 3*sqrt(2))') + assert radsimp(e).expand() == -2*2**Rational(3, 4) - 2*2**Rational(1, 4) + 2 + 2*sqrt(2) + + # issue 5986 (modifications to radimp didn't initially recognize this so + # the test is included here) + assert radsimp(1/(-sqrt(5)/2 - S.Half + (-sqrt(5)/2 - S.Half)**2)) == 1 + + # from issue 5934 + eq = ( + (-240*sqrt(2)*sqrt(sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) - + 360*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) - + 120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) + + 120*sqrt(2)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) + + 120*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5) + + 120*sqrt(10)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) + + 120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5))/(-36000 - + 7200*sqrt(5) + (12*sqrt(10)*sqrt(sqrt(5) + 5) + + 24*sqrt(10)*sqrt(-sqrt(5) + 5))**2)) + assert radsimp(eq) is S.NaN # it's 0/0 + + # work with normal form + e = 1/sqrt(sqrt(7)/7 + 2*sqrt(2) + 3*sqrt(3) + 5*sqrt(5)) + 3 + assert radsimp(e) == ( + -sqrt(sqrt(7) + 14*sqrt(2) + 21*sqrt(3) + + 35*sqrt(5))*(-11654899*sqrt(35) - 1577436*sqrt(210) - 1278438*sqrt(15) + - 1346996*sqrt(10) + 1635060*sqrt(6) + 5709765 + 7539830*sqrt(14) + + 8291415*sqrt(21))/1300423175 + 3) + + # obey power rules + base = sqrt(3) - sqrt(2) + assert radsimp(1/base**3) == (sqrt(3) + sqrt(2))**3 + assert radsimp(1/(-base)**3) == -(sqrt(2) + sqrt(3))**3 + assert radsimp(1/(-base)**x) == (-base)**(-x) + assert radsimp(1/base**x) == (sqrt(2) + sqrt(3))**x + assert radsimp(root(1/(-1 - sqrt(2)), -x)) == (-1)**(-1/x)*(1 + sqrt(2))**(1/x) + + # recurse + e = cos(1/(1 + sqrt(2))) + assert radsimp(e) == cos(-sqrt(2) + 1) + assert radsimp(e/2) == cos(-sqrt(2) + 1)/2 + assert radsimp(1/e) == 1/cos(-sqrt(2) + 1) + assert radsimp(2/e) == 2/cos(-sqrt(2) + 1) + assert fraction(radsimp(e/sqrt(x))) == (sqrt(x)*cos(-sqrt(2)+1), x) + + # test that symbolic denominators are not processed + r = 1 + sqrt(2) + assert radsimp(x/r, symbolic=False) == -x*(-sqrt(2) + 1) + assert radsimp(x/(y + r), symbolic=False) == x/(y + 1 + sqrt(2)) + assert radsimp(x/(y + r)/r, symbolic=False) == \ + -x*(-sqrt(2) + 1)/(y + 1 + sqrt(2)) + + # issue 7408 + eq = sqrt(x)/sqrt(y) + assert radsimp(eq) == umul(sqrt(x), sqrt(y), 1/y) + assert radsimp(eq, symbolic=False) == eq + + # issue 7498 + assert radsimp(sqrt(x)/sqrt(y)**3) == umul(sqrt(x), sqrt(y**3), 1/y**3) + + # for coverage + eq = sqrt(x)/y**2 + assert radsimp(eq) == eq + + # handle non-Expr args + from sympy.integrals.integrals import Integral + eq = Integral(x/(sqrt(2) - 1), (x, 0, 1/(sqrt(2) + 1))) + assert radsimp(eq) == Integral((sqrt(2) + 1)*x , (x, 0, sqrt(2) - 1)) + + from sympy.sets import FiniteSet + eq = FiniteSet(x/(sqrt(2) - 1)) + assert radsimp(eq) == FiniteSet((sqrt(2) + 1)*x) + +def test_radsimp_issue_3214(): + c, p = symbols('c p', positive=True) + s = sqrt(c**2 - p**2) + b = (c + I*p - s)/(c + I*p + s) + assert radsimp(b) == -I*(c + I*p - sqrt(c**2 - p**2))**2/(2*c*p) + + +def test_collect_1(): + """Collect with respect to Symbol""" + x, y, z, n = symbols('x,y,z,n') + assert collect(1, x) == 1 + assert collect( x + y*x, x ) == x * (1 + y) + assert collect( x + x**2, x ) == x + x**2 + assert collect( x**2 + y*x**2, x ) == (x**2)*(1 + y) + assert collect( x**2 + y*x, x ) == x*y + x**2 + assert collect( 2*x**2 + y*x**2 + 3*x*y, [x] ) == x**2*(2 + y) + 3*x*y + assert collect( 2*x**2 + y*x**2 + 3*x*y, [y] ) == 2*x**2 + y*(x**2 + 3*x) + + assert collect( ((1 + y + x)**4).expand(), x) == ((1 + y)**4).expand() + \ + x*(4*(1 + y)**3).expand() + x**2*(6*(1 + y)**2).expand() + \ + x**3*(4*(1 + y)).expand() + x**4 + # symbols can be given as any iterable + expr = x + y + assert collect(expr, expr.free_symbols) == expr + assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None + ) == x*exp(x) + 3*x + (y + 2)*sin(x) + assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x + y*x + + y*x*exp(x), x, exact=None + ) == x*exp(x)*(y + 1) + (3 + y)*x + (y + 2)*sin(x) + + +def test_collect_2(): + """Collect with respect to a sum""" + a, b, x = symbols('a,b,x') + assert collect(a*(cos(x) + sin(x)) + b*(cos(x) + sin(x)), + sin(x) + cos(x)) == (a + b)*(cos(x) + sin(x)) + + +def test_collect_3(): + """Collect with respect to a product""" + a, b, c = symbols('a,b,c') + f = Function('f') + x, y, z, n = symbols('x,y,z,n') + + assert collect(-x/8 + x*y, -x) == x*(y - Rational(1, 8)) + + assert collect( 1 + x*(y**2), x*y ) == 1 + x*(y**2) + assert collect( x*y + a*x*y, x*y) == x*y*(1 + a) + assert collect( 1 + x*y + a*x*y, x*y) == 1 + x*y*(1 + a) + assert collect(a*x*f(x) + b*(x*f(x)), x*f(x)) == x*(a + b)*f(x) + + assert collect(a*x*log(x) + b*(x*log(x)), x*log(x)) == x*(a + b)*log(x) + assert collect(a*x**2*log(x)**2 + b*(x*log(x))**2, x*log(x)) == \ + x**2*log(x)**2*(a + b) + + # with respect to a product of three symbols + assert collect(y*x*z + a*x*y*z, x*y*z) == (1 + a)*x*y*z + + +def test_collect_4(): + """Collect with respect to a power""" + a, b, c, x = symbols('a,b,c,x') + + assert collect(a*x**c + b*x**c, x**c) == x**c*(a + b) + # issue 6096: 2 stays with c (unless c is integer or x is positive0 + assert collect(a*x**(2*c) + b*x**(2*c), x**c) == x**(2*c)*(a + b) + + +def test_collect_5(): + """Collect with respect to a tuple""" + a, x, y, z, n = symbols('a,x,y,z,n') + assert collect(x**2*y**4 + z*(x*y**2)**2 + z + a*z, [x*y**2, z]) in [ + z*(1 + a + x**2*y**4) + x**2*y**4, + z*(1 + a) + x**2*y**4*(1 + z) ] + assert collect((1 + (x + y) + (x + y)**2).expand(), + [x, y]) == 1 + y + x*(1 + 2*y) + x**2 + y**2 + + +def test_collect_pr19431(): + """Unevaluated collect with respect to a product""" + a = symbols('a') + assert collect(a**2*(a**2 + 1), a**2, evaluate=False)[a**2] == (a**2 + 1) + + +def test_collect_D(): + D = Derivative + f = Function('f') + x, a, b = symbols('x,a,b') + fx = D(f(x), x) + fxx = D(f(x), x, x) + + assert collect(a*fx + b*fx, fx) == (a + b)*fx + assert collect(a*D(fx, x) + b*D(fx, x), fx) == (a + b)*D(fx, x) + assert collect(a*fxx + b*fxx, fx) == (a + b)*D(fx, x) + # issue 4784 + assert collect(5*f(x) + 3*fx, fx) == 5*f(x) + 3*fx + assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x)) == \ + (x*f(x) + f(x))*D(f(x), x) + f(x) + assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x), exact=True) == \ + (x*f(x) + f(x))*D(f(x), x) + f(x) + assert collect(1/f(x) + 1/f(x)*diff(f(x), x) + x*diff(f(x), x)/f(x), f(x).diff(x), exact=True) == \ + (1/f(x) + x/f(x))*D(f(x), x) + 1/f(x) + e = (1 + x*fx + fx)/f(x) + assert collect(e.expand(), fx) == fx*(x/f(x) + 1/f(x)) + 1/f(x) + + +def test_collect_func(): + f = ((x + a + 1)**3).expand() + + assert collect(f, x) == a**3 + 3*a**2 + 3*a + x**3 + x**2*(3*a + 3) + \ + x*(3*a**2 + 6*a + 3) + 1 + assert collect(f, x, factor) == x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + \ + (a + 1)**3 + + assert collect(f, x, evaluate=False) == { + S.One: a**3 + 3*a**2 + 3*a + 1, + x: 3*a**2 + 6*a + 3, x**2: 3*a + 3, + x**3: 1 + } + + assert collect(f, x, factor, evaluate=False) == { + S.One: (a + 1)**3, x: 3*(a + 1)**2, + x**2: umul(S(3), a + 1), x**3: 1} + + +def test_collect_order(): + a, b, x, t = symbols('a,b,x,t') + + assert collect(t + t*x + t*x**2 + O(x**3), t) == t*(1 + x + x**2 + O(x**3)) + assert collect(t + t*x + x**2 + O(x**3), t) == \ + t*(1 + x + O(x**3)) + x**2 + O(x**3) + + f = a*x + b*x + c*x**2 + d*x**2 + O(x**3) + g = x*(a + b) + x**2*(c + d) + O(x**3) + + assert collect(f, x) == g + assert collect(f, x, distribute_order_term=False) == g + + f = sin(a + b).series(b, 0, 10) + + assert collect(f, [sin(a), cos(a)]) == \ + sin(a)*cos(b).series(b, 0, 10) + cos(a)*sin(b).series(b, 0, 10) + assert collect(f, [sin(a), cos(a)], distribute_order_term=False) == \ + sin(a)*cos(b).series(b, 0, 10).removeO() + \ + cos(a)*sin(b).series(b, 0, 10).removeO() + O(b**10) + + +def test_rcollect(): + assert rcollect((x**2*y + x*y + x + y)/(x + y), y) == \ + (x + y*(1 + x + x**2))/(x + y) + assert rcollect(sqrt(-((x + 1)*(y + 1))), z) == sqrt(-((x + 1)*(y + 1))) + + +def test_collect_D_0(): + D = Derivative + f = Function('f') + x, a, b = symbols('x,a,b') + fxx = D(f(x), x, x) + + assert collect(a*fxx + b*fxx, fxx) == (a + b)*fxx + + +def test_collect_Wild(): + """Collect with respect to functions with Wild argument""" + a, b, x, y = symbols('a b x y') + f = Function('f') + w1 = Wild('.1') + w2 = Wild('.2') + assert collect(f(x) + a*f(x), f(w1)) == (1 + a)*f(x) + assert collect(f(x, y) + a*f(x, y), f(w1)) == f(x, y) + a*f(x, y) + assert collect(f(x, y) + a*f(x, y), f(w1, w2)) == (1 + a)*f(x, y) + assert collect(f(x, y) + a*f(x, y), f(w1, w1)) == f(x, y) + a*f(x, y) + assert collect(f(x, x) + a*f(x, x), f(w1, w1)) == (1 + a)*f(x, x) + assert collect(a*(x + 1)**y + (x + 1)**y, w1**y) == (1 + a)*(x + 1)**y + assert collect(a*(x + 1)**y + (x + 1)**y, w1**b) == \ + a*(x + 1)**y + (x + 1)**y + assert collect(a*(x + 1)**y + (x + 1)**y, (x + 1)**w2) == \ + (1 + a)*(x + 1)**y + assert collect(a*(x + 1)**y + (x + 1)**y, w1**w2) == (1 + a)*(x + 1)**y + + +def test_collect_const(): + # coverage not provided by above tests + assert collect_const(2*sqrt(3) + 4*a*sqrt(5)) == \ + 2*(2*sqrt(5)*a + sqrt(3)) # let the primitive reabsorb + assert collect_const(2*sqrt(3) + 4*a*sqrt(5), sqrt(3)) == \ + 2*sqrt(3) + 4*a*sqrt(5) + assert collect_const(sqrt(2)*(1 + sqrt(2)) + sqrt(3) + x*sqrt(2)) == \ + sqrt(2)*(x + 1 + sqrt(2)) + sqrt(3) + + # issue 5290 + assert collect_const(2*x + 2*y + 1, 2) == \ + collect_const(2*x + 2*y + 1) == \ + Add(S.One, Mul(2, x + y, evaluate=False), evaluate=False) + assert collect_const(-y - z) == Mul(-1, y + z, evaluate=False) + assert collect_const(2*x - 2*y - 2*z, 2) == \ + Mul(2, x - y - z, evaluate=False) + assert collect_const(2*x - 2*y - 2*z, -2) == \ + _unevaluated_Add(2*x, Mul(-2, y + z, evaluate=False)) + + # this is why the content_primitive is used + eq = (sqrt(15 + 5*sqrt(2))*x + sqrt(3 + sqrt(2))*y)*2 + assert collect_sqrt(eq + 2) == \ + 2*sqrt(sqrt(2) + 3)*(sqrt(5)*x + y) + 2 + + # issue 16296 + assert collect_const(a + b + x/2 + y/2) == a + b + Mul(S.Half, x + y, evaluate=False) + + +def test_issue_13143(): + f = Function('f') + fx = f(x).diff(x) + e = f(x) + fx + f(x)*fx + # collect function before derivative + assert collect(e, Wild('w')) == f(x)*(fx + 1) + fx + e = f(x) + f(x)*fx + x*fx*f(x) + assert collect(e, fx) == (x*f(x) + f(x))*fx + f(x) + assert collect(e, f(x)) == (x*fx + fx + 1)*f(x) + e = f(x) + fx + f(x)*fx + assert collect(e, [f(x), fx]) == f(x)*(1 + fx) + fx + assert collect(e, [fx, f(x)]) == fx*(1 + f(x)) + f(x) + + +def test_issue_6097(): + assert collect(a*y**(2.0*x) + b*y**(2.0*x), y**x) == (a + b)*(y**x)**2.0 + assert collect(a*2**(2.0*x) + b*2**(2.0*x), 2**x) == (a + b)*(2**x)**2.0 + + +def test_fraction_expand(): + eq = (x + y)*y/x + assert eq.expand(frac=True) == fraction_expand(eq) == (x*y + y**2)/x + assert eq.expand() == y + y**2/x + + +def test_fraction(): + x, y, z = map(Symbol, 'xyz') + A = Symbol('A', commutative=False) + + assert fraction(S.Half) == (1, 2) + + assert fraction(x) == (x, 1) + assert fraction(1/x) == (1, x) + assert fraction(x/y) == (x, y) + assert fraction(x/2) == (x, 2) + + assert fraction(x*y/z) == (x*y, z) + assert fraction(x/(y*z)) == (x, y*z) + + assert fraction(1/y**2) == (1, y**2) + assert fraction(x/y**2) == (x, y**2) + + assert fraction((x**2 + 1)/y) == (x**2 + 1, y) + assert fraction(x*(y + 1)/y**7) == (x*(y + 1), y**7) + + assert fraction(exp(-x), exact=True) == (exp(-x), 1) + assert fraction((1/(x + y))/2, exact=True) == (1, Mul(2,(x + y), evaluate=False)) + + assert fraction(x*A/y) == (x*A, y) + assert fraction(x*A**-1/y) == (x*A**-1, y) + + n = symbols('n', negative=True) + assert fraction(exp(n)) == (1, exp(-n)) + assert fraction(exp(-n)) == (exp(-n), 1) + + p = symbols('p', positive=True) + assert fraction(exp(-p)*log(p), exact=True) == (exp(-p)*log(p), 1) + + m = Mul(1, 1, S.Half, evaluate=False) + assert fraction(m) == (1, 2) + assert fraction(m, exact=True) == (Mul(1, 1, evaluate=False), 2) + + m = Mul(1, 1, S.Half, S.Half, Pow(1, -1, evaluate=False), evaluate=False) + assert fraction(m) == (1, 4) + assert fraction(m, exact=True) == \ + (Mul(1, 1, evaluate=False), Mul(2, 2, 1, evaluate=False)) + + +def test_issue_5615(): + aA, Re, a, b, D = symbols('aA Re a b D') + e = ((D**3*a + b*aA**3)/Re).expand() + assert collect(e, [aA**3/Re, a]) == e + + +def test_issue_5933(): + from sympy.geometry.polygon import (Polygon, RegularPolygon) + from sympy.simplify.radsimp import denom + x = Polygon(*RegularPolygon((0, 0), 1, 5).vertices).centroid.x + assert abs(denom(x).n()) > 1e-12 + assert abs(denom(radsimp(x))) > 1e-12 # in case simplify didn't handle it + + +def test_issue_14608(): + a, b = symbols('a b', commutative=False) + x, y = symbols('x y') + raises(AttributeError, lambda: collect(a*b + b*a, a)) + assert collect(x*y + y*(x+1), a) == x*y + y*(x+1) + assert collect(x*y + y*(x+1) + a*b + b*a, y) == y*(2*x + 1) + a*b + b*a + + +def test_collect_abs(): + s = abs(x) + abs(y) + assert collect_abs(s) == s + assert unchanged(Mul, abs(x), abs(y)) + ans = Abs(x*y) + assert isinstance(ans, Abs) + assert collect_abs(abs(x)*abs(y)) == ans + assert collect_abs(1 + exp(abs(x)*abs(y))) == 1 + exp(ans) + + # See https://github.com/sympy/sympy/issues/12910 + p = Symbol('p', positive=True) + assert collect_abs(p/abs(1-p)).is_commutative is True + + +def test_issue_19149(): + eq = exp(3*x/4) + assert collect(eq, exp(x)) == eq + +def test_issue_19719(): + a, b = symbols('a, b') + expr = a**2 * (b + 1) + (7 + 1/b)/a + collected = collect(expr, (a**2, 1/a), evaluate=False) + # Would return {_Dummy_20**(-2): b + 1, 1/a: 7 + 1/b} without xreplace + assert collected == {a**2: b + 1, 1/a: 7 + 1/b} + + +def test_issue_21355(): + assert radsimp(1/(x + sqrt(x**2))) == 1/(x + sqrt(x**2)) + assert radsimp(1/(x - sqrt(x**2))) == 1/(x - sqrt(x**2)) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_ratsimp.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_ratsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..14e84fd2b227518baff1bda4e5b27ecc40a8bcdd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_ratsimp.py @@ -0,0 +1,78 @@ +from sympy.core.numbers import (Rational, pi) +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.error_functions import erf +from sympy.polys.domains import GF +from sympy.simplify.ratsimp import (ratsimp, ratsimpmodprime) + +from sympy.abc import x, y, z, t, a, b, c, d, e + + +def test_ratsimp(): + f, g = 1/x + 1/y, (x + y)/(x*y) + + assert f != g and ratsimp(f) == g + + f, g = 1/(1 + 1/x), 1 - 1/(x + 1) + + assert f != g and ratsimp(f) == g + + f, g = x/(x + y) + y/(x + y), 1 + + assert f != g and ratsimp(f) == g + + f, g = -x - y - y**2/(x + y) + x**2/(x + y), -2*y + + assert f != g and ratsimp(f) == g + + f = (a*c*x*y + a*c*z - b*d*x*y - b*d*z - b*t*x*y - b*t*x - b*t*z + + e*x)/(x*y + z) + G = [a*c - b*d - b*t + (-b*t*x + e*x)/(x*y + z), + a*c - b*d - b*t - ( b*t*x - e*x)/(x*y + z)] + + assert f != g and ratsimp(f) in G + + A = sqrt(pi) + + B = log(erf(x) - 1) + C = log(erf(x) + 1) + + D = 8 - 8*erf(x) + + f = A*B/D - A*C/D + A*C*erf(x)/D - A*B*erf(x)/D + 2*A/D + + assert ratsimp(f) == A*B/8 - A*C/8 - A/(4*erf(x) - 4) + + +def test_ratsimpmodprime(): + a = y**5 + x + y + b = x - y + F = [x*y**5 - x - y] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + (-x**2 - x*y - x - y) / (-x**2 + x*y) + + a = x + y**2 - 2 + b = x + y**2 - y - 1 + F = [x*y - 1] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + (1 + y - x)/(y - x) + + a = 5*x**3 + 21*x**2 + 4*x*y + 23*x + 12*y + 15 + b = 7*x**3 - y*x**2 + 31*x**2 + 2*x*y + 15*y + 37*x + 21 + F = [x**2 + y**2 - 1] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + (1 + 5*y - 5*x)/(8*y - 6*x) + + a = x*y - x - 2*y + 4 + b = x + y**2 - 2*y + F = [x - 2, y - 3] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + Rational(2, 5) + + # Test a bug where denominators would be dropped + assert ratsimpmodprime(x, [y - 2*x], order='lex') == \ + y/2 + + a = (x**5 + 2*x**4 + 2*x**3 + 2*x**2 + x + 2/x + x**(-2)) + assert ratsimpmodprime(a, [x + 1], domain=GF(2)) == 1 + assert ratsimpmodprime(a, [x + 1], domain=GF(3)) == -1 diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_rewrite.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..56d2fb7a85bd959bd4accc2f36127429efbdbe70 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_rewrite.py @@ -0,0 +1,31 @@ +from sympy.core.numbers import I +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import (cos, cot, sin) +from sympy.testing.pytest import _both_exp_pow + +x, y, z, n = symbols('x,y,z,n') + + +@_both_exp_pow +def test_has(): + assert cot(x).has(x) + assert cot(x).has(cot) + assert not cot(x).has(sin) + assert sin(x).has(x) + assert sin(x).has(sin) + assert not sin(x).has(cot) + assert exp(x).has(exp) + + +@_both_exp_pow +def test_sin_exp_rewrite(): + assert sin(x).rewrite(sin, exp) == -I/2*(exp(I*x) - exp(-I*x)) + assert sin(x).rewrite(sin, exp).rewrite(exp, sin) == sin(x) + assert cos(x).rewrite(cos, exp).rewrite(exp, cos) == cos(x) + assert (sin(5*y) - sin( + 2*x)).rewrite(sin, exp).rewrite(exp, sin) == sin(5*y) - sin(2*x) + assert sin(x + y).rewrite(sin, exp).rewrite(exp, sin) == sin(x + y) + assert cos(x + y).rewrite(cos, exp).rewrite(exp, cos) == cos(x + y) + # This next test currently passes... not clear whether it should or not? + assert cos(x).rewrite(cos, exp).rewrite(exp, sin) == cos(x) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_simplify.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_simplify.py new file mode 100644 index 0000000000000000000000000000000000000000..a5bf469f68adf5c5dfbdf7559414681e2fb28ba7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_simplify.py @@ -0,0 +1,1093 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.expr import unchanged +from sympy.core.function import (count_ops, diff, expand, expand_multinomial, Function, Derivative) +from sympy.core.mul import Mul, _keep_coeff +from sympy.core import GoldenRatio +from sympy.core.numbers import (E, Float, I, oo, pi, Rational, zoo) +from sympy.core.relational import (Eq, Lt, Gt, Ge, Le) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.complexes import (Abs, sign) +from sympy.functions.elementary.exponential import (exp, exp_polar, log) +from sympy.functions.elementary.hyperbolic import (cosh, csch, sinh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, sinc, tan) +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import hyper +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.geometry.polygon import rad +from sympy.integrals.integrals import (Integral, integrate) +from sympy.logic.boolalg import (And, Or) +from sympy.matrices.dense import (Matrix, eye) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.polys.polytools import (factor, Poly) +from sympy.simplify.simplify import (besselsimp, hypersimp, inversecombine, logcombine, nsimplify, nthroot, posify, separatevars, signsimp, simplify) +from sympy.solvers.solvers import solve + +from sympy.testing.pytest import XFAIL, slow, _both_exp_pow +from sympy.abc import x, y, z, t, a, b, c, d, e, f, g, h, i, n + + +def test_issue_7263(): + assert abs((simplify(30.8**2 - 82.5**2 * sin(rad(11.6))**2)).evalf() - \ + 673.447451402970) < 1e-12 + + +def test_factorial_simplify(): + # There are more tests in test_factorials.py. + x = Symbol('x') + assert simplify(factorial(x)/x) == gamma(x) + assert simplify(factorial(factorial(x))) == factorial(factorial(x)) + + +def test_simplify_expr(): + x, y, z, k, n, m, w, s, A = symbols('x,y,z,k,n,m,w,s,A') + f = Function('f') + + assert all(simplify(tmp) == tmp for tmp in [I, E, oo, x, -x, -oo, -E, -I]) + + e = 1/x + 1/y + assert e != (x + y)/(x*y) + assert simplify(e) == (x + y)/(x*y) + + e = A**2*s**4/(4*pi*k*m**3) + assert simplify(e) == e + + e = (4 + 4*x - 2*(2 + 2*x))/(2 + 2*x) + assert simplify(e) == 0 + + e = (-4*x*y**2 - 2*y**3 - 2*x**2*y)/(x + y)**2 + assert simplify(e) == -2*y + + e = -x - y - (x + y)**(-1)*y**2 + (x + y)**(-1)*x**2 + assert simplify(e) == -2*y + + e = (x + x*y)/x + assert simplify(e) == 1 + y + + e = (f(x) + y*f(x))/f(x) + assert simplify(e) == 1 + y + + e = (2 * (1/n - cos(n * pi)/n))/pi + assert simplify(e) == (-cos(pi*n) + 1)/(pi*n)*2 + + e = integrate(1/(x**3 + 1), x).diff(x) + assert simplify(e) == 1/(x**3 + 1) + + e = integrate(x/(x**2 + 3*x + 1), x).diff(x) + assert simplify(e) == x/(x**2 + 3*x + 1) + + f = Symbol('f') + A = Matrix([[2*k - m*w**2, -k], [-k, k - m*w**2]]).inv() + assert simplify((A*Matrix([0, f]))[1] - + (-f*(2*k - m*w**2)/(k**2 - (k - m*w**2)*(2*k - m*w**2)))) == 0 + + f = -x + y/(z + t) + z*x/(z + t) + z*a/(z + t) + t*x/(z + t) + assert simplify(f) == (y + a*z)/(z + t) + + # issue 10347 + expr = -x*(y**2 - 1)*(2*y**2*(x**2 - 1)/(a*(x**2 - y**2)**2) + (x**2 - 1) + /(a*(x**2 - y**2)))/(a*(x**2 - y**2)) + x*(-2*x**2*sqrt(-x**2*y**2 + x**2 + + y**2 - 1)*sin(z)/(a*(x**2 - y**2)**2) - x**2*sqrt(-x**2*y**2 + x**2 + + y**2 - 1)*sin(z)/(a*(x**2 - 1)*(x**2 - y**2)) + (x**2*sqrt((-x**2 + 1)* + (y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(x**2 - 1) + sqrt( + (-x**2 + 1)*(y**2 - 1))*(x*(-x*y**2 + x)/sqrt(-x**2*y**2 + x**2 + y**2 - + 1) + sqrt(-x**2*y**2 + x**2 + y**2 - 1))*sin(z))/(a*sqrt((-x**2 + 1)*( + y**2 - 1))*(x**2 - y**2)))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(a* + (x**2 - y**2)) + x*(-2*x**2*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a* + (x**2 - y**2)**2) - x**2*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a* + (x**2 - 1)*(x**2 - y**2)) + (x**2*sqrt((-x**2 + 1)*(y**2 - 1))*sqrt(-x**2 + *y**2 + x**2 + y**2 - 1)*cos(z)/(x**2 - 1) + x*sqrt((-x**2 + 1)*(y**2 - + 1))*(-x*y**2 + x)*cos(z)/sqrt(-x**2*y**2 + x**2 + y**2 - 1) + sqrt((-x**2 + + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z))/(a*sqrt((-x**2 + + 1)*(y**2 - 1))*(x**2 - y**2)))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos( + z)/(a*(x**2 - y**2)) - y*sqrt((-x**2 + 1)*(y**2 - 1))*(-x*y*sqrt(-x**2* + y**2 + x**2 + y**2 - 1)*sin(z)/(a*(x**2 - y**2)*(y**2 - 1)) + 2*x*y*sqrt( + -x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(a*(x**2 - y**2)**2) + (x*y*sqrt(( + -x**2 + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(y**2 - + 1) + x*sqrt((-x**2 + 1)*(y**2 - 1))*(-x**2*y + y)*sin(z)/sqrt(-x**2*y**2 + + x**2 + y**2 - 1))/(a*sqrt((-x**2 + 1)*(y**2 - 1))*(x**2 - y**2)))*sin( + z)/(a*(x**2 - y**2)) + y*(x**2 - 1)*(-2*x*y*(x**2 - 1)/(a*(x**2 - y**2) + **2) + 2*x*y/(a*(x**2 - y**2)))/(a*(x**2 - y**2)) + y*(x**2 - 1)*(y**2 - + 1)*(-x*y*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a*(x**2 - y**2)*(y**2 + - 1)) + 2*x*y*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a*(x**2 - y**2) + **2) + (x*y*sqrt((-x**2 + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - + 1)*cos(z)/(y**2 - 1) + x*sqrt((-x**2 + 1)*(y**2 - 1))*(-x**2*y + y)*cos( + z)/sqrt(-x**2*y**2 + x**2 + y**2 - 1))/(a*sqrt((-x**2 + 1)*(y**2 - 1) + )*(x**2 - y**2)))*cos(z)/(a*sqrt((-x**2 + 1)*(y**2 - 1))*(x**2 - y**2) + ) - x*sqrt((-x**2 + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin( + z)**2/(a**2*(x**2 - 1)*(x**2 - y**2)*(y**2 - 1)) - x*sqrt((-x**2 + 1)*( + y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)**2/(a**2*(x**2 - 1)*( + x**2 - y**2)*(y**2 - 1)) + assert simplify(expr) == 2*x/(a**2*(x**2 - y**2)) + + #issue 17631 + assert simplify('((-1/2)*Boole(True)*Boole(False)-1)*Boole(True)') == \ + Mul(sympify('(2 + Boole(True)*Boole(False))'), sympify('-Boole(True)/2')) + + A, B = symbols('A,B', commutative=False) + + assert simplify(A*B - B*A) == A*B - B*A + assert simplify(A/(1 + y/x)) == x*A/(x + y) + assert simplify(A*(1/x + 1/y)) == A/x + A/y #(x + y)*A/(x*y) + + assert simplify(log(2) + log(3)) == log(6) + assert simplify(log(2*x) - log(2)) == log(x) + + assert simplify(hyper([], [], x)) == exp(x) + + +def test_issue_3557(): + f_1 = x*a + y*b + z*c - 1 + f_2 = x*d + y*e + z*f - 1 + f_3 = x*g + y*h + z*i - 1 + + solutions = solve([f_1, f_2, f_3], x, y, z, simplify=False) + + assert simplify(solutions[y]) == \ + (a*i + c*d + f*g - a*f - c*g - d*i)/ \ + (a*e*i + b*f*g + c*d*h - a*f*h - b*d*i - c*e*g) + + +def test_simplify_other(): + assert simplify(sin(x)**2 + cos(x)**2) == 1 + assert simplify(gamma(x + 1)/gamma(x)) == x + assert simplify(sin(x)**2 + cos(x)**2 + factorial(x)/gamma(x)) == 1 + x + assert simplify( + Eq(sin(x)**2 + cos(x)**2, factorial(x)/gamma(x))) == Eq(x, 1) + nc = symbols('nc', commutative=False) + assert simplify(x + x*nc) == x*(1 + nc) + # issue 6123 + # f = exp(-I*(k*sqrt(t) + x/(2*sqrt(t)))**2) + # ans = integrate(f, (k, -oo, oo), conds='none') + ans = I*(-pi*x*exp(I*pi*Rational(-3, 4) + I*x**2/(4*t))*erf(x*exp(I*pi*Rational(-3, 4))/ + (2*sqrt(t)))/(2*sqrt(t)) + pi*x*exp(I*pi*Rational(-3, 4) + I*x**2/(4*t))/ + (2*sqrt(t)))*exp(-I*x**2/(4*t))/(sqrt(pi)*x) - I*sqrt(pi) * \ + (-erf(x*exp(I*pi/4)/(2*sqrt(t))) + 1)*exp(I*pi/4)/(2*sqrt(t)) + assert simplify(ans) == -(-1)**Rational(3, 4)*sqrt(pi)/sqrt(t) + # issue 6370 + assert simplify(2**(2 + x)/4) == 2**x + + +@_both_exp_pow +def test_simplify_complex(): + cosAsExp = cos(x)._eval_rewrite_as_exp(x) + tanAsExp = tan(x)._eval_rewrite_as_exp(x) + assert simplify(cosAsExp*tanAsExp) == sin(x) # issue 4341 + + # issue 10124 + assert simplify(exp(Matrix([[0, -1], [1, 0]]))) == Matrix([[cos(1), + -sin(1)], [sin(1), cos(1)]]) + + +def test_simplify_ratio(): + # roots of x**3-3*x+5 + roots = ['(1/2 - sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3) + 1/((1/2 - ' + 'sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3))', + '1/((1/2 + sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3)) + ' + '(1/2 + sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3)', + '-(sqrt(21)/2 + 5/2)**(1/3) - 1/(sqrt(21)/2 + 5/2)**(1/3)'] + + for r in roots: + r = S(r) + assert count_ops(simplify(r, ratio=1)) <= count_ops(r) + # If ratio=oo, simplify() is always applied: + assert simplify(r, ratio=oo) is not r + + +def test_simplify_measure(): + measure1 = lambda expr: len(str(expr)) + measure2 = lambda expr: -count_ops(expr) + # Return the most complicated result + expr = (x + 1)/(x + sin(x)**2 + cos(x)**2) + assert measure1(simplify(expr, measure=measure1)) <= measure1(expr) + assert measure2(simplify(expr, measure=measure2)) <= measure2(expr) + + expr2 = Eq(sin(x)**2 + cos(x)**2, 1) + assert measure1(simplify(expr2, measure=measure1)) <= measure1(expr2) + assert measure2(simplify(expr2, measure=measure2)) <= measure2(expr2) + + +def test_simplify_rational(): + expr = 2**x*2.**y + assert simplify(expr, rational = True) == 2**(x+y) + assert simplify(expr, rational = None) == 2.0**(x+y) + assert simplify(expr, rational = False) == expr + assert simplify('0.9 - 0.8 - 0.1', rational = True) == 0 + + +def test_simplify_issue_1308(): + assert simplify(exp(Rational(-1, 2)) + exp(Rational(-3, 2))) == \ + (1 + E)*exp(Rational(-3, 2)) + + +def test_issue_5652(): + assert simplify(E + exp(-E)) == exp(-E) + E + n = symbols('n', commutative=False) + assert simplify(n + n**(-n)) == n + n**(-n) + +def test_issue_27380(): + assert simplify(1.0**(x+1)/1.0**x) == 1.0 + +def test_simplify_fail1(): + x = Symbol('x') + y = Symbol('y') + e = (x + y)**2/(-4*x*y**2 - 2*y**3 - 2*x**2*y) + assert simplify(e) == 1 / (-2*y) + + +def test_nthroot(): + assert nthroot(90 + 34*sqrt(7), 3) == sqrt(7) + 3 + q = 1 + sqrt(2) - 2*sqrt(3) + sqrt(6) + sqrt(7) + assert nthroot(expand_multinomial(q**3), 3) == q + assert nthroot(41 + 29*sqrt(2), 5) == 1 + sqrt(2) + assert nthroot(-41 - 29*sqrt(2), 5) == -1 - sqrt(2) + expr = 1320*sqrt(10) + 4216 + 2576*sqrt(6) + 1640*sqrt(15) + assert nthroot(expr, 5) == 1 + sqrt(6) + sqrt(15) + q = 1 + sqrt(2) + sqrt(3) + sqrt(5) + assert expand_multinomial(nthroot(expand_multinomial(q**5), 5)) == q + q = 1 + sqrt(2) + 7*sqrt(6) + 2*sqrt(10) + assert nthroot(expand_multinomial(q**5), 5, 8) == q + q = 1 + sqrt(2) - 2*sqrt(3) + 1171*sqrt(6) + assert nthroot(expand_multinomial(q**3), 3) == q + assert nthroot(expand_multinomial(q**6), 6) == q + + +def test_nthroot1(): + q = 1 + sqrt(2) + sqrt(3) + S.One/10**20 + p = expand_multinomial(q**5) + assert nthroot(p, 5) == q + q = 1 + sqrt(2) + sqrt(3) + S.One/10**30 + p = expand_multinomial(q**5) + assert nthroot(p, 5) == q + + +@_both_exp_pow +def test_separatevars(): + x, y, z, n = symbols('x,y,z,n') + assert separatevars(2*n*x*z + 2*x*y*z) == 2*x*z*(n + y) + assert separatevars(x*z + x*y*z) == x*z*(1 + y) + assert separatevars(pi*x*z + pi*x*y*z) == pi*x*z*(1 + y) + assert separatevars(x*y**2*sin(x) + x*sin(x)*sin(y)) == \ + x*(sin(y) + y**2)*sin(x) + assert separatevars(x*exp(x + y) + x*exp(x)) == x*(1 + exp(y))*exp(x) + assert separatevars((x*(y + 1))**z).is_Pow # != x**z*(1 + y)**z + assert separatevars(1 + x + y + x*y) == (x + 1)*(y + 1) + assert separatevars(y/pi*exp(-(z - x)/cos(n))) == \ + y*exp(x/cos(n))*exp(-z/cos(n))/pi + assert separatevars((x + y)*(x - y) + y**2 + 2*x + 1) == (x + 1)**2 + # issue 4858 + p = Symbol('p', positive=True) + assert separatevars(sqrt(p**2 + x*p**2)) == p*sqrt(1 + x) + assert separatevars(sqrt(y*(p**2 + x*p**2))) == p*sqrt(y*(1 + x)) + assert separatevars(sqrt(y*(p**2 + x*p**2)), force=True) == \ + p*sqrt(y)*sqrt(1 + x) + # issue 4865 + assert separatevars(sqrt(x*y)).is_Pow + assert separatevars(sqrt(x*y), force=True) == sqrt(x)*sqrt(y) + # issue 4957 + # any type sequence for symbols is fine + assert separatevars(((2*x + 2)*y), dict=True, symbols=()) == \ + {'coeff': 1, x: 2*x + 2, y: y} + # separable + assert separatevars(((2*x + 2)*y), dict=True, symbols=[x]) == \ + {'coeff': y, x: 2*x + 2} + assert separatevars(((2*x + 2)*y), dict=True, symbols=[]) == \ + {'coeff': 1, x: 2*x + 2, y: y} + assert separatevars(((2*x + 2)*y), dict=True) == \ + {'coeff': 1, x: 2*x + 2, y: y} + assert separatevars(((2*x + 2)*y), dict=True, symbols=None) == \ + {'coeff': y*(2*x + 2)} + # not separable + assert separatevars(3, dict=True) is None + assert separatevars(2*x + y, dict=True, symbols=()) is None + assert separatevars(2*x + y, dict=True) is None + assert separatevars(2*x + y, dict=True, symbols=None) == {'coeff': 2*x + y} + # issue 4808 + n, m = symbols('n,m', commutative=False) + assert separatevars(m + n*m) == (1 + n)*m + assert separatevars(x + x*n) == x*(1 + n) + # issue 4910 + f = Function('f') + assert separatevars(f(x) + x*f(x)) == f(x) + x*f(x) + # a noncommutable object present + eq = x*(1 + hyper((), (), y*z)) + assert separatevars(eq) == eq + + s = separatevars(abs(x*y)) + assert s == abs(x)*abs(y) and s.is_Mul + z = cos(1)**2 + sin(1)**2 - 1 + a = abs(x*z) + s = separatevars(a) + assert not a.is_Mul and s.is_Mul and s == abs(x)*abs(z) + s = separatevars(abs(x*y*z)) + assert s == abs(x)*abs(y)*abs(z) + + # abs(x+y)/abs(z) would be better but we test this here to + # see that it doesn't raise + assert separatevars(abs((x+y)/z)) == abs((x+y)/z) + + +def test_separatevars_advanced_factor(): + x, y, z = symbols('x,y,z') + assert separatevars(1 + log(x)*log(y) + log(x) + log(y)) == \ + (log(x) + 1)*(log(y) + 1) + assert separatevars(1 + x - log(z) - x*log(z) - exp(y)*log(z) - + x*exp(y)*log(z) + x*exp(y) + exp(y)) == \ + -((x + 1)*(log(z) - 1)*(exp(y) + 1)) + x, y = symbols('x,y', positive=True) + assert separatevars(1 + log(x**log(y)) + log(x*y)) == \ + (log(x) + 1)*(log(y) + 1) + + +def test_hypersimp(): + n, k = symbols('n,k', integer=True) + + assert hypersimp(factorial(k), k) == k + 1 + assert hypersimp(factorial(k**2), k) is None + + assert hypersimp(1/factorial(k), k) == 1/(k + 1) + + assert hypersimp(2**k/factorial(k)**2, k) == 2/(k + 1)**2 + + assert hypersimp(binomial(n, k), k) == (n - k)/(k + 1) + assert hypersimp(binomial(n + 1, k), k) == (n - k + 1)/(k + 1) + + term = (4*k + 1)*factorial(k)/factorial(2*k + 1) + assert hypersimp(term, k) == S.Half*((4*k + 5)/(3 + 14*k + 8*k**2)) + + term = 1/((2*k - 1)*factorial(2*k + 1)) + assert hypersimp(term, k) == (k - S.Half)/((k + 1)*(2*k + 1)*(2*k + 3)) + + term = binomial(n, k)*(-1)**k/factorial(k) + assert hypersimp(term, k) == (k - n)/(k + 1)**2 + + +def test_nsimplify(): + x = Symbol("x") + assert nsimplify(0) == 0 + assert nsimplify(-1) == -1 + assert nsimplify(1) == 1 + assert nsimplify(1 + x) == 1 + x + assert nsimplify(2.7) == Rational(27, 10) + assert nsimplify(1 - GoldenRatio) == (1 - sqrt(5))/2 + assert nsimplify((1 + sqrt(5))/4, [GoldenRatio]) == GoldenRatio/2 + assert nsimplify(2/GoldenRatio, [GoldenRatio]) == 2*GoldenRatio - 2 + assert nsimplify(exp(pi*I*Rational(5, 3), evaluate=False)) == \ + sympify('1/2 - sqrt(3)*I/2') + assert nsimplify(sin(pi*Rational(3, 5), evaluate=False)) == \ + sympify('sqrt(sqrt(5)/8 + 5/8)') + assert nsimplify(sqrt(atan('1', evaluate=False))*(2 + I), [pi]) == \ + sqrt(pi) + sqrt(pi)/2*I + assert nsimplify(2 + exp(2*atan('1/4')*I)) == sympify('49/17 + 8*I/17') + assert nsimplify(pi, tolerance=0.01) == Rational(22, 7) + assert nsimplify(pi, tolerance=0.001) == Rational(355, 113) + assert nsimplify(0.33333, tolerance=1e-4) == Rational(1, 3) + assert nsimplify(2.0**(1/3.), tolerance=0.001) == Rational(635, 504) + assert nsimplify(2.0**(1/3.), tolerance=0.001, full=True) == \ + 2**Rational(1, 3) + assert nsimplify(x + .5, rational=True) == S.Half + x + assert nsimplify(1/.3 + x, rational=True) == Rational(10, 3) + x + assert nsimplify(log(3).n(), rational=True) == \ + sympify('109861228866811/100000000000000') + assert nsimplify(Float(0.272198261287950), [pi, log(2)]) == pi*log(2)/8 + assert nsimplify(Float(0.272198261287950).n(3), [pi, log(2)]) == \ + -pi/4 - log(2) + Rational(7, 4) + assert nsimplify(x/7.0) == x/7 + assert nsimplify(pi/1e2) == pi/100 + assert nsimplify(pi/1e2, rational=False) == pi/100.0 + assert nsimplify(pi/1e-7) == 10000000*pi + assert not nsimplify( + factor(-3.0*z**2*(z**2)**(-2.5) + 3*(z**2)**(-1.5))).atoms(Float) + e = x**0.0 + assert e.is_Pow and nsimplify(x**0.0) == 1 + assert nsimplify(3.333333, tolerance=0.1, rational=True) == Rational(10, 3) + assert nsimplify(3.333333, tolerance=0.01, rational=True) == Rational(10, 3) + assert nsimplify(3.666666, tolerance=0.1, rational=True) == Rational(11, 3) + assert nsimplify(3.666666, tolerance=0.01, rational=True) == Rational(11, 3) + assert nsimplify(33, tolerance=10, rational=True) == Rational(33) + assert nsimplify(33.33, tolerance=10, rational=True) == Rational(30) + assert nsimplify(37.76, tolerance=10, rational=True) == Rational(40) + assert nsimplify(-203.1) == Rational(-2031, 10) + assert nsimplify(.2, tolerance=0) == Rational(1, 5) + assert nsimplify(-.2, tolerance=0) == Rational(-1, 5) + assert nsimplify(.2222, tolerance=0) == Rational(1111, 5000) + assert nsimplify(-.2222, tolerance=0) == Rational(-1111, 5000) + # issue 7211, PR 4112 + assert nsimplify(S(2e-8)) == Rational(1, 50000000) + # issue 7322 direct test + assert nsimplify(1e-42, rational=True) != 0 + # issue 10336 + inf = Float('inf') + infs = (-oo, oo, inf, -inf) + for zi in infs: + ans = sign(zi)*oo + assert nsimplify(zi) == ans + assert nsimplify(zi + x) == x + ans + + assert nsimplify(0.33333333, rational=True, rational_conversion='exact') == Rational(0.33333333) + + # Make sure nsimplify on expressions uses full precision + assert nsimplify(pi.evalf(100)*x, rational_conversion='exact').evalf(100) == pi.evalf(100)*x + + +def test_issue_9448(): + tmp = sympify("1/(1 - (-1)**(2/3) - (-1)**(1/3)) + 1/(1 + (-1)**(2/3) + (-1)**(1/3))") + assert nsimplify(tmp) == S.Half + + +def test_extract_minus_sign(): + x = Symbol("x") + y = Symbol("y") + a = Symbol("a") + b = Symbol("b") + assert simplify(-x/-y) == x/y + assert simplify(-x/y) == -x/y + assert simplify(x/y) == x/y + assert simplify(x/-y) == -x/y + assert simplify(-x/0) == zoo*x + assert simplify(Rational(-5, 0)) is zoo + assert simplify(-a*x/(-y - b)) == a*x/(b + y) + + +def test_diff(): + x = Symbol("x") + y = Symbol("y") + f = Function("f") + g = Function("g") + assert simplify(g(x).diff(x)*f(x).diff(x) - f(x).diff(x)*g(x).diff(x)) == 0 + assert simplify(2*f(x)*f(x).diff(x) - diff(f(x)**2, x)) == 0 + assert simplify(diff(1/f(x), x) + f(x).diff(x)/f(x)**2) == 0 + assert simplify(f(x).diff(x, y) - f(x).diff(y, x)) == 0 + + +def test_logcombine_1(): + x, y = symbols("x,y") + a = Symbol("a") + z, w = symbols("z,w", positive=True) + b = Symbol("b", real=True) + assert logcombine(log(x) + 2*log(y)) == log(x) + 2*log(y) + assert logcombine(log(x) + 2*log(y), force=True) == log(x*y**2) + assert logcombine(a*log(w) + log(z)) == a*log(w) + log(z) + assert logcombine(b*log(z) + b*log(x)) == log(z**b) + b*log(x) + assert logcombine(b*log(z) - log(w)) == log(z**b/w) + assert logcombine(log(x)*log(z)) == log(x)*log(z) + assert logcombine(log(w)*log(x)) == log(w)*log(x) + assert logcombine(cos(-2*log(z) + b*log(w))) in [cos(log(w**b/z**2)), + cos(log(z**2/w**b))] + assert logcombine(log(log(x) - log(y)) - log(z), force=True) == \ + log(log(x/y)/z) + assert logcombine((2 + I)*log(x), force=True) == (2 + I)*log(x) + assert logcombine((x**2 + log(x) - log(y))/(x*y), force=True) == \ + (x**2 + log(x/y))/(x*y) + # the following could also give log(z*x**log(y**2)), what we + # are testing is that a canonical result is obtained + assert logcombine(log(x)*2*log(y) + log(z), force=True) == \ + log(z*y**log(x**2)) + assert logcombine((x*y + sqrt(x**4 + y**4) + log(x) - log(y))/(pi*x**Rational(2, 3)* + sqrt(y)**3), force=True) == ( + x*y + sqrt(x**4 + y**4) + log(x/y))/(pi*x**Rational(2, 3)*y**Rational(3, 2)) + assert logcombine(gamma(-log(x/y))*acos(-log(x/y)), force=True) == \ + acos(-log(x/y))*gamma(-log(x/y)) + + assert logcombine(2*log(z)*log(w)*log(x) + log(z) + log(w)) == \ + log(z**log(w**2))*log(x) + log(w*z) + assert logcombine(3*log(w) + 3*log(z)) == log(w**3*z**3) + assert logcombine(x*(y + 1) + log(2) + log(3)) == x*(y + 1) + log(6) + assert logcombine((x + y)*log(w) + (-x - y)*log(3)) == (x + y)*log(w/3) + # a single unknown can combine + assert logcombine(log(x) + log(2)) == log(2*x) + eq = log(abs(x)) + log(abs(y)) + assert logcombine(eq) == eq + reps = {x: 0, y: 0} + assert log(abs(x)*abs(y)).subs(reps) != eq.subs(reps) + + +def test_logcombine_complex_coeff(): + i = Integral((sin(x**2) + cos(x**3))/x, x) + assert logcombine(i, force=True) == i + assert logcombine(i + 2*log(x), force=True) == \ + i + log(x**2) + + +def test_issue_5950(): + x, y = symbols("x,y", positive=True) + assert logcombine(log(3) - log(2)) == log(Rational(3,2), evaluate=False) + assert logcombine(log(x) - log(y)) == log(x/y) + assert logcombine(log(Rational(3,2), evaluate=False) - log(2)) == \ + log(Rational(3,4), evaluate=False) + + +def test_posify(): + x = symbols('x') + + assert str(posify( + x + + Symbol('p', positive=True) + + Symbol('n', negative=True))) == '(_x + n + p, {_x: x})' + + eq, rep = posify(1/x) + assert log(eq).expand().subs(rep) == -log(x) + assert str(posify([x, 1 + x])) == '([_x, _x + 1], {_x: x})' + + p = symbols('p', positive=True) + n = symbols('n', negative=True) + orig = [x, n, p] + modified, reps = posify(orig) + assert str(modified) == '[_x, n, p]' + assert [w.subs(reps) for w in modified] == orig + + assert str(Integral(posify(1/x + y)[0], (y, 1, 3)).expand()) == \ + 'Integral(1/_x, (y, 1, 3)) + Integral(_y, (y, 1, 3))' + assert str(Sum(posify(1/x**n)[0], (n,1,3)).expand()) == \ + 'Sum(_x**(-n), (n, 1, 3))' + + A = Matrix([[1, 2, 3], [4, 5, 6 * Abs(x)]]) + Ap, rep = posify(A) + assert Ap == A.subs(*reversed(rep.popitem())) + + # issue 16438 + k = Symbol('k', finite=True) + eq, rep = posify(k) + assert eq.assumptions0 == {'positive': True, 'zero': False, 'imaginary': False, + 'nonpositive': False, 'commutative': True, 'hermitian': True, 'real': True, 'nonzero': True, + 'nonnegative': True, 'negative': False, 'complex': True, 'finite': True, + 'infinite': False, 'extended_real':True, 'extended_negative': False, + 'extended_nonnegative': True, 'extended_nonpositive': False, + 'extended_nonzero': True, 'extended_positive': True} + + +def test_issue_4194(): + # simplify should call cancel + f = Function('f') + assert simplify((4*x + 6*f(y))/(2*x + 3*f(y))) == 2 + + +@XFAIL +def test_simplify_float_vs_integer(): + # Test for issue 4473: + # https://github.com/sympy/sympy/issues/4473 + assert simplify(x**2.0 - x**2) == 0 + assert simplify(x**2 - x**2.0) == 0 + + +def test_as_content_primitive(): + assert (x/2 + y).as_content_primitive() == (S.Half, x + 2*y) + assert (x/2 + y).as_content_primitive(clear=False) == (S.One, x/2 + y) + assert (y*(x/2 + y)).as_content_primitive() == (S.Half, y*(x + 2*y)) + assert (y*(x/2 + y)).as_content_primitive(clear=False) == (S.One, y*(x/2 + y)) + + # although the _as_content_primitive methods do not alter the underlying structure, + # the as_content_primitive function will touch up the expression and join + # bases that would otherwise have not been joined. + assert (x*(2 + 2*x)*(3*x + 3)**2).as_content_primitive() == \ + (18, x*(x + 1)**3) + assert (2 + 2*x + 2*y*(3 + 3*y)).as_content_primitive() == \ + (2, x + 3*y*(y + 1) + 1) + assert ((2 + 6*x)**2).as_content_primitive() == \ + (4, (3*x + 1)**2) + assert ((2 + 6*x)**(2*y)).as_content_primitive() == \ + (1, (_keep_coeff(S(2), (3*x + 1)))**(2*y)) + assert (5 + 10*x + 2*y*(3 + 3*y)).as_content_primitive() == \ + (1, 10*x + 6*y*(y + 1) + 5) + assert (5*(x*(1 + y)) + 2*x*(3 + 3*y)).as_content_primitive() == \ + (11, x*(y + 1)) + assert ((5*(x*(1 + y)) + 2*x*(3 + 3*y))**2).as_content_primitive() == \ + (121, x**2*(y + 1)**2) + assert (y**2).as_content_primitive() == \ + (1, y**2) + assert (S.Infinity).as_content_primitive() == (1, oo) + eq = x**(2 + y) + assert (eq).as_content_primitive() == (1, eq) + assert (S.Half**(2 + x)).as_content_primitive() == (Rational(1, 4), 2**-x) + assert (Rational(-1, 2)**(2 + x)).as_content_primitive() == \ + (Rational(1, 4), (Rational(-1, 2))**x) + assert (Rational(-1, 2)**(2 + x)).as_content_primitive() == \ + (Rational(1, 4), Rational(-1, 2)**x) + assert (4**((1 + y)/2)).as_content_primitive() == (2, 4**(y/2)) + assert (3**((1 + y)/2)).as_content_primitive() == \ + (1, 3**(Mul(S.Half, 1 + y, evaluate=False))) + assert (5**Rational(3, 4)).as_content_primitive() == (1, 5**Rational(3, 4)) + assert (5**Rational(7, 4)).as_content_primitive() == (5, 5**Rational(3, 4)) + assert Add(z*Rational(5, 7), 0.5*x, y*Rational(3, 2), evaluate=False).as_content_primitive() == \ + (Rational(1, 14), 7.0*x + 21*y + 10*z) + assert (2**Rational(3, 4) + 2**Rational(1, 4)*sqrt(3)).as_content_primitive(radical=True) == \ + (1, 2**Rational(1, 4)*(sqrt(2) + sqrt(3))) + + +def test_signsimp(): + e = x*(-x + 1) + x*(x - 1) + assert signsimp(Eq(e, 0)) is S.true + assert Abs(x - 1) == Abs(1 - x) + assert signsimp(y - x) == y - x + assert signsimp(y - x, evaluate=False) == Mul(-1, x - y, evaluate=False) + + +def test_besselsimp(): + from sympy.functions.special.bessel import (besseli, besselj, bessely) + from sympy.integrals.transforms import cosine_transform + assert besselsimp(exp(-I*pi*y/2)*besseli(y, z*exp_polar(I*pi/2))) == \ + besselj(y, z) + assert besselsimp(exp(-I*pi*a/2)*besseli(a, 2*sqrt(x)*exp_polar(I*pi/2))) == \ + besselj(a, 2*sqrt(x)) + assert besselsimp(sqrt(2)*sqrt(pi)*x**Rational(1, 4)*exp(I*pi/4)*exp(-I*pi*a/2) * + besseli(Rational(-1, 2), sqrt(x)*exp_polar(I*pi/2)) * + besseli(a, sqrt(x)*exp_polar(I*pi/2))/2) == \ + besselj(a, sqrt(x)) * cos(sqrt(x)) + assert besselsimp(besseli(Rational(-1, 2), z)) == \ + sqrt(2)*cosh(z)/(sqrt(pi)*sqrt(z)) + assert besselsimp(besseli(a, z*exp_polar(-I*pi/2))) == \ + exp(-I*pi*a/2)*besselj(a, z) + assert cosine_transform(1/t*sin(a/t), t, y) == \ + sqrt(2)*sqrt(pi)*besselj(0, 2*sqrt(a)*sqrt(y))/2 + + assert besselsimp(x**2*(a*(-2*besselj(5*I, x) + besselj(-2 + 5*I, x) + + besselj(2 + 5*I, x)) + b*(-2*bessely(5*I, x) + bessely(-2 + 5*I, x) + + bessely(2 + 5*I, x)))/4 + x*(a*(besselj(-1 + 5*I, x)/2 - besselj(1 + 5*I, x)/2) + + b*(bessely(-1 + 5*I, x)/2 - bessely(1 + 5*I, x)/2)) + (x**2 + 25)*(a*besselj(5*I, x) + + b*bessely(5*I, x))) == 0 + + assert besselsimp(81*x**2*(a*(besselj(Rational(-5, 3), 9*x) - 2*besselj(Rational(1, 3), 9*x) + besselj(Rational(7, 3), 9*x)) + + b*(bessely(Rational(-5, 3), 9*x) - 2*bessely(Rational(1, 3), 9*x) + bessely(Rational(7, 3), 9*x)))/4 + x*(a*(9*besselj(Rational(-2, 3), 9*x)/2 + - 9*besselj(Rational(4, 3), 9*x)/2) + b*(9*bessely(Rational(-2, 3), 9*x)/2 - 9*bessely(Rational(4, 3), 9*x)/2)) + + (81*x**2 - Rational(1, 9))*(a*besselj(Rational(1, 3), 9*x) + b*bessely(Rational(1, 3), 9*x))) == 0 + + assert besselsimp(besselj(a-1,x) + besselj(a+1, x) - 2*a*besselj(a, x)/x) == 0 + + assert besselsimp(besselj(a-1,x) + besselj(a+1, x) + besselj(a, x)) == (2*a + x)*besselj(a, x)/x + + assert besselsimp(x**2* besselj(a,x) + x**3*besselj(a+1, x) + besselj(a+2, x)) == \ + 2*a*x*besselj(a + 1, x) + x**3*besselj(a + 1, x) - x**2*besselj(a + 2, x) + 2*x*besselj(a + 1, x) + besselj(a + 2, x) + +def test_Piecewise(): + e1 = x*(x + y) - y*(x + y) + e2 = sin(x)**2 + cos(x)**2 + e3 = expand((x + y)*y/x) + s1 = simplify(e1) + s2 = simplify(e2) + s3 = simplify(e3) + assert simplify(Piecewise((e1, x < e2), (e3, True))) == \ + Piecewise((s1, x < s2), (s3, True)) + + +def test_polymorphism(): + class A(Basic): + def _eval_simplify(x, **kwargs): + return S.One + + a = A(S(5), S(2)) + assert simplify(a) == 1 + + +def test_issue_from_PR1599(): + n1, n2, n3, n4 = symbols('n1 n2 n3 n4', negative=True) + assert simplify(I*sqrt(n1)) == -sqrt(-n1) + + +def test_issue_6811(): + eq = (x + 2*y)*(2*x + 2) + assert simplify(eq) == (x + 1)*(x + 2*y)*2 + # reject the 2-arg Mul -- these are a headache for test writing + assert simplify(eq.expand()) == \ + 2*x**2 + 4*x*y + 2*x + 4*y + + +def test_issue_6920(): + e = [cos(x) + I*sin(x), cos(x) - I*sin(x), + cosh(x) - sinh(x), cosh(x) + sinh(x)] + ok = [exp(I*x), exp(-I*x), exp(-x), exp(x)] + # wrap in f to show that the change happens wherever ei occurs + f = Function('f') + assert [simplify(f(ei)).args[0] for ei in e] == ok + + +def test_issue_7001(): + from sympy.abc import r, R + assert simplify(-(r*Piecewise((pi*Rational(4, 3), r <= R), + (-8*pi*R**3/(3*r**3), True)) + 2*Piecewise((pi*r*Rational(4, 3), r <= R), + (4*pi*R**3/(3*r**2), True)))/(4*pi*r)) == \ + Piecewise((-1, r <= R), (0, True)) + + +def test_inequality_no_auto_simplify(): + # no simplify on creation but can be simplified + lhs = cos(x)**2 + sin(x)**2 + rhs = 2 + e = Lt(lhs, rhs, evaluate=False) + assert e is not S.true + assert simplify(e) + + +def test_issue_9398(): + from sympy.core.numbers import Number + from sympy.polys.polytools import cancel + assert cancel(1e-14) != 0 + assert cancel(1e-14*I) != 0 + + assert simplify(1e-14) != 0 + assert simplify(1e-14*I) != 0 + + assert (I*Number(1.)*Number(10)**Number(-14)).simplify() != 0 + + assert cancel(1e-20) != 0 + assert cancel(1e-20*I) != 0 + + assert simplify(1e-20) != 0 + assert simplify(1e-20*I) != 0 + + assert cancel(1e-100) != 0 + assert cancel(1e-100*I) != 0 + + assert simplify(1e-100) != 0 + assert simplify(1e-100*I) != 0 + + f = Float("1e-1000") + assert cancel(f) != 0 + assert cancel(f*I) != 0 + + assert simplify(f) != 0 + assert simplify(f*I) != 0 + + +def test_issue_9324_simplify(): + M = MatrixSymbol('M', 10, 10) + e = M[0, 0] + M[5, 4] + 1304 + assert simplify(e) == e + + +def test_issue_9817_simplify(): + # simplify on trace of substituted explicit quadratic form of matrix + # expressions (a scalar) should return without errors (AttributeError) + # See issue #9817 and #9190 for the original bug more discussion on this + from sympy.matrices.expressions import Identity, trace + v = MatrixSymbol('v', 3, 1) + A = MatrixSymbol('A', 3, 3) + x = Matrix([i + 1 for i in range(3)]) + X = Identity(3) + quadratic = v.T * A * v + assert simplify((trace(quadratic.as_explicit())).xreplace({v:x, A:X})) == 14 + + +def test_issue_13474(): + x = Symbol('x') + assert simplify(x + csch(sinc(1))) == x + csch(sinc(1)) + + +@_both_exp_pow +def test_simplify_function_inverse(): + # "inverse" attribute does not guarantee that f(g(x)) is x + # so this simplification should not happen automatically. + # See issue #12140 + x, y = symbols('x, y') + g = Function('g') + + class f(Function): + def inverse(self, argindex=1): + return g + + assert simplify(f(g(x))) == f(g(x)) + assert inversecombine(f(g(x))) == x + assert simplify(f(g(x)), inverse=True) == x + assert simplify(f(g(sin(x)**2 + cos(x)**2)), inverse=True) == 1 + assert simplify(f(g(x, y)), inverse=True) == f(g(x, y)) + assert unchanged(asin, sin(x)) + assert simplify(asin(sin(x))) == asin(sin(x)) + assert simplify(2*asin(sin(3*x)), inverse=True) == 6*x + assert simplify(log(exp(x))) == log(exp(x)) + assert simplify(log(exp(x)), inverse=True) == x + assert simplify(exp(log(x)), inverse=True) == x + assert simplify(log(exp(x), 2), inverse=True) == x/log(2) + assert simplify(log(exp(x), 2, evaluate=False), inverse=True) == x/log(2) + + +def test_clear_coefficients(): + from sympy.simplify.simplify import clear_coefficients + assert clear_coefficients(4*y*(6*x + 3)) == (y*(2*x + 1), 0) + assert clear_coefficients(4*y*(6*x + 3) - 2) == (y*(2*x + 1), Rational(1, 6)) + assert clear_coefficients(4*y*(6*x + 3) - 2, x) == (y*(2*x + 1), x/12 + Rational(1, 6)) + assert clear_coefficients(sqrt(2) - 2) == (sqrt(2), 2) + assert clear_coefficients(4*sqrt(2) - 2) == (sqrt(2), S.Half) + assert clear_coefficients(S(3), x) == (0, x - 3) + assert clear_coefficients(S.Infinity, x) == (S.Infinity, x) + assert clear_coefficients(-S.Pi, x) == (S.Pi, -x) + assert clear_coefficients(2 - S.Pi/3, x) == (pi, -3*x + 6) + +def test_nc_simplify(): + from sympy.simplify.simplify import nc_simplify + from sympy.matrices.expressions import MatPow, Identity + from sympy.core import Pow + from functools import reduce + + a, b, c, d = symbols('a b c d', commutative = False) + x = Symbol('x') + A = MatrixSymbol("A", x, x) + B = MatrixSymbol("B", x, x) + C = MatrixSymbol("C", x, x) + D = MatrixSymbol("D", x, x) + subst = {a: A, b: B, c: C, d:D} + funcs = {Add: lambda x,y: x+y, Mul: lambda x,y: x*y } + + def _to_matrix(expr): + if expr in subst: + return subst[expr] + if isinstance(expr, Pow): + return MatPow(_to_matrix(expr.args[0]), expr.args[1]) + elif isinstance(expr, (Add, Mul)): + return reduce(funcs[expr.func],[_to_matrix(a) for a in expr.args]) + else: + return expr*Identity(x) + + def _check(expr, simplified, deep=True, matrix=True): + assert nc_simplify(expr, deep=deep) == simplified + assert expand(expr) == expand(simplified) + if matrix: + m_simp = _to_matrix(simplified).doit(inv_expand=False) + assert nc_simplify(_to_matrix(expr), deep=deep) == m_simp + + _check(a*b*a*b*a*b*c*(a*b)**3*c, ((a*b)**3*c)**2) + _check(a*b*(a*b)**-2*a*b, 1) + _check(a**2*b*a*b*a*b*(a*b)**-1, a*(a*b)**2, matrix=False) + _check(b*a*b**2*a*b**2*a*b**2, b*(a*b**2)**3) + _check(a*b*a**2*b*a**2*b*a**3, (a*b*a)**3*a**2) + _check(a**2*b*a**4*b*a**4*b*a**2, (a**2*b*a**2)**3) + _check(a**3*b*a**4*b*a**4*b*a, a**3*(b*a**4)**3*a**-3) + _check(a*b*a*b + a*b*c*x*a*b*c, (a*b)**2 + x*(a*b*c)**2) + _check(a*b*a*b*c*a*b*a*b*c, ((a*b)**2*c)**2) + _check(b**-1*a**-1*(a*b)**2, a*b) + _check(a**-1*b*c**-1, (c*b**-1*a)**-1) + expr = a**3*b*a**4*b*a**4*b*a**2*b*a**2*(b*a**2)**2*b*a**2*b*a**2 + for _ in range(10): + expr *= a*b + _check(expr, a**3*(b*a**4)**2*(b*a**2)**6*(a*b)**10) + _check((a*b*a*b)**2, (a*b*a*b)**2, deep=False) + _check(a*b*(c*d)**2, a*b*(c*d)**2) + expr = b**-1*(a**-1*b**-1 - a**-1*c*b**-1)**-1*a**-1 + assert nc_simplify(expr) == (1-c)**-1 + # commutative expressions should be returned without an error + assert nc_simplify(2*x**2) == 2*x**2 + +def test_issue_15965(): + A = Sum(z*x**y, (x, 1, a)) + anew = z*Sum(x**y, (x, 1, a)) + B = Integral(x*y, x) + bdo = x**2*y/2 + assert simplify(A + B) == anew + bdo + assert simplify(A) == anew + assert simplify(B) == bdo + assert simplify(B, doit=False) == y*Integral(x, x) + + +def test_issue_17137(): + assert simplify(cos(x)**I) == cos(x)**I + assert simplify(cos(x)**(2 + 3*I)) == cos(x)**(2 + 3*I) + + +def test_issue_21869(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + expr = And(Eq(x**2, 4), Le(x, y)) + assert expr.simplify() == expr + + expr = And(Eq(x**2, 4), Eq(x, 2)) + assert expr.simplify() == Eq(x, 2) + + expr = And(Eq(x**3, x**2), Eq(x, 1)) + assert expr.simplify() == Eq(x, 1) + + expr = And(Eq(sin(x), x**2), Eq(x, 0)) + assert expr.simplify() == Eq(x, 0) + + expr = And(Eq(x**3, x**2), Eq(x, 2)) + assert expr.simplify() == S.false + + expr = And(Eq(y, x**2), Eq(x, 1)) + assert expr.simplify() == And(Eq(y,1), Eq(x, 1)) + + expr = And(Eq(y**2, 1), Eq(y, x**2), Eq(x, 1)) + assert expr.simplify() == And(Eq(y,1), Eq(x, 1)) + + expr = And(Eq(y**2, 4), Eq(y, 2*x**2), Eq(x, 1)) + assert expr.simplify() == And(Eq(y,2), Eq(x, 1)) + + expr = And(Eq(y**2, 4), Eq(y, x**2), Eq(x, 1)) + assert expr.simplify() == S.false + + +def test_issue_7971_21740(): + z = Integral(x, (x, 1, 1)) + assert z != 0 + assert simplify(z) is S.Zero + assert simplify(S.Zero) is S.Zero + z = simplify(Float(0)) + assert z is not S.Zero and z == 0.0 + + +@slow +def test_issue_17141_slow(): + # Should not give RecursionError + assert simplify((2**acos(I+1)**2).rewrite('log')) == 2**((pi + 2*I*log(-1 + + sqrt(1 - 2*I) + I))**2/4) + + +def test_issue_17141(): + # Check that there is no RecursionError + assert simplify(x**(1 / acos(I))) == x**(2/(pi - 2*I*log(1 + sqrt(2)))) + assert simplify(acos(-I)**2*acos(I)**2) == \ + log(1 + sqrt(2))**4 + pi**2*log(1 + sqrt(2))**2/2 + pi**4/16 + assert simplify(2**acos(I)**2) == 2**((pi - 2*I*log(1 + sqrt(2)))**2/4) + p = 2**acos(I+1)**2 + assert simplify(p) == p + + +def test_simplify_kroneckerdelta(): + i, j = symbols("i j") + K = KroneckerDelta + + assert simplify(K(i, j)) == K(i, j) + assert simplify(K(0, j)) == K(0, j) + assert simplify(K(i, 0)) == K(i, 0) + + assert simplify(K(0, j).rewrite(Piecewise) * K(1, j)) == 0 + assert simplify(K(1, i) + Piecewise((1, Eq(j, 2)), (0, True))) == K(1, i) + K(2, j) + + # issue 17214 + assert simplify(K(0, j) * K(1, j)) == 0 + + n = Symbol('n', integer=True) + assert simplify(K(0, n) * K(1, n)) == 0 + + M = Matrix(4, 4, lambda i, j: K(j - i, n) if i <= j else 0) + assert simplify(M**2) == Matrix([[K(0, n), 0, K(1, n), 0], + [0, K(0, n), 0, K(1, n)], + [0, 0, K(0, n), 0], + [0, 0, 0, K(0, n)]]) + assert simplify(eye(1) * KroneckerDelta(0, n) * + KroneckerDelta(1, n)) == Matrix([[0]]) + + assert simplify(S.Infinity * KroneckerDelta(0, n) * + KroneckerDelta(1, n)) is S.NaN + + +def test_issue_17292(): + assert simplify(abs(x)/abs(x**2)) == 1/abs(x) + # this is bigger than the issue: check that deep processing works + assert simplify(5*abs((x**2 - 1)/(x - 1))) == 5*Abs(x + 1) + + +def test_issue_19822(): + expr = And(Gt(n-2, 1), Gt(n, 1)) + assert simplify(expr) == Gt(n, 3) + + +def test_issue_18645(): + expr = And(Ge(x, 3), Le(x, 3)) + assert simplify(expr) == Eq(x, 3) + expr = And(Eq(x, 3), Le(x, 3)) + assert simplify(expr) == Eq(x, 3) + + +@XFAIL +def test_issue_18642(): + i = Symbol("i", integer=True) + n = Symbol("n", integer=True) + expr = And(Eq(i, 2 * n), Le(i, 2*n -1)) + assert simplify(expr) == S.false + + +@XFAIL +def test_issue_18389(): + n = Symbol("n", integer=True) + expr = Eq(n, 0) | (n >= 1) + assert simplify(expr) == Ge(n, 0) + + +def test_issue_8373(): + x = Symbol('x', real=True) + assert simplify(Or(x < 1, x >= 1)) == S.true + + +def test_issue_7950(): + expr = And(Eq(x, 1), Eq(x, 2)) + assert simplify(expr) == S.false + + +def test_issue_22020(): + expr = I*pi/2 -oo + assert simplify(expr) == expr + # Used to throw an error + + +def test_issue_19484(): + assert simplify(sign(x) * Abs(x)) == x + + e = x + sign(x + x**3) + assert simplify(Abs(x + x**3)*e) == x**3 + x*Abs(x**3 + x) + x + + e = x**2 + sign(x**3 + 1) + assert simplify(Abs(x**3 + 1) * e) == x**3 + x**2*Abs(x**3 + 1) + 1 + + f = Function('f') + e = x + sign(x + f(x)**3) + assert simplify(Abs(x + f(x)**3) * e) == x*Abs(x + f(x)**3) + x + f(x)**3 + + +def test_issue_23543(): + # Used to give an error + x, y, z = symbols("x y z", commutative=False) + assert (x*(y + z/2)).simplify() == x*(2*y + z)/2 + + +def test_issue_11004(): + + def f(n): + return sqrt(2*pi*n) * (n/E)**n + + def m(n, k): + return f(n) / (f(n/k)**k) + + def p(n,k): + return m(n, k) / (k**n) + + N, k = symbols('N k') + half = Float('0.5', 4) + z = log(p(n, k) / p(n, k + 1)).expand(force=True) + r = simplify(z.subs(n, N).n(4)) + assert r == ( + half*k*log(k) + - half*k*log(k + 1) + + half*log(N) + - half*log(k + 1) + + Float(0.9189224, 4) + ) + + +def test_issue_19161(): + polynomial = Poly('x**2').simplify() + assert (polynomial-x**2).simplify() == 0 + + +def test_issue_22210(): + d = Symbol('d', integer=True) + expr = 2*Derivative(sin(x), (x, d)) + assert expr.simplify() == expr + + +def test_reduce_inverses_nc_pow(): + x, y = symbols("x y", commutative=True) + Z = symbols("Z", commutative=False) + assert simplify(2**Z * y**Z) == 2**Z * y**Z + assert simplify(x**Z * y**Z) == x**Z * y**Z + x, y = symbols("x y", positive=True) + assert expand((x*y)**Z) == x**Z * y**Z + assert simplify(x**Z * y**Z) == expand((x*y)**Z) + +def test_nc_recursion_coeff(): + X = symbols("X", commutative = False) + assert (2 * cos(pi/3) * X).simplify() == X + assert (2.0 * cos(pi/3) * X).simplify() == X diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_sqrtdenest.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_sqrtdenest.py new file mode 100644 index 0000000000000000000000000000000000000000..41c771bb2055a1199d349ae3649f33927d79313a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_sqrtdenest.py @@ -0,0 +1,204 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import cos +from sympy.integrals.integrals import Integral +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.simplify.sqrtdenest import ( + _subsets as subsets, _sqrt_numeric_denest) + +r2, r3, r5, r6, r7, r10, r15, r29 = [sqrt(x) for x in (2, 3, 5, 6, 7, 10, + 15, 29)] + + +def test_sqrtdenest(): + d = {sqrt(5 + 2 * r6): r2 + r3, + sqrt(5. + 2 * r6): sqrt(5. + 2 * r6), + sqrt(5. + 4*sqrt(5 + 2 * r6)): sqrt(5.0 + 4*r2 + 4*r3), + sqrt(r2): sqrt(r2), + sqrt(5 + r7): sqrt(5 + r7), + sqrt(3 + sqrt(5 + 2*r7)): + 3*r2*(5 + 2*r7)**Rational(1, 4)/(2*sqrt(6 + 3*r7)) + + r2*sqrt(6 + 3*r7)/(2*(5 + 2*r7)**Rational(1, 4)), + sqrt(3 + 2*r3): 3**Rational(3, 4)*(r6/2 + 3*r2/2)/3} + for i in d: + assert sqrtdenest(i) == d[i], i + + +def test_sqrtdenest2(): + assert sqrtdenest(sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29))) == \ + r5 + sqrt(11 - 2*r29) + e = sqrt(-r5 + sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16)) + assert sqrtdenest(e) == root(-2*r29 + 11, 4) + r = sqrt(1 + r7) + assert sqrtdenest(sqrt(1 + r)) == sqrt(1 + r) + e = sqrt(((1 + sqrt(1 + 2*sqrt(3 + r2 + r5)))**2).expand()) + assert sqrtdenest(e) == 1 + sqrt(1 + 2*sqrt(r2 + r5 + 3)) + + assert sqrtdenest(sqrt(5*r3 + 6*r2)) == \ + sqrt(2)*root(3, 4) + root(3, 4)**3 + + assert sqrtdenest(sqrt(((1 + r5 + sqrt(1 + r3))**2).expand())) == \ + 1 + r5 + sqrt(1 + r3) + + assert sqrtdenest(sqrt(((1 + r5 + r7 + sqrt(1 + r3))**2).expand())) == \ + 1 + sqrt(1 + r3) + r5 + r7 + + e = sqrt(((1 + cos(2) + cos(3) + sqrt(1 + r3))**2).expand()) + assert sqrtdenest(e) == cos(3) + cos(2) + 1 + sqrt(1 + r3) + + e = sqrt(-2*r10 + 2*r2*sqrt(-2*r10 + 11) + 14) + assert sqrtdenest(e) == sqrt(-2*r10 - 2*r2 + 4*r5 + 14) + + # check that the result is not more complicated than the input + z = sqrt(-2*r29 + cos(2) + 2*sqrt(-10*r29 + 55) + 16) + assert sqrtdenest(z) == z + + assert sqrtdenest(sqrt(r6 + sqrt(15))) == sqrt(r6 + sqrt(15)) + + z = sqrt(15 - 2*sqrt(31) + 2*sqrt(55 - 10*r29)) + assert sqrtdenest(z) == z + + +def test_sqrtdenest_rec(): + assert sqrtdenest(sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 33)) == \ + -r2 + r3 + 2*r7 + assert sqrtdenest(sqrt(-28*r7 - 14*r5 + 4*sqrt(35) + 82)) == \ + -7 + r5 + 2*r7 + assert sqrtdenest(sqrt(6*r2/11 + 2*sqrt(22)/11 + 6*sqrt(11)/11 + 2)) == \ + sqrt(11)*(r2 + 3 + sqrt(11))/11 + assert sqrtdenest(sqrt(468*r3 + 3024*r2 + 2912*r6 + 19735)) == \ + 9*r3 + 26 + 56*r6 + z = sqrt(-490*r3 - 98*sqrt(115) - 98*sqrt(345) - 2107) + assert sqrtdenest(z) == sqrt(-1)*(7*r5 + 7*r15 + 7*sqrt(23)) + z = sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 34) + assert sqrtdenest(z) == z + assert sqrtdenest(sqrt(-8*r2 - 2*r5 + 18)) == -r10 + 1 + r2 + r5 + assert sqrtdenest(sqrt(8*r2 + 2*r5 - 18)) == \ + sqrt(-1)*(-r10 + 1 + r2 + r5) + assert sqrtdenest(sqrt(8*r2/3 + 14*r5/3 + Rational(154, 9))) == \ + -r10/3 + r2 + r5 + 3 + assert sqrtdenest(sqrt(sqrt(2*r6 + 5) + sqrt(2*r7 + 8))) == \ + sqrt(1 + r2 + r3 + r7) + assert sqrtdenest(sqrt(4*r15 + 8*r5 + 12*r3 + 24)) == 1 + r3 + r5 + r15 + + w = 1 + r2 + r3 + r5 + r7 + assert sqrtdenest(sqrt((w**2).expand())) == w + z = sqrt((w**2).expand() + 1) + assert sqrtdenest(z) == z + + z = sqrt(2*r10 + 6*r2 + 4*r5 + 12 + 10*r15 + 30*r3) + assert sqrtdenest(z) == z + + +def test_issue_6241(): + z = sqrt( -320 + 32*sqrt(5) + 64*r15) + assert sqrtdenest(z) == z + + +def test_sqrtdenest3(): + z = sqrt(13 - 2*r10 + 2*r2*sqrt(-2*r10 + 11)) + assert sqrtdenest(z) == -1 + r2 + r10 + assert sqrtdenest(z, max_iter=1) == -1 + sqrt(2) + sqrt(10) + z = sqrt(sqrt(r2 + 2) + 2) + assert sqrtdenest(z) == z + assert sqrtdenest(sqrt(-2*r10 + 4*r2*sqrt(-2*r10 + 11) + 20)) == \ + sqrt(-2*r10 - 4*r2 + 8*r5 + 20) + assert sqrtdenest(sqrt((112 + 70*r2) + (46 + 34*r2)*r5)) == \ + r10 + 5 + 4*r2 + 3*r5 + z = sqrt(5 + sqrt(2*r6 + 5)*sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16)) + r = sqrt(-2*r29 + 11) + assert sqrtdenest(z) == sqrt(r2*r + r3*r + r10 + r15 + 5) + + n = sqrt(2*r6/7 + 2*r7/7 + 2*sqrt(42)/7 + 2) + d = sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29)) + assert sqrtdenest(n/d) == r7*(1 + r6 + r7)/(Mul(7, (sqrt(-2*r29 + 11) + r5), + evaluate=False)) + + +def test_sqrtdenest4(): + # see Denest_en.pdf in https://github.com/sympy/sympy/issues/3192 + z = sqrt(8 - r2*sqrt(5 - r5) - sqrt(3)*(1 + r5)) + z1 = sqrtdenest(z) + c = sqrt(-r5 + 5) + z1 = ((-r15*c - r3*c + c + r5*c - r6 - r2 + r10 + sqrt(30))/4).expand() + assert sqrtdenest(z) == z1 + + z = sqrt(2*r2*sqrt(r2 + 2) + 5*r2 + 4*sqrt(r2 + 2) + 8) + assert sqrtdenest(z) == r2 + sqrt(r2 + 2) + 2 + + w = 2 + r2 + r3 + (1 + r3)*sqrt(2 + r2 + 5*r3) + z = sqrt((w**2).expand()) + assert sqrtdenest(z) == w.expand() + + +def test_sqrt_symbolic_denest(): + x = Symbol('x') + z = sqrt(((1 + sqrt(sqrt(2 + x) + 3))**2).expand()) + assert sqrtdenest(z) == sqrt((1 + sqrt(sqrt(2 + x) + 3))**2) + z = sqrt(((1 + sqrt(sqrt(2 + cos(1)) + 3))**2).expand()) + assert sqrtdenest(z) == 1 + sqrt(sqrt(2 + cos(1)) + 3) + z = ((1 + cos(2))**4 + 1).expand() + assert sqrtdenest(z) == z + z = sqrt(((1 + sqrt(sqrt(2 + cos(3*x)) + 3))**2 + 1).expand()) + assert sqrtdenest(z) == z + c = cos(3) + c2 = c**2 + assert sqrtdenest(sqrt(2*sqrt(1 + r3)*c + c2 + 1 + r3*c2)) == \ + -1 - sqrt(1 + r3)*c + ra = sqrt(1 + r3) + z = sqrt(20*ra*sqrt(3 + 3*r3) + 12*r3*ra*sqrt(3 + 3*r3) + 64*r3 + 112) + assert sqrtdenest(z) == z + + +def test_issue_5857(): + from sympy.abc import x, y + z = sqrt(1/(4*r3 + 7) + 1) + ans = (r2 + r6)/(r3 + 2) + assert sqrtdenest(z) == ans + assert sqrtdenest(1 + z) == 1 + ans + assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \ + Integral(1 + ans, (x, 1, 2)) + assert sqrtdenest(x + sqrt(y)) == x + sqrt(y) + ans = (r2 + r6)/(r3 + 2) + assert sqrtdenest(z) == ans + assert sqrtdenest(1 + z) == 1 + ans + assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \ + Integral(1 + ans, (x, 1, 2)) + assert sqrtdenest(x + sqrt(y)) == x + sqrt(y) + + +def test_subsets(): + assert subsets(1) == [[1]] + assert subsets(4) == [ + [1, 0, 0, 0], [0, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 0], [1, 0, 1, 0], + [0, 1, 1, 0], [1, 1, 1, 0], [0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1], + [1, 1, 0, 1], [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]] + + +def test_issue_5653(): + assert sqrtdenest( + sqrt(2 + sqrt(2 + sqrt(2)))) == sqrt(2 + sqrt(2 + sqrt(2))) + +def test_issue_12420(): + assert sqrtdenest((3 - sqrt(2)*sqrt(4 + 3*I) + 3*I)/2) == I + e = 3 - sqrt(2)*sqrt(4 + I) + 3*I + assert sqrtdenest(e) == e + +def test_sqrt_ratcomb(): + assert sqrtdenest(sqrt(1 + r3) + sqrt(3 + 3*r3) - sqrt(10 + 6*r3)) == 0 + +def test_issue_18041(): + e = -sqrt(-2 + 2*sqrt(3)*I) + assert sqrtdenest(e) == -1 - sqrt(3)*I + +def test_issue_19914(): + a = Integer(-8) + b = Integer(-1) + r = Integer(63) + d2 = a*a - b*b*r + + assert _sqrt_numeric_denest(a, b, r, d2) == \ + sqrt(14)*I/2 + 3*sqrt(2)*I/2 + assert sqrtdenest(sqrt(-8-sqrt(63))) == sqrt(14)*I/2 + 3*sqrt(2)*I/2 diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_trigsimp.py b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_trigsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..ea091ec8a6c7d654405968e3d035c2bbe02ccdf7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/tests/test_trigsimp.py @@ -0,0 +1,520 @@ +from itertools import product +from sympy.core.function import (Subs, count_ops, diff, expand) +from sympy.core.numbers import (E, I, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (cosh, coth, sinh, tanh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, cot, sin, tan) +from sympy.functions.elementary.trigonometric import (acos, asin, atan2) +from sympy.functions.elementary.trigonometric import (asec, acsc) +from sympy.functions.elementary.trigonometric import (acot, atan) +from sympy.integrals.integrals import integrate +from sympy.matrices.dense import Matrix +from sympy.simplify.simplify import simplify +from sympy.simplify.trigsimp import (exptrigsimp, trigsimp) + +from sympy.testing.pytest import XFAIL + +from sympy.abc import x, y + + + +def test_trigsimp1(): + x, y = symbols('x,y') + + assert trigsimp(1 - sin(x)**2) == cos(x)**2 + assert trigsimp(1 - cos(x)**2) == sin(x)**2 + assert trigsimp(sin(x)**2 + cos(x)**2) == 1 + assert trigsimp(1 + tan(x)**2) == 1/cos(x)**2 + assert trigsimp(1/cos(x)**2 - 1) == tan(x)**2 + assert trigsimp(1/cos(x)**2 - tan(x)**2) == 1 + assert trigsimp(1 + cot(x)**2) == 1/sin(x)**2 + assert trigsimp(1/sin(x)**2 - 1) == 1/tan(x)**2 + assert trigsimp(1/sin(x)**2 - cot(x)**2) == 1 + + assert trigsimp(5*cos(x)**2 + 5*sin(x)**2) == 5 + assert trigsimp(5*cos(x/2)**2 + 2*sin(x/2)**2) == 3*cos(x)/2 + Rational(7, 2) + + assert trigsimp(sin(x)/cos(x)) == tan(x) + assert trigsimp(2*tan(x)*cos(x)) == 2*sin(x) + assert trigsimp(cot(x)**3*sin(x)**3) == cos(x)**3 + assert trigsimp(y*tan(x)**2/sin(x)**2) == y/cos(x)**2 + assert trigsimp(cot(x)/cos(x)) == 1/sin(x) + + assert trigsimp(sin(x + y) + sin(x - y)) == 2*sin(x)*cos(y) + assert trigsimp(sin(x + y) - sin(x - y)) == 2*sin(y)*cos(x) + assert trigsimp(cos(x + y) + cos(x - y)) == 2*cos(x)*cos(y) + assert trigsimp(cos(x + y) - cos(x - y)) == -2*sin(x)*sin(y) + assert trigsimp(tan(x + y) - tan(x)/(1 - tan(x)*tan(y))) == \ + sin(y)/(-sin(y)*tan(x) + cos(y)) # -tan(y)/(tan(x)*tan(y) - 1) + + assert trigsimp(sinh(x + y) + sinh(x - y)) == 2*sinh(x)*cosh(y) + assert trigsimp(sinh(x + y) - sinh(x - y)) == 2*sinh(y)*cosh(x) + assert trigsimp(cosh(x + y) + cosh(x - y)) == 2*cosh(x)*cosh(y) + assert trigsimp(cosh(x + y) - cosh(x - y)) == 2*sinh(x)*sinh(y) + assert trigsimp(tanh(x + y) - tanh(x)/(1 + tanh(x)*tanh(y))) == \ + sinh(y)/(sinh(y)*tanh(x) + cosh(y)) + + assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2) == 1.0 + e = 2*sin(x)**2 + 2*cos(x)**2 + assert trigsimp(log(e)) == log(2) + + +def test_trigsimp1a(): + assert trigsimp(sin(2)**2*cos(3)*exp(2)/cos(2)**2) == tan(2)**2*cos(3)*exp(2) + assert trigsimp(tan(2)**2*cos(3)*exp(2)*cos(2)**2) == sin(2)**2*cos(3)*exp(2) + assert trigsimp(cot(2)*cos(3)*exp(2)*sin(2)) == cos(3)*exp(2)*cos(2) + assert trigsimp(tan(2)*cos(3)*exp(2)/sin(2)) == cos(3)*exp(2)/cos(2) + assert trigsimp(cot(2)*cos(3)*exp(2)/cos(2)) == cos(3)*exp(2)/sin(2) + assert trigsimp(cot(2)*cos(3)*exp(2)*tan(2)) == cos(3)*exp(2) + assert trigsimp(sinh(2)*cos(3)*exp(2)/cosh(2)) == tanh(2)*cos(3)*exp(2) + assert trigsimp(tanh(2)*cos(3)*exp(2)*cosh(2)) == sinh(2)*cos(3)*exp(2) + assert trigsimp(coth(2)*cos(3)*exp(2)*sinh(2)) == cosh(2)*cos(3)*exp(2) + assert trigsimp(tanh(2)*cos(3)*exp(2)/sinh(2)) == cos(3)*exp(2)/cosh(2) + assert trigsimp(coth(2)*cos(3)*exp(2)/cosh(2)) == cos(3)*exp(2)/sinh(2) + assert trigsimp(coth(2)*cos(3)*exp(2)*tanh(2)) == cos(3)*exp(2) + + +def test_trigsimp2(): + x, y = symbols('x,y') + assert trigsimp(cos(x)**2*sin(y)**2 + cos(x)**2*cos(y)**2 + sin(x)**2, + recursive=True) == 1 + assert trigsimp(sin(x)**2*sin(y)**2 + sin(x)**2*cos(y)**2 + cos(x)**2, + recursive=True) == 1 + assert trigsimp( + Subs(x, x, sin(y)**2 + cos(y)**2)) == Subs(x, x, 1) + + +def test_issue_4373(): + x = Symbol("x") + assert abs(trigsimp(2.0*sin(x)**2 + 2.0*cos(x)**2) - 2.0) < 1e-10 + + +def test_trigsimp3(): + x, y = symbols('x,y') + assert trigsimp(sin(x)/cos(x)) == tan(x) + assert trigsimp(sin(x)**2/cos(x)**2) == tan(x)**2 + assert trigsimp(sin(x)**3/cos(x)**3) == tan(x)**3 + assert trigsimp(sin(x)**10/cos(x)**10) == tan(x)**10 + + assert trigsimp(cos(x)/sin(x)) == 1/tan(x) + assert trigsimp(cos(x)**2/sin(x)**2) == 1/tan(x)**2 + assert trigsimp(cos(x)**10/sin(x)**10) == 1/tan(x)**10 + + assert trigsimp(tan(x)) == trigsimp(sin(x)/cos(x)) + + +def test_issue_4661(): + a, x, y = symbols('a x y') + eq = -4*sin(x)**4 + 4*cos(x)**4 - 8*cos(x)**2 + assert trigsimp(eq) == -4 + n = sin(x)**6 + 4*sin(x)**4*cos(x)**2 + 5*sin(x)**2*cos(x)**4 + 2*cos(x)**6 + d = -sin(x)**2 - 2*cos(x)**2 + assert simplify(n/d) == -1 + assert trigsimp(-2*cos(x)**2 + cos(x)**4 - sin(x)**4) == -1 + eq = (- sin(x)**3/4)*cos(x) + (cos(x)**3/4)*sin(x) - sin(2*x)*cos(2*x)/8 + assert trigsimp(eq) == 0 + + +def test_issue_4494(): + a, b = symbols('a b') + eq = sin(a)**2*sin(b)**2 + cos(a)**2*cos(b)**2*tan(a)**2 + cos(a)**2 + assert trigsimp(eq) == 1 + + +def test_issue_5948(): + a, x, y = symbols('a x y') + assert trigsimp(diff(integrate(cos(x)/sin(x)**7, x), x)) == \ + cos(x)/sin(x)**7 + + +def test_issue_4775(): + a, x, y = symbols('a x y') + assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)) == sin(x + y) + assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)+3) == sin(x + y) + 3 + + +def test_issue_4280(): + a, x, y = symbols('a x y') + assert trigsimp(cos(x)**2 + cos(y)**2*sin(x)**2 + sin(y)**2*sin(x)**2) == 1 + assert trigsimp(a**2*sin(x)**2 + a**2*cos(y)**2*cos(x)**2 + a**2*cos(x)**2*sin(y)**2) == a**2 + assert trigsimp(a**2*cos(y)**2*sin(x)**2 + a**2*sin(y)**2*sin(x)**2) == a**2*sin(x)**2 + + +def test_issue_3210(): + eqs = (sin(2)*cos(3) + sin(3)*cos(2), + -sin(2)*sin(3) + cos(2)*cos(3), + sin(2)*cos(3) - sin(3)*cos(2), + sin(2)*sin(3) + cos(2)*cos(3), + sin(2)*sin(3) + cos(2)*cos(3) + cos(2), + sinh(2)*cosh(3) + sinh(3)*cosh(2), + sinh(2)*sinh(3) + cosh(2)*cosh(3), + ) + assert [trigsimp(e) for e in eqs] == [ + sin(5), + cos(5), + -sin(1), + cos(1), + cos(1) + cos(2), + sinh(5), + cosh(5), + ] + + +def test_trigsimp_issues(): + a, x, y = symbols('a x y') + + # issue 4625 - factor_terms works, too + assert trigsimp(sin(x)**3 + cos(x)**2*sin(x)) == sin(x) + + # issue 5948 + assert trigsimp(diff(integrate(cos(x)/sin(x)**3, x), x)) == \ + cos(x)/sin(x)**3 + assert trigsimp(diff(integrate(sin(x)/cos(x)**3, x), x)) == \ + sin(x)/cos(x)**3 + + # check integer exponents + e = sin(x)**y/cos(x)**y + assert trigsimp(e) == e + assert trigsimp(e.subs(y, 2)) == tan(x)**2 + assert trigsimp(e.subs(x, 1)) == tan(1)**y + + # check for multiple patterns + assert (cos(x)**2/sin(x)**2*cos(y)**2/sin(y)**2).trigsimp() == \ + 1/tan(x)**2/tan(y)**2 + assert trigsimp(cos(x)/sin(x)*cos(x+y)/sin(x+y)) == \ + 1/(tan(x)*tan(x + y)) + + eq = cos(2)*(cos(3) + 1)**2/(cos(3) - 1)**2 + assert trigsimp(eq) == eq.factor() # factor makes denom (-1 + cos(3))**2 + assert trigsimp(cos(2)*(cos(3) + 1)**2*(cos(3) - 1)**2) == \ + cos(2)*sin(3)**4 + + # issue 6789; this generates an expression that formerly caused + # trigsimp to hang + assert cot(x).equals(tan(x)) is False + + # nan or the unchanged expression is ok, but not sin(1) + z = cos(x)**2 + sin(x)**2 - 1 + z1 = tan(x)**2 - 1/cot(x)**2 + n = (1 + z1/z) + assert trigsimp(sin(n)) != sin(1) + eq = x*(n - 1) - x*n + assert trigsimp(eq) is S.NaN + assert trigsimp(eq, recursive=True) is S.NaN + assert trigsimp(1).is_Integer + + assert trigsimp(-sin(x)**4 - 2*sin(x)**2*cos(x)**2 - cos(x)**4) == -1 + + +def test_trigsimp_issue_2515(): + x = Symbol('x') + assert trigsimp(x*cos(x)*tan(x)) == x*sin(x) + assert trigsimp(-sin(x) + cos(x)*tan(x)) == 0 + + +def test_trigsimp_issue_3826(): + assert trigsimp(tan(2*x).expand(trig=True)) == tan(2*x) + + +def test_trigsimp_issue_4032(): + n = Symbol('n', integer=True, positive=True) + assert trigsimp(2**(n/2)*cos(pi*n/4)/2 + 2**(n - 1)/2) == \ + 2**(n/2)*cos(pi*n/4)/2 + 2**n/4 + + +def test_trigsimp_issue_7761(): + assert trigsimp(cosh(pi/4)) == cosh(pi/4) + + +def test_trigsimp_noncommutative(): + x, y = symbols('x,y') + A, B = symbols('A,B', commutative=False) + + assert trigsimp(A - A*sin(x)**2) == A*cos(x)**2 + assert trigsimp(A - A*cos(x)**2) == A*sin(x)**2 + assert trigsimp(A*sin(x)**2 + A*cos(x)**2) == A + assert trigsimp(A + A*tan(x)**2) == A/cos(x)**2 + assert trigsimp(A/cos(x)**2 - A) == A*tan(x)**2 + assert trigsimp(A/cos(x)**2 - A*tan(x)**2) == A + assert trigsimp(A + A*cot(x)**2) == A/sin(x)**2 + assert trigsimp(A/sin(x)**2 - A) == A/tan(x)**2 + assert trigsimp(A/sin(x)**2 - A*cot(x)**2) == A + + assert trigsimp(y*A*cos(x)**2 + y*A*sin(x)**2) == y*A + + assert trigsimp(A*sin(x)/cos(x)) == A*tan(x) + assert trigsimp(A*tan(x)*cos(x)) == A*sin(x) + assert trigsimp(A*cot(x)**3*sin(x)**3) == A*cos(x)**3 + assert trigsimp(y*A*tan(x)**2/sin(x)**2) == y*A/cos(x)**2 + assert trigsimp(A*cot(x)/cos(x)) == A/sin(x) + + assert trigsimp(A*sin(x + y) + A*sin(x - y)) == 2*A*sin(x)*cos(y) + assert trigsimp(A*sin(x + y) - A*sin(x - y)) == 2*A*sin(y)*cos(x) + assert trigsimp(A*cos(x + y) + A*cos(x - y)) == 2*A*cos(x)*cos(y) + assert trigsimp(A*cos(x + y) - A*cos(x - y)) == -2*A*sin(x)*sin(y) + + assert trigsimp(A*sinh(x + y) + A*sinh(x - y)) == 2*A*sinh(x)*cosh(y) + assert trigsimp(A*sinh(x + y) - A*sinh(x - y)) == 2*A*sinh(y)*cosh(x) + assert trigsimp(A*cosh(x + y) + A*cosh(x - y)) == 2*A*cosh(x)*cosh(y) + assert trigsimp(A*cosh(x + y) - A*cosh(x - y)) == 2*A*sinh(x)*sinh(y) + + assert trigsimp(A*cos(0.12345)**2 + A*sin(0.12345)**2) == 1.0*A + + +def test_hyperbolic_simp(): + x, y = symbols('x,y') + + assert trigsimp(sinh(x)**2 + 1) == cosh(x)**2 + assert trigsimp(cosh(x)**2 - 1) == sinh(x)**2 + assert trigsimp(cosh(x)**2 - sinh(x)**2) == 1 + assert trigsimp(1 - tanh(x)**2) == 1/cosh(x)**2 + assert trigsimp(1 - 1/cosh(x)**2) == tanh(x)**2 + assert trigsimp(tanh(x)**2 + 1/cosh(x)**2) == 1 + assert trigsimp(coth(x)**2 - 1) == 1/sinh(x)**2 + assert trigsimp(1/sinh(x)**2 + 1) == 1/tanh(x)**2 + assert trigsimp(coth(x)**2 - 1/sinh(x)**2) == 1 + + assert trigsimp(5*cosh(x)**2 - 5*sinh(x)**2) == 5 + assert trigsimp(5*cosh(x/2)**2 - 2*sinh(x/2)**2) == 3*cosh(x)/2 + Rational(7, 2) + + assert trigsimp(sinh(x)/cosh(x)) == tanh(x) + assert trigsimp(tanh(x)) == trigsimp(sinh(x)/cosh(x)) + assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x) + assert trigsimp(2*tanh(x)*cosh(x)) == 2*sinh(x) + assert trigsimp(coth(x)**3*sinh(x)**3) == cosh(x)**3 + assert trigsimp(y*tanh(x)**2/sinh(x)**2) == y/cosh(x)**2 + assert trigsimp(coth(x)/cosh(x)) == 1/sinh(x) + + for a in (pi/6*I, pi/4*I, pi/3*I): + assert trigsimp(sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x + a) + assert trigsimp(-sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x - a) + + e = 2*cosh(x)**2 - 2*sinh(x)**2 + assert trigsimp(log(e)) == log(2) + + # issue 19535: + assert trigsimp(sqrt(cosh(x)**2 - 1)) == sqrt(sinh(x)**2) + + assert trigsimp(cosh(x)**2*cosh(y)**2 - cosh(x)**2*sinh(y)**2 - sinh(x)**2, + recursive=True) == 1 + assert trigsimp(sinh(x)**2*sinh(y)**2 - sinh(x)**2*cosh(y)**2 + cosh(x)**2, + recursive=True) == 1 + + assert abs(trigsimp(2.0*cosh(x)**2 - 2.0*sinh(x)**2) - 2.0) < 1e-10 + + assert trigsimp(sinh(x)**2/cosh(x)**2) == tanh(x)**2 + assert trigsimp(sinh(x)**3/cosh(x)**3) == tanh(x)**3 + assert trigsimp(sinh(x)**10/cosh(x)**10) == tanh(x)**10 + assert trigsimp(cosh(x)**3/sinh(x)**3) == 1/tanh(x)**3 + + assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x) + assert trigsimp(cosh(x)**2/sinh(x)**2) == 1/tanh(x)**2 + assert trigsimp(cosh(x)**10/sinh(x)**10) == 1/tanh(x)**10 + + assert trigsimp(x*cosh(x)*tanh(x)) == x*sinh(x) + assert trigsimp(-sinh(x) + cosh(x)*tanh(x)) == 0 + + assert tan(x) != 1/cot(x) # cot doesn't auto-simplify + + assert trigsimp(tan(x) - 1/cot(x)) == 0 + assert trigsimp(3*tanh(x)**7 - 2/coth(x)**7) == tanh(x)**7 + + +def test_trigsimp_groebner(): + from sympy.simplify.trigsimp import trigsimp_groebner + + c = cos(x) + s = sin(x) + ex = (4*s*c + 12*s + 5*c**3 + 21*c**2 + 23*c + 15)/( + -s*c**2 + 2*s*c + 15*s + 7*c**3 + 31*c**2 + 37*c + 21) + resnum = (5*s - 5*c + 1) + resdenom = (8*s - 6*c) + results = [resnum/resdenom, (-resnum)/(-resdenom)] + assert trigsimp_groebner(ex) in results + assert trigsimp_groebner(s/c, hints=[tan]) == tan(x) + assert trigsimp_groebner(c*s) == c*s + assert trigsimp((-s + 1)/c + c/(-s + 1), + method='groebner') == 2/c + assert trigsimp((-s + 1)/c + c/(-s + 1), + method='groebner', polynomial=True) == 2/c + + # Test quick=False works + assert trigsimp_groebner(ex, hints=[2]) in results + assert trigsimp_groebner(ex, hints=[int(2)]) in results + + # test "I" + assert trigsimp_groebner(sin(I*x)/cos(I*x), hints=[tanh]) == I*tanh(x) + + # test hyperbolic / sums + assert trigsimp_groebner((tanh(x)+tanh(y))/(1+tanh(x)*tanh(y)), + hints=[(tanh, x, y)]) == tanh(x + y) + + +def test_issue_2827_trigsimp_methods(): + measure1 = lambda expr: len(str(expr)) + measure2 = lambda expr: -count_ops(expr) + # Return the most complicated result + expr = (x + 1)/(x + sin(x)**2 + cos(x)**2) + ans = Matrix([1]) + M = Matrix([expr]) + assert trigsimp(M, method='fu', measure=measure1) == ans + assert trigsimp(M, method='fu', measure=measure2) != ans + # all methods should work with Basic expressions even if they + # aren't Expr + M = Matrix.eye(1) + assert all(trigsimp(M, method=m) == M for m in + 'fu matching groebner old'.split()) + # watch for E in exptrigsimp, not only exp() + eq = 1/sqrt(E) + E + assert exptrigsimp(eq) == eq + +def test_issue_15129_trigsimp_methods(): + t1 = Matrix([sin(Rational(1, 50)), cos(Rational(1, 50)), 0]) + t2 = Matrix([sin(Rational(1, 25)), cos(Rational(1, 25)), 0]) + t3 = Matrix([cos(Rational(1, 25)), sin(Rational(1, 25)), 0]) + r1 = t1.dot(t2) + r2 = t1.dot(t3) + assert trigsimp(r1) == cos(Rational(1, 50)) + assert trigsimp(r2) == sin(Rational(3, 50)) + +def test_exptrigsimp(): + def valid(a, b): + from sympy.core.random import verify_numerically as tn + if not (tn(a, b) and a == b): + return False + return True + + assert exptrigsimp(exp(x) + exp(-x)) == 2*cosh(x) + assert exptrigsimp(exp(x) - exp(-x)) == 2*sinh(x) + assert exptrigsimp((2*exp(x)-2*exp(-x))/(exp(x)+exp(-x))) == 2*tanh(x) + assert exptrigsimp((2*exp(2*x)-2)/(exp(2*x)+1)) == 2*tanh(x) + e = [cos(x) + I*sin(x), cos(x) - I*sin(x), + cosh(x) - sinh(x), cosh(x) + sinh(x)] + ok = [exp(I*x), exp(-I*x), exp(-x), exp(x)] + assert all(valid(i, j) for i, j in zip( + [exptrigsimp(ei) for ei in e], ok)) + + ue = [cos(x) + sin(x), cos(x) - sin(x), + cosh(x) + I*sinh(x), cosh(x) - I*sinh(x)] + assert [exptrigsimp(ei) == ei for ei in ue] + + res = [] + ok = [y*tanh(1), 1/(y*tanh(1)), I*y*tan(1), -I/(y*tan(1)), + y*tanh(x), 1/(y*tanh(x)), I*y*tan(x), -I/(y*tan(x)), + y*tanh(1 + I), 1/(y*tanh(1 + I))] + for a in (1, I, x, I*x, 1 + I): + w = exp(a) + eq = y*(w - 1/w)/(w + 1/w) + res.append(simplify(eq)) + res.append(simplify(1/eq)) + assert all(valid(i, j) for i, j in zip(res, ok)) + + for a in range(1, 3): + w = exp(a) + e = w + 1/w + s = simplify(e) + assert s == exptrigsimp(e) + assert valid(s, 2*cosh(a)) + e = w - 1/w + s = simplify(e) + assert s == exptrigsimp(e) + assert valid(s, 2*sinh(a)) + +def test_exptrigsimp_noncommutative(): + a,b = symbols('a b', commutative=False) + x = Symbol('x', commutative=True) + assert exp(a + x) == exptrigsimp(exp(a)*exp(x)) + p = exp(a)*exp(b) - exp(b)*exp(a) + assert p == exptrigsimp(p) != 0 + +def test_powsimp_on_numbers(): + assert 2**(Rational(1, 3) - 2) == 2**Rational(1, 3)/4 + + +@XFAIL +def test_issue_6811_fail(): + # from doc/src/modules/physics/mechanics/examples.rst, the current `eq` + # at Line 576 (in different variables) was formerly the equivalent and + # shorter expression given below...it would be nice to get the short one + # back again + xp, y, x, z = symbols('xp, y, x, z') + eq = 4*(-19*sin(x)*y + 5*sin(3*x)*y + 15*cos(2*x)*z - 21*z)*xp/(9*cos(x) - 5*cos(3*x)) + assert trigsimp(eq) == -2*(2*cos(x)*tan(x)*y + 3*z)*xp/cos(x) + + +def test_Piecewise(): + e1 = x*(x + y) - y*(x + y) + e2 = sin(x)**2 + cos(x)**2 + e3 = expand((x + y)*y/x) + # s1 = simplify(e1) + s2 = simplify(e2) + # s3 = simplify(e3) + + # trigsimp tries not to touch non-trig containing args + assert trigsimp(Piecewise((e1, e3 < e2), (e3, True))) == \ + Piecewise((e1, e3 < s2), (e3, True)) + + +def test_issue_21594(): + assert simplify(exp(Rational(1,2)) + exp(Rational(-1,2))) == cosh(S.Half)*2 + + +def test_trigsimp_old(): + x, y = symbols('x,y') + + assert trigsimp(1 - sin(x)**2, old=True) == cos(x)**2 + assert trigsimp(1 - cos(x)**2, old=True) == sin(x)**2 + assert trigsimp(sin(x)**2 + cos(x)**2, old=True) == 1 + assert trigsimp(1 + tan(x)**2, old=True) == 1/cos(x)**2 + assert trigsimp(1/cos(x)**2 - 1, old=True) == tan(x)**2 + assert trigsimp(1/cos(x)**2 - tan(x)**2, old=True) == 1 + assert trigsimp(1 + cot(x)**2, old=True) == 1/sin(x)**2 + assert trigsimp(1/sin(x)**2 - cot(x)**2, old=True) == 1 + + assert trigsimp(5*cos(x)**2 + 5*sin(x)**2, old=True) == 5 + + assert trigsimp(sin(x)/cos(x), old=True) == tan(x) + assert trigsimp(2*tan(x)*cos(x), old=True) == 2*sin(x) + assert trigsimp(cot(x)**3*sin(x)**3, old=True) == cos(x)**3 + assert trigsimp(y*tan(x)**2/sin(x)**2, old=True) == y/cos(x)**2 + assert trigsimp(cot(x)/cos(x), old=True) == 1/sin(x) + + assert trigsimp(sin(x + y) + sin(x - y), old=True) == 2*sin(x)*cos(y) + assert trigsimp(sin(x + y) - sin(x - y), old=True) == 2*sin(y)*cos(x) + assert trigsimp(cos(x + y) + cos(x - y), old=True) == 2*cos(x)*cos(y) + assert trigsimp(cos(x + y) - cos(x - y), old=True) == -2*sin(x)*sin(y) + + assert trigsimp(sinh(x + y) + sinh(x - y), old=True) == 2*sinh(x)*cosh(y) + assert trigsimp(sinh(x + y) - sinh(x - y), old=True) == 2*sinh(y)*cosh(x) + assert trigsimp(cosh(x + y) + cosh(x - y), old=True) == 2*cosh(x)*cosh(y) + assert trigsimp(cosh(x + y) - cosh(x - y), old=True) == 2*sinh(x)*sinh(y) + + assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2, old=True) == 1.0 + + assert trigsimp(sin(x)/cos(x), old=True, method='combined') == tan(x) + assert trigsimp(sin(x)/cos(x), old=True, method='groebner') == sin(x)/cos(x) + assert trigsimp(sin(x)/cos(x), old=True, method='groebner', hints=[tan]) == tan(x) + + assert trigsimp(1-sin(sin(x)**2+cos(x)**2)**2, old=True, deep=True) == cos(1)**2 + + +def test_trigsimp_inverse(): + alpha = symbols('alpha') + s, c = sin(alpha), cos(alpha) + + for finv in [asin, acos, asec, acsc, atan, acot]: + f = finv.inverse(None) + assert alpha == trigsimp(finv(f(alpha)), inverse=True) + + # test atan2(cos, sin), atan2(sin, cos), etc... + for a, b in [[c, s], [s, c]]: + for i, j in product([-1, 1], repeat=2): + angle = atan2(i*b, j*a) + angle_inverted = trigsimp(angle, inverse=True) + assert angle_inverted != angle # assures simplification happened + assert sin(angle_inverted) == trigsimp(sin(angle)) + assert cos(angle_inverted) == trigsimp(cos(angle)) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/traversaltools.py b/.venv/lib/python3.13/site-packages/sympy/simplify/traversaltools.py new file mode 100644 index 0000000000000000000000000000000000000000..75b0bd0d8fd198cb12640ab8a0fe63a23c81ed8f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/traversaltools.py @@ -0,0 +1,15 @@ +from sympy.core.traversal import use as _use +from sympy.utilities.decorator import deprecated + +use = deprecated( + """ + Using use from the sympy.simplify.traversaltools submodule is + deprecated. + + Instead, use use from the top-level sympy namespace, like + + sympy.use + """, + deprecated_since_version="1.10", + active_deprecations_target="deprecated-traversal-functions-moved" +)(_use) diff --git a/.venv/lib/python3.13/site-packages/sympy/simplify/trigsimp.py b/.venv/lib/python3.13/site-packages/sympy/simplify/trigsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5be1444a4625e4b63b339877e441d12cfbe8de --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/simplify/trigsimp.py @@ -0,0 +1,1252 @@ +from collections import defaultdict +from functools import reduce + +from sympy.core import (sympify, Basic, S, Expr, factor_terms, + Mul, Add, bottom_up) +from sympy.core.cache import cacheit +from sympy.core.function import (count_ops, _mexpand, FunctionClass, expand, + expand_mul, _coeff_isneg, Derivative) +from sympy.core.numbers import I, Integer +from sympy.core.intfunc import igcd +from sympy.core.sorting import _nodes +from sympy.core.symbol import Dummy, symbols, Wild +from sympy.external.gmpy import SYMPY_INTS +from sympy.functions import sin, cos, exp, cosh, tanh, sinh, tan, cot, coth +from sympy.functions import atan2 +from sympy.functions.elementary.hyperbolic import HyperbolicFunction +from sympy.functions.elementary.trigonometric import TrigonometricFunction +from sympy.polys import Poly, factor, cancel, parallel_poly_from_expr +from sympy.polys.domains import ZZ +from sympy.polys.polyerrors import PolificationFailed +from sympy.polys.polytools import groebner +from sympy.simplify.cse_main import cse +from sympy.strategies.core import identity +from sympy.strategies.tree import greedy +from sympy.utilities.iterables import iterable +from sympy.utilities.misc import debug + +def trigsimp_groebner(expr, hints=[], quick=False, order="grlex", + polynomial=False): + """ + Simplify trigonometric expressions using a groebner basis algorithm. + + Explanation + =========== + + This routine takes a fraction involving trigonometric or hyperbolic + expressions, and tries to simplify it. The primary metric is the + total degree. Some attempts are made to choose the simplest possible + expression of the minimal degree, but this is non-rigorous, and also + very slow (see the ``quick=True`` option). + + If ``polynomial`` is set to True, instead of simplifying numerator and + denominator together, this function just brings numerator and denominator + into a canonical form. This is much faster, but has potentially worse + results. However, if the input is a polynomial, then the result is + guaranteed to be an equivalent polynomial of minimal degree. + + The most important option is hints. Its entries can be any of the + following: + + - a natural number + - a function + - an iterable of the form (func, var1, var2, ...) + - anything else, interpreted as a generator + + A number is used to indicate that the search space should be increased. + A function is used to indicate that said function is likely to occur in a + simplified expression. + An iterable is used indicate that func(var1 + var2 + ...) is likely to + occur in a simplified . + An additional generator also indicates that it is likely to occur. + (See examples below). + + This routine carries out various computationally intensive algorithms. + The option ``quick=True`` can be used to suppress one particularly slow + step (at the expense of potentially more complicated results, but never at + the expense of increased total degree). + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import sin, tan, cos, sinh, cosh, tanh + >>> from sympy.simplify.trigsimp import trigsimp_groebner + + Suppose you want to simplify ``sin(x)*cos(x)``. Naively, nothing happens: + + >>> ex = sin(x)*cos(x) + >>> trigsimp_groebner(ex) + sin(x)*cos(x) + + This is because ``trigsimp_groebner`` only looks for a simplification + involving just ``sin(x)`` and ``cos(x)``. You can tell it to also try + ``2*x`` by passing ``hints=[2]``: + + >>> trigsimp_groebner(ex, hints=[2]) + sin(2*x)/2 + >>> trigsimp_groebner(sin(x)**2 - cos(x)**2, hints=[2]) + -cos(2*x) + + Increasing the search space this way can quickly become expensive. A much + faster way is to give a specific expression that is likely to occur: + + >>> trigsimp_groebner(ex, hints=[sin(2*x)]) + sin(2*x)/2 + + Hyperbolic expressions are similarly supported: + + >>> trigsimp_groebner(sinh(2*x)/sinh(x)) + 2*cosh(x) + + Note how no hints had to be passed, since the expression already involved + ``2*x``. + + The tangent function is also supported. You can either pass ``tan`` in the + hints, to indicate that tan should be tried whenever cosine or sine are, + or you can pass a specific generator: + + >>> trigsimp_groebner(sin(x)/cos(x), hints=[tan]) + tan(x) + >>> trigsimp_groebner(sinh(x)/cosh(x), hints=[tanh(x)]) + tanh(x) + + Finally, you can use the iterable form to suggest that angle sum formulae + should be tried: + + >>> ex = (tan(x) + tan(y))/(1 - tan(x)*tan(y)) + >>> trigsimp_groebner(ex, hints=[(tan, x, y)]) + tan(x + y) + """ + # TODO + # - preprocess by replacing everything by funcs we can handle + # - optionally use cot instead of tan + # - more intelligent hinting. + # For example, if the ideal is small, and we have sin(x), sin(y), + # add sin(x + y) automatically... ? + # - algebraic numbers ... + # - expressions of lowest degree are not distinguished properly + # e.g. 1 - sin(x)**2 + # - we could try to order the generators intelligently, so as to influence + # which monomials appear in the quotient basis + + # THEORY + # ------ + # Ratsimpmodprime above can be used to "simplify" a rational function + # modulo a prime ideal. "Simplify" mainly means finding an equivalent + # expression of lower total degree. + # + # We intend to use this to simplify trigonometric functions. To do that, + # we need to decide (a) which ring to use, and (b) modulo which ideal to + # simplify. In practice, (a) means settling on a list of "generators" + # a, b, c, ..., such that the fraction we want to simplify is a rational + # function in a, b, c, ..., with coefficients in ZZ (integers). + # (2) means that we have to decide what relations to impose on the + # generators. There are two practical problems: + # (1) The ideal has to be *prime* (a technical term). + # (2) The relations have to be polynomials in the generators. + # + # We typically have two kinds of generators: + # - trigonometric expressions, like sin(x), cos(5*x), etc + # - "everything else", like gamma(x), pi, etc. + # + # Since this function is trigsimp, we will concentrate on what to do with + # trigonometric expressions. We can also simplify hyperbolic expressions, + # but the extensions should be clear. + # + # One crucial point is that all *other* generators really should behave + # like indeterminates. In particular if (say) "I" is one of them, then + # in fact I**2 + 1 = 0 and we may and will compute non-sensical + # expressions. However, we can work with a dummy and add the relation + # I**2 + 1 = 0 to our ideal, then substitute back in the end. + # + # Now regarding trigonometric generators. We split them into groups, + # according to the argument of the trigonometric functions. We want to + # organise this in such a way that most trigonometric identities apply in + # the same group. For example, given sin(x), cos(2*x) and cos(y), we would + # group as [sin(x), cos(2*x)] and [cos(y)]. + # + # Our prime ideal will be built in three steps: + # (1) For each group, compute a "geometrically prime" ideal of relations. + # Geometrically prime means that it generates a prime ideal in + # CC[gens], not just ZZ[gens]. + # (2) Take the union of all the generators of the ideals for all groups. + # By the geometric primality condition, this is still prime. + # (3) Add further inter-group relations which preserve primality. + # + # Step (1) works as follows. We will isolate common factors in the + # argument, so that all our generators are of the form sin(n*x), cos(n*x) + # or tan(n*x), with n an integer. Suppose first there are no tan terms. + # The ideal [sin(x)**2 + cos(x)**2 - 1] is geometrically prime, since + # X**2 + Y**2 - 1 is irreducible over CC. + # Now, if we have a generator sin(n*x), than we can, using trig identities, + # express sin(n*x) as a polynomial in sin(x) and cos(x). We can add this + # relation to the ideal, preserving geometric primality, since the quotient + # ring is unchanged. + # Thus we have treated all sin and cos terms. + # For tan(n*x), we add a relation tan(n*x)*cos(n*x) - sin(n*x) = 0. + # (This requires of course that we already have relations for cos(n*x) and + # sin(n*x).) It is not obvious, but it seems that this preserves geometric + # primality. + # XXX A real proof would be nice. HELP! + # Sketch that is a prime ideal of + # CC[S, C, T]: + # - it suffices to show that the projective closure in CP**3 is + # irreducible + # - using the half-angle substitutions, we can express sin(x), tan(x), + # cos(x) as rational functions in tan(x/2) + # - from this, we get a rational map from CP**1 to our curve + # - this is a morphism, hence the curve is prime + # + # Step (2) is trivial. + # + # Step (3) works by adding selected relations of the form + # sin(x + y) - sin(x)*cos(y) - sin(y)*cos(x), etc. Geometric primality is + # preserved by the same argument as before. + + def parse_hints(hints): + """Split hints into (n, funcs, iterables, gens).""" + n = 1 + funcs, iterables, gens = [], [], [] + for e in hints: + if isinstance(e, (SYMPY_INTS, Integer)): + n = e + elif isinstance(e, FunctionClass): + funcs.append(e) + elif iterable(e): + iterables.append((e[0], e[1:])) + # XXX sin(x+2y)? + # Note: we go through polys so e.g. + # sin(-x) -> -sin(x) -> sin(x) + gens.extend(parallel_poly_from_expr( + [e[0](x) for x in e[1:]] + [e[0](Add(*e[1:]))])[1].gens) + else: + gens.append(e) + return n, funcs, iterables, gens + + def build_ideal(x, terms): + """ + Build generators for our ideal. ``Terms`` is an iterable with elements of + the form (fn, coeff), indicating that we have a generator fn(coeff*x). + + If any of the terms is trigonometric, sin(x) and cos(x) are guaranteed + to appear in terms. Similarly for hyperbolic functions. For tan(n*x), + sin(n*x) and cos(n*x) are guaranteed. + """ + I = [] + y = Dummy('y') + for fn, coeff in terms: + for c, s, t, rel in ( + [cos, sin, tan, cos(x)**2 + sin(x)**2 - 1], + [cosh, sinh, tanh, cosh(x)**2 - sinh(x)**2 - 1]): + if coeff == 1 and fn in [c, s]: + I.append(rel) + elif fn == t: + I.append(t(coeff*x)*c(coeff*x) - s(coeff*x)) + elif fn in [c, s]: + cn = fn(coeff*y).expand(trig=True).subs(y, x) + I.append(fn(coeff*x) - cn) + return list(set(I)) + + def analyse_gens(gens, hints): + """ + Analyse the generators ``gens``, using the hints ``hints``. + + The meaning of ``hints`` is described in the main docstring. + Return a new list of generators, and also the ideal we should + work with. + """ + # First parse the hints + n, funcs, iterables, extragens = parse_hints(hints) + debug('n=%s funcs: %s iterables: %s extragens: %s', + (funcs, iterables, extragens)) + + # We just add the extragens to gens and analyse them as before + gens = list(gens) + gens.extend(extragens) + + # remove duplicates + funcs = list(set(funcs)) + iterables = list(set(iterables)) + gens = list(set(gens)) + + # all the functions we can do anything with + allfuncs = {sin, cos, tan, sinh, cosh, tanh} + # sin(3*x) -> ((3, x), sin) + trigterms = [(g.args[0].as_coeff_mul(), g.func) for g in gens + if g.func in allfuncs] + # Our list of new generators - start with anything that we cannot + # work with (i.e. is not a trigonometric term) + freegens = [g for g in gens if g.func not in allfuncs] + newgens = [] + trigdict = {} + for (coeff, var), fn in trigterms: + trigdict.setdefault(var, []).append((coeff, fn)) + res = [] # the ideal + + for key, val in trigdict.items(): + # We have now assembeled a dictionary. Its keys are common + # arguments in trigonometric expressions, and values are lists of + # pairs (fn, coeff). x0, (fn, coeff) in trigdict means that we + # need to deal with fn(coeff*x0). We take the rational gcd of the + # coeffs, call it ``gcd``. We then use x = x0/gcd as "base symbol", + # all other arguments are integral multiples thereof. + # We will build an ideal which works with sin(x), cos(x). + # If hint tan is provided, also work with tan(x). Moreover, if + # n > 1, also work with sin(k*x) for k <= n, and similarly for cos + # (and tan if the hint is provided). Finally, any generators which + # the ideal does not work with but we need to accommodate (either + # because it was in expr or because it was provided as a hint) + # we also build into the ideal. + # This selection process is expressed in the list ``terms``. + # build_ideal then generates the actual relations in our ideal, + # from this list. + fns = [x[1] for x in val] + val = [x[0] for x in val] + gcd = reduce(igcd, val) + terms = [(fn, v/gcd) for (fn, v) in zip(fns, val)] + fs = set(funcs + fns) + for c, s, t in ([cos, sin, tan], [cosh, sinh, tanh]): + if any(x in fs for x in (c, s, t)): + fs.add(c) + fs.add(s) + for fn in fs: + terms.extend((fn, k) for k in range(1, n + 1)) + extra = [] + for fn, v in terms: + if fn == tan: + extra.append((sin, v)) + extra.append((cos, v)) + if fn in [sin, cos] and tan in fs: + extra.append((tan, v)) + if fn == tanh: + extra.append((sinh, v)) + extra.append((cosh, v)) + if fn in [sinh, cosh] and tanh in fs: + extra.append((tanh, v)) + terms.extend(extra) + x = gcd*Mul(*key) + r = build_ideal(x, terms) + res.extend(r) + newgens.extend({fn(v*x) for fn, v in terms}) + + # Add generators for compound expressions from iterables + for fn, args in iterables: + if fn == tan: + # Tan expressions are recovered from sin and cos. + iterables.extend([(sin, args), (cos, args)]) + elif fn == tanh: + # Tanh expressions are recovered from sihn and cosh. + iterables.extend([(sinh, args), (cosh, args)]) + else: + dummys = symbols('d:%i' % len(args), cls=Dummy) + expr = fn( Add(*dummys)).expand(trig=True).subs(list(zip(dummys, args))) + res.append(fn(Add(*args)) - expr) + + if myI in gens: + res.append(myI**2 + 1) + freegens.remove(myI) + newgens.append(myI) + + return res, freegens, newgens + + myI = Dummy('I') + expr = expr.subs(S.ImaginaryUnit, myI) + subs = [(myI, S.ImaginaryUnit)] + + num, denom = cancel(expr).as_numer_denom() + try: + (pnum, pdenom), opt = parallel_poly_from_expr([num, denom]) + except PolificationFailed: + return expr + debug('initial gens:', opt.gens) + ideal, freegens, gens = analyse_gens(opt.gens, hints) + debug('ideal:', ideal) + debug('new gens:', gens, " -- len", len(gens)) + debug('free gens:', freegens, " -- len", len(gens)) + # NOTE we force the domain to be ZZ to stop polys from injecting generators + # (which is usually a sign of a bug in the way we build the ideal) + if not gens: + return expr + G = groebner(ideal, order=order, gens=gens, domain=ZZ) + debug('groebner basis:', list(G), " -- len", len(G)) + + # If our fraction is a polynomial in the free generators, simplify all + # coefficients separately: + + from sympy.simplify.ratsimp import ratsimpmodprime + + if freegens and pdenom.has_only_gens(*set(gens).intersection(pdenom.gens)): + num = Poly(num, gens=gens+freegens).eject(*gens) + res = [] + for monom, coeff in num.terms(): + ourgens = set(parallel_poly_from_expr([coeff, denom])[1].gens) + # We compute the transitive closure of all generators that can + # be reached from our generators through relations in the ideal. + changed = True + while changed: + changed = False + for p in ideal: + p = Poly(p) + if not ourgens.issuperset(p.gens) and \ + not p.has_only_gens(*set(p.gens).difference(ourgens)): + changed = True + ourgens.update(p.exclude().gens) + # NOTE preserve order! + realgens = [x for x in gens if x in ourgens] + # The generators of the ideal have now been (implicitly) split + # into two groups: those involving ourgens and those that don't. + # Since we took the transitive closure above, these two groups + # live in subgrings generated by a *disjoint* set of variables. + # Any sensible groebner basis algorithm will preserve this disjoint + # structure (i.e. the elements of the groebner basis can be split + # similarly), and and the two subsets of the groebner basis then + # form groebner bases by themselves. (For the smaller generating + # sets, of course.) + ourG = [g.as_expr() for g in G.polys if + g.has_only_gens(*ourgens.intersection(g.gens))] + res.append(Mul(*[a**b for a, b in zip(freegens, monom)]) * \ + ratsimpmodprime(coeff/denom, ourG, order=order, + gens=realgens, quick=quick, domain=ZZ, + polynomial=polynomial).subs(subs)) + return Add(*res) + # NOTE The following is simpler and has less assumptions on the + # groebner basis algorithm. If the above turns out to be broken, + # use this. + return Add(*[Mul(*[a**b for a, b in zip(freegens, monom)]) * \ + ratsimpmodprime(coeff/denom, list(G), order=order, + gens=gens, quick=quick, domain=ZZ) + for monom, coeff in num.terms()]) + else: + return ratsimpmodprime( + expr, list(G), order=order, gens=freegens+gens, + quick=quick, domain=ZZ, polynomial=polynomial).subs(subs) + + +_trigs = (TrigonometricFunction, HyperbolicFunction) + + +def _trigsimp_inverse(rv): + + def check_args(x, y): + try: + return x.args[0] == y.args[0] + except IndexError: + return False + + def f(rv): + # for simple functions + g = getattr(rv, 'inverse', None) + if (g is not None and isinstance(rv.args[0], g()) and + isinstance(g()(1), TrigonometricFunction)): + return rv.args[0].args[0] + + # for atan2 simplifications, harder because atan2 has 2 args + if isinstance(rv, atan2): + y, x = rv.args + if _coeff_isneg(y): + return -f(atan2(-y, x)) + elif _coeff_isneg(x): + return S.Pi - f(atan2(y, -x)) + + if check_args(x, y): + if isinstance(y, sin) and isinstance(x, cos): + return x.args[0] + if isinstance(y, cos) and isinstance(x, sin): + return S.Pi / 2 - x.args[0] + + return rv + + return bottom_up(rv, f) + + +def trigsimp(expr, inverse=False, **opts): + """Returns a reduced expression by using known trig identities. + + Parameters + ========== + + inverse : bool, optional + If ``inverse=True``, it will be assumed that a composition of inverse + functions, such as sin and asin, can be cancelled in any order. + For example, ``asin(sin(x))`` will yield ``x`` without checking whether + x belongs to the set where this relation is true. The default is False. + Default : True + + method : string, optional + Specifies the method to use. Valid choices are: + + - ``'matching'``, default + - ``'groebner'`` + - ``'combined'`` + - ``'fu'`` + - ``'old'`` + + If ``'matching'``, simplify the expression recursively by targeting + common patterns. If ``'groebner'``, apply an experimental groebner + basis algorithm. In this case further options are forwarded to + ``trigsimp_groebner``, please refer to + its docstring. If ``'combined'``, it first runs the groebner basis + algorithm with small default parameters, then runs the ``'matching'`` + algorithm. If ``'fu'``, run the collection of trigonometric + transformations described by Fu, et al. (see the + :py:func:`~sympy.simplify.fu.fu` docstring). If ``'old'``, the original + SymPy trig simplification function is run. + opts : + Optional keyword arguments passed to the method. See each method's + function docstring for details. + + Examples + ======== + + >>> from sympy import trigsimp, sin, cos, log + >>> from sympy.abc import x + >>> e = 2*sin(x)**2 + 2*cos(x)**2 + >>> trigsimp(e) + 2 + + Simplification occurs wherever trigonometric functions are located. + + >>> trigsimp(log(e)) + log(2) + + Using ``method='groebner'`` (or ``method='combined'``) might lead to + greater simplification. + + The old trigsimp routine can be accessed as with method ``method='old'``. + + >>> from sympy import coth, tanh + >>> t = 3*tanh(x)**7 - 2/coth(x)**7 + >>> trigsimp(t, method='old') == t + True + >>> trigsimp(t) + tanh(x)**7 + + """ + from sympy.simplify.fu import fu + + expr = sympify(expr) + + _eval_trigsimp = getattr(expr, '_eval_trigsimp', None) + if _eval_trigsimp is not None: + return _eval_trigsimp(**opts) + + old = opts.pop('old', False) + if not old: + opts.pop('deep', None) + opts.pop('recursive', None) + method = opts.pop('method', 'matching') + else: + method = 'old' + + def groebnersimp(ex, **opts): + def traverse(e): + if e.is_Atom: + return e + args = [traverse(x) for x in e.args] + if e.is_Function or e.is_Pow: + args = [trigsimp_groebner(x, **opts) for x in args] + return e.func(*args) + new = traverse(ex) + if not isinstance(new, Expr): + return new + return trigsimp_groebner(new, **opts) + + trigsimpfunc = { + 'fu': (lambda x: fu(x, **opts)), + 'matching': (lambda x: futrig(x)), + 'groebner': (lambda x: groebnersimp(x, **opts)), + 'combined': (lambda x: futrig(groebnersimp(x, + polynomial=True, hints=[2, tan]))), + 'old': lambda x: trigsimp_old(x, **opts), + }[method] + + expr_simplified = trigsimpfunc(expr) + if inverse: + expr_simplified = _trigsimp_inverse(expr_simplified) + + return expr_simplified + + +def exptrigsimp(expr): + """ + Simplifies exponential / trigonometric / hyperbolic functions. + + Examples + ======== + + >>> from sympy import exptrigsimp, exp, cosh, sinh + >>> from sympy.abc import z + + >>> exptrigsimp(exp(z) + exp(-z)) + 2*cosh(z) + >>> exptrigsimp(cosh(z) - sinh(z)) + exp(-z) + """ + from sympy.simplify.fu import hyper_as_trig, TR2i + + def exp_trig(e): + # select the better of e, and e rewritten in terms of exp or trig + # functions + choices = [e] + if e.has(*_trigs): + choices.append(e.rewrite(exp)) + choices.append(e.rewrite(cos)) + return min(*choices, key=count_ops) + newexpr = bottom_up(expr, exp_trig) + + def f(rv): + if not rv.is_Mul: + return rv + commutative_part, noncommutative_part = rv.args_cnc() + # Since as_powers_dict loses order information, + # if there is more than one noncommutative factor, + # it should only be used to simplify the commutative part. + if (len(noncommutative_part) > 1): + return f(Mul(*commutative_part))*Mul(*noncommutative_part) + rvd = rv.as_powers_dict() + newd = rvd.copy() + + def signlog(expr, sign=S.One): + if expr is S.Exp1: + return sign, S.One + elif isinstance(expr, exp) or (expr.is_Pow and expr.base == S.Exp1): + return sign, expr.exp + elif sign is S.One: + return signlog(-expr, sign=-S.One) + else: + return None, None + + ee = rvd[S.Exp1] + for k in rvd: + if k.is_Add and len(k.args) == 2: + # k == c*(1 + sign*E**x) + c = k.args[0] + sign, x = signlog(k.args[1]/c) + if not x: + continue + m = rvd[k] + newd[k] -= m + if ee == -x*m/2: + # sinh and cosh + newd[S.Exp1] -= ee + ee = 0 + if sign == 1: + newd[2*c*cosh(x/2)] += m + else: + newd[-2*c*sinh(x/2)] += m + elif newd[1 - sign*S.Exp1**x] == -m: + # tanh + del newd[1 - sign*S.Exp1**x] + if sign == 1: + newd[-c/tanh(x/2)] += m + else: + newd[-c*tanh(x/2)] += m + else: + newd[1 + sign*S.Exp1**x] += m + newd[c] += m + + return Mul(*[k**newd[k] for k in newd]) + newexpr = bottom_up(newexpr, f) + + # sin/cos and sinh/cosh ratios to tan and tanh, respectively + if newexpr.has(HyperbolicFunction): + e, f = hyper_as_trig(newexpr) + newexpr = f(TR2i(e)) + if newexpr.has(TrigonometricFunction): + newexpr = TR2i(newexpr) + + # can we ever generate an I where there was none previously? + if not (newexpr.has(I) and not expr.has(I)): + expr = newexpr + return expr + +#-------------------- the old trigsimp routines --------------------- + +def trigsimp_old(expr, *, first=True, **opts): + """ + Reduces expression by using known trig identities. + + Notes + ===== + + deep: + - Apply trigsimp inside all objects with arguments + + recursive: + - Use common subexpression elimination (cse()) and apply + trigsimp recursively (this is quite expensive if the + expression is large) + + method: + - Determine the method to use. Valid choices are 'matching' (default), + 'groebner', 'combined', 'fu' and 'futrig'. If 'matching', simplify the + expression recursively by pattern matching. If 'groebner', apply an + experimental groebner basis algorithm. In this case further options + are forwarded to ``trigsimp_groebner``, please refer to its docstring. + If 'combined', first run the groebner basis algorithm with small + default parameters, then run the 'matching' algorithm. 'fu' runs the + collection of trigonometric transformations described by Fu, et al. + (see the `fu` docstring) while `futrig` runs a subset of Fu-transforms + that mimic the behavior of `trigsimp`. + + compare: + - show input and output from `trigsimp` and `futrig` when different, + but returns the `trigsimp` value. + + Examples + ======== + + >>> from sympy import trigsimp, sin, cos, log, cot + >>> from sympy.abc import x + >>> e = 2*sin(x)**2 + 2*cos(x)**2 + >>> trigsimp(e, old=True) + 2 + >>> trigsimp(log(e), old=True) + log(2*sin(x)**2 + 2*cos(x)**2) + >>> trigsimp(log(e), deep=True, old=True) + log(2) + + Using `method="groebner"` (or `"combined"`) can sometimes lead to a lot + more simplification: + + >>> e = (-sin(x) + 1)/cos(x) + cos(x)/(-sin(x) + 1) + >>> trigsimp(e, old=True) + (1 - sin(x))/cos(x) + cos(x)/(1 - sin(x)) + >>> trigsimp(e, method="groebner", old=True) + 2/cos(x) + + >>> trigsimp(1/cot(x)**2, compare=True, old=True) + futrig: tan(x)**2 + cot(x)**(-2) + + """ + old = expr + if first: + if not expr.has(*_trigs): + return expr + + trigsyms = set().union(*[t.free_symbols for t in expr.atoms(*_trigs)]) + if len(trigsyms) > 1: + from sympy.simplify.simplify import separatevars + + d = separatevars(expr) + if d.is_Mul: + d = separatevars(d, dict=True) or d + if isinstance(d, dict): + expr = 1 + for v in d.values(): + # remove hollow factoring + was = v + v = expand_mul(v) + opts['first'] = False + vnew = trigsimp(v, **opts) + if vnew == v: + vnew = was + expr *= vnew + old = expr + else: + if d.is_Add: + for s in trigsyms: + r, e = expr.as_independent(s) + if r: + opts['first'] = False + expr = r + trigsimp(e, **opts) + if not expr.is_Add: + break + old = expr + + recursive = opts.pop('recursive', False) + deep = opts.pop('deep', False) + method = opts.pop('method', 'matching') + + def groebnersimp(ex, deep, **opts): + def traverse(e): + if e.is_Atom: + return e + args = [traverse(x) for x in e.args] + if e.is_Function or e.is_Pow: + args = [trigsimp_groebner(x, **opts) for x in args] + return e.func(*args) + if deep: + ex = traverse(ex) + return trigsimp_groebner(ex, **opts) + + trigsimpfunc = { + 'matching': (lambda x, d: _trigsimp(x, d)), + 'groebner': (lambda x, d: groebnersimp(x, d, **opts)), + 'combined': (lambda x, d: _trigsimp(groebnersimp(x, + d, polynomial=True, hints=[2, tan]), + d)) + }[method] + + if recursive: + w, g = cse(expr) + g = trigsimpfunc(g[0], deep) + + for sub in reversed(w): + g = g.subs(sub[0], sub[1]) + g = trigsimpfunc(g, deep) + result = g + else: + result = trigsimpfunc(expr, deep) + + if opts.get('compare', False): + f = futrig(old) + if f != result: + print('\tfutrig:', f) + + return result + + +def _dotrig(a, b): + """Helper to tell whether ``a`` and ``b`` have the same sorts + of symbols in them -- no need to test hyperbolic patterns against + expressions that have no hyperbolics in them.""" + return a.func == b.func and ( + a.has(TrigonometricFunction) and b.has(TrigonometricFunction) or + a.has(HyperbolicFunction) and b.has(HyperbolicFunction)) + + +_trigpat = None +def _trigpats(): + global _trigpat + a, b, c = symbols('a b c', cls=Wild) + d = Wild('d', commutative=False) + + # for the simplifications like sinh/cosh -> tanh: + # DO NOT REORDER THE FIRST 14 since these are assumed to be in this + # order in _match_div_rewrite. + matchers_division = ( + (a*sin(b)**c/cos(b)**c, a*tan(b)**c, sin(b), cos(b)), + (a*tan(b)**c*cos(b)**c, a*sin(b)**c, sin(b), cos(b)), + (a*cot(b)**c*sin(b)**c, a*cos(b)**c, sin(b), cos(b)), + (a*tan(b)**c/sin(b)**c, a/cos(b)**c, sin(b), cos(b)), + (a*cot(b)**c/cos(b)**c, a/sin(b)**c, sin(b), cos(b)), + (a*cot(b)**c*tan(b)**c, a, sin(b), cos(b)), + (a*(cos(b) + 1)**c*(cos(b) - 1)**c, + a*(-sin(b)**2)**c, cos(b) + 1, cos(b) - 1), + (a*(sin(b) + 1)**c*(sin(b) - 1)**c, + a*(-cos(b)**2)**c, sin(b) + 1, sin(b) - 1), + + (a*sinh(b)**c/cosh(b)**c, a*tanh(b)**c, S.One, S.One), + (a*tanh(b)**c*cosh(b)**c, a*sinh(b)**c, S.One, S.One), + (a*coth(b)**c*sinh(b)**c, a*cosh(b)**c, S.One, S.One), + (a*tanh(b)**c/sinh(b)**c, a/cosh(b)**c, S.One, S.One), + (a*coth(b)**c/cosh(b)**c, a/sinh(b)**c, S.One, S.One), + (a*coth(b)**c*tanh(b)**c, a, S.One, S.One), + + (c*(tanh(a) + tanh(b))/(1 + tanh(a)*tanh(b)), + tanh(a + b)*c, S.One, S.One), + ) + + matchers_add = ( + (c*sin(a)*cos(b) + c*cos(a)*sin(b) + d, sin(a + b)*c + d), + (c*cos(a)*cos(b) - c*sin(a)*sin(b) + d, cos(a + b)*c + d), + (c*sin(a)*cos(b) - c*cos(a)*sin(b) + d, sin(a - b)*c + d), + (c*cos(a)*cos(b) + c*sin(a)*sin(b) + d, cos(a - b)*c + d), + (c*sinh(a)*cosh(b) + c*sinh(b)*cosh(a) + d, sinh(a + b)*c + d), + (c*cosh(a)*cosh(b) + c*sinh(a)*sinh(b) + d, cosh(a + b)*c + d), + ) + + # for cos(x)**2 + sin(x)**2 -> 1 + matchers_identity = ( + (a*sin(b)**2, a - a*cos(b)**2), + (a*tan(b)**2, a*(1/cos(b))**2 - a), + (a*cot(b)**2, a*(1/sin(b))**2 - a), + (a*sin(b + c), a*(sin(b)*cos(c) + sin(c)*cos(b))), + (a*cos(b + c), a*(cos(b)*cos(c) - sin(b)*sin(c))), + (a*tan(b + c), a*((tan(b) + tan(c))/(1 - tan(b)*tan(c)))), + + (a*sinh(b)**2, a*cosh(b)**2 - a), + (a*tanh(b)**2, a - a*(1/cosh(b))**2), + (a*coth(b)**2, a + a*(1/sinh(b))**2), + (a*sinh(b + c), a*(sinh(b)*cosh(c) + sinh(c)*cosh(b))), + (a*cosh(b + c), a*(cosh(b)*cosh(c) + sinh(b)*sinh(c))), + (a*tanh(b + c), a*((tanh(b) + tanh(c))/(1 + tanh(b)*tanh(c)))), + + ) + + # Reduce any lingering artifacts, such as sin(x)**2 changing + # to 1-cos(x)**2 when sin(x)**2 was "simpler" + artifacts = ( + (a - a*cos(b)**2 + c, a*sin(b)**2 + c, cos), + (a - a*(1/cos(b))**2 + c, -a*tan(b)**2 + c, cos), + (a - a*(1/sin(b))**2 + c, -a*cot(b)**2 + c, sin), + + (a - a*cosh(b)**2 + c, -a*sinh(b)**2 + c, cosh), + (a - a*(1/cosh(b))**2 + c, a*tanh(b)**2 + c, cosh), + (a + a*(1/sinh(b))**2 + c, a*coth(b)**2 + c, sinh), + + # same as above but with noncommutative prefactor + (a*d - a*d*cos(b)**2 + c, a*d*sin(b)**2 + c, cos), + (a*d - a*d*(1/cos(b))**2 + c, -a*d*tan(b)**2 + c, cos), + (a*d - a*d*(1/sin(b))**2 + c, -a*d*cot(b)**2 + c, sin), + + (a*d - a*d*cosh(b)**2 + c, -a*d*sinh(b)**2 + c, cosh), + (a*d - a*d*(1/cosh(b))**2 + c, a*d*tanh(b)**2 + c, cosh), + (a*d + a*d*(1/sinh(b))**2 + c, a*d*coth(b)**2 + c, sinh), + ) + + _trigpat = (a, b, c, d, matchers_division, matchers_add, + matchers_identity, artifacts) + return _trigpat + + +def _replace_mul_fpowxgpow(expr, f, g, rexp, h, rexph): + """Helper for _match_div_rewrite. + + Replace f(b_)**c_*g(b_)**(rexp(c_)) with h(b)**rexph(c) if f(b_) + and g(b_) are both positive or if c_ is an integer. + """ + # assert expr.is_Mul and expr.is_commutative and f != g + fargs = defaultdict(int) + gargs = defaultdict(int) + args = [] + for x in expr.args: + if x.is_Pow or x.func in (f, g): + b, e = x.as_base_exp() + if b.is_positive or e.is_integer: + if b.func == f: + fargs[b.args[0]] += e + continue + elif b.func == g: + gargs[b.args[0]] += e + continue + args.append(x) + common = set(fargs) & set(gargs) + hit = False + while common: + key = common.pop() + fe = fargs.pop(key) + ge = gargs.pop(key) + if fe == rexp(ge): + args.append(h(key)**rexph(fe)) + hit = True + else: + fargs[key] = fe + gargs[key] = ge + if not hit: + return expr + while fargs: + key, e = fargs.popitem() + args.append(f(key)**e) + while gargs: + key, e = gargs.popitem() + args.append(g(key)**e) + return Mul(*args) + + +_idn = lambda x: x +_midn = lambda x: -x +_one = lambda x: S.One + +def _match_div_rewrite(expr, i): + """helper for __trigsimp""" + if i == 0: + expr = _replace_mul_fpowxgpow(expr, sin, cos, + _midn, tan, _idn) + elif i == 1: + expr = _replace_mul_fpowxgpow(expr, tan, cos, + _idn, sin, _idn) + elif i == 2: + expr = _replace_mul_fpowxgpow(expr, cot, sin, + _idn, cos, _idn) + elif i == 3: + expr = _replace_mul_fpowxgpow(expr, tan, sin, + _midn, cos, _midn) + elif i == 4: + expr = _replace_mul_fpowxgpow(expr, cot, cos, + _midn, sin, _midn) + elif i == 5: + expr = _replace_mul_fpowxgpow(expr, cot, tan, + _idn, _one, _idn) + # i in (6, 7) is skipped + elif i == 8: + expr = _replace_mul_fpowxgpow(expr, sinh, cosh, + _midn, tanh, _idn) + elif i == 9: + expr = _replace_mul_fpowxgpow(expr, tanh, cosh, + _idn, sinh, _idn) + elif i == 10: + expr = _replace_mul_fpowxgpow(expr, coth, sinh, + _idn, cosh, _idn) + elif i == 11: + expr = _replace_mul_fpowxgpow(expr, tanh, sinh, + _midn, cosh, _midn) + elif i == 12: + expr = _replace_mul_fpowxgpow(expr, coth, cosh, + _midn, sinh, _midn) + elif i == 13: + expr = _replace_mul_fpowxgpow(expr, coth, tanh, + _idn, _one, _idn) + else: + return None + return expr + + +def _trigsimp(expr, deep=False): + # protect the cache from non-trig patterns; we only allow + # trig patterns to enter the cache + if expr.has(*_trigs): + return __trigsimp(expr, deep) + return expr + + +@cacheit +def __trigsimp(expr, deep=False): + """recursive helper for trigsimp""" + from sympy.simplify.fu import TR10i + + if _trigpat is None: + _trigpats() + a, b, c, d, matchers_division, matchers_add, \ + matchers_identity, artifacts = _trigpat + + if expr.is_Mul: + # do some simplifications like sin/cos -> tan: + if not expr.is_commutative: + com, nc = expr.args_cnc() + expr = _trigsimp(Mul._from_args(com), deep)*Mul._from_args(nc) + else: + for i, (pattern, simp, ok1, ok2) in enumerate(matchers_division): + if not _dotrig(expr, pattern): + continue + + newexpr = _match_div_rewrite(expr, i) + if newexpr is not None: + if newexpr != expr: + expr = newexpr + break + else: + continue + + # use SymPy matching instead + res = expr.match(pattern) + if res and res.get(c, 0): + if not res[c].is_integer: + ok = ok1.subs(res) + if not ok.is_positive: + continue + ok = ok2.subs(res) + if not ok.is_positive: + continue + # if "a" contains any of trig or hyperbolic funcs with + # argument "b" then skip the simplification + if any(w.args[0] == res[b] for w in res[a].atoms( + TrigonometricFunction, HyperbolicFunction)): + continue + # simplify and finish: + expr = simp.subs(res) + break # process below + + if expr.is_Add: + args = [] + for term in expr.args: + if not term.is_commutative: + com, nc = term.args_cnc() + nc = Mul._from_args(nc) + term = Mul._from_args(com) + else: + nc = S.One + term = _trigsimp(term, deep) + for pattern, result in matchers_identity: + res = term.match(pattern) + if res is not None: + term = result.subs(res) + break + args.append(term*nc) + if args != expr.args: + expr = Add(*args) + expr = min(expr, expand(expr), key=count_ops) + if expr.is_Add: + for pattern, result in matchers_add: + if not _dotrig(expr, pattern): + continue + expr = TR10i(expr) + if expr.has(HyperbolicFunction): + res = expr.match(pattern) + # if "d" contains any trig or hyperbolic funcs with + # argument "a" or "b" then skip the simplification; + # this isn't perfect -- see tests + if res is None or not (a in res and b in res) or any( + w.args[0] in (res[a], res[b]) for w in res[d].atoms( + TrigonometricFunction, HyperbolicFunction)): + continue + expr = result.subs(res) + break + + # Reduce any lingering artifacts, such as sin(x)**2 changing + # to 1 - cos(x)**2 when sin(x)**2 was "simpler" + for pattern, result, ex in artifacts: + if not _dotrig(expr, pattern): + continue + # Substitute a new wild that excludes some function(s) + # to help influence a better match. This is because + # sometimes, for example, 'a' would match sec(x)**2 + a_t = Wild('a', exclude=[ex]) + pattern = pattern.subs(a, a_t) + result = result.subs(a, a_t) + + m = expr.match(pattern) + was = None + while m and was != expr: + was = expr + if m[a_t] == 0 or \ + -m[a_t] in m[c].args or m[a_t] + m[c] == 0: + break + if d in m and m[a_t]*m[d] + m[c] == 0: + break + expr = result.subs(m) + m = expr.match(pattern) + m.setdefault(c, S.Zero) + + elif expr.is_Mul or expr.is_Pow or deep and expr.args: + expr = expr.func(*[_trigsimp(a, deep) for a in expr.args]) + + try: + if not expr.has(*_trigs): + raise TypeError + e = expr.atoms(exp) + new = expr.rewrite(exp, deep=deep) + if new == e: + raise TypeError + fnew = factor(new) + if fnew != new: + new = min([new, factor(new)], key=count_ops) + # if all exp that were introduced disappeared then accept it + if not (new.atoms(exp) - e): + expr = new + except TypeError: + pass + + return expr +#------------------- end of old trigsimp routines -------------------- + + +def futrig(e, *, hyper=True, **kwargs): + """Return simplified ``e`` using Fu-like transformations. + This is not the "Fu" algorithm. This is called by default + from ``trigsimp``. By default, hyperbolics subexpressions + will be simplified, but this can be disabled by setting + ``hyper=False``. + + Examples + ======== + + >>> from sympy import trigsimp, tan, sinh, tanh + >>> from sympy.simplify.trigsimp import futrig + >>> from sympy.abc import x + >>> trigsimp(1/tan(x)**2) + tan(x)**(-2) + + >>> futrig(sinh(x)/tanh(x)) + cosh(x) + + """ + from sympy.simplify.fu import hyper_as_trig + + e = sympify(e) + + if not isinstance(e, Basic): + return e + + if not e.args: + return e + + old = e + e = bottom_up(e, _futrig) + + if hyper and e.has(HyperbolicFunction): + e, f = hyper_as_trig(e) + e = f(bottom_up(e, _futrig)) + + if e != old and e.is_Mul and e.args[0].is_Rational: + # redistribute leading coeff on 2-arg Add + e = Mul(*e.as_coeff_Mul()) + return e + + +def _futrig(e): + """Helper for futrig.""" + from sympy.simplify.fu import ( + TR1, TR2, TR3, TR2i, TR10, L, TR10i, + TR8, TR6, TR15, TR16, TR111, TR5, TRmorrie, TR11, _TR11, TR14, TR22, + TR12) + + if not e.has(TrigonometricFunction): + return e + + if e.is_Mul: + coeff, e = e.as_independent(TrigonometricFunction) + else: + coeff = None + + Lops = lambda x: (L(x), x.count_ops(), _nodes(x), len(x.args), x.is_Add) + trigs = lambda x: x.has(TrigonometricFunction) + + tree = [identity, + ( + TR3, # canonical angles + TR1, # sec-csc -> cos-sin + TR12, # expand tan of sum + lambda x: _eapply(factor, x, trigs), + TR2, # tan-cot -> sin-cos + [identity, lambda x: _eapply(_mexpand, x, trigs)], + TR2i, # sin-cos ratio -> tan + lambda x: _eapply(lambda i: factor(i.normal()), x, trigs), + TR14, # factored identities + TR5, # sin-pow -> cos_pow + TR10, # sin-cos of sums -> sin-cos prod + TR11, _TR11, TR6, # reduce double angles and rewrite cos pows + lambda x: _eapply(factor, x, trigs), + TR14, # factored powers of identities + [identity, lambda x: _eapply(_mexpand, x, trigs)], + TR10i, # sin-cos products > sin-cos of sums + TRmorrie, + [identity, TR8], # sin-cos products -> sin-cos of sums + [identity, lambda x: TR2i(TR2(x))], # tan -> sin-cos -> tan + [ + lambda x: _eapply(expand_mul, TR5(x), trigs), + lambda x: _eapply( + expand_mul, TR15(x), trigs)], # pos/neg powers of sin + [ + lambda x: _eapply(expand_mul, TR6(x), trigs), + lambda x: _eapply( + expand_mul, TR16(x), trigs)], # pos/neg powers of cos + TR111, # tan, sin, cos to neg power -> cot, csc, sec + [identity, TR2i], # sin-cos ratio to tan + [identity, lambda x: _eapply( + expand_mul, TR22(x), trigs)], # tan-cot to sec-csc + TR1, TR2, TR2i, + [identity, lambda x: _eapply( + factor_terms, TR12(x), trigs)], # expand tan of sum + )] + e = greedy(tree, objective=Lops)(e) + + if coeff is not None: + e = coeff * e + + return e + + +def _is_Expr(e): + """_eapply helper to tell whether ``e`` and all its args + are Exprs.""" + if isinstance(e, Derivative): + return _is_Expr(e.expr) + if not isinstance(e, Expr): + return False + return all(_is_Expr(i) for i in e.args) + + +def _eapply(func, e, cond=None): + """Apply ``func`` to ``e`` if all args are Exprs else only + apply it to those args that *are* Exprs.""" + if not isinstance(e, Expr): + return e + if _is_Expr(e) or not e.args: + return func(e) + return e.func(*[ + _eapply(func, ei) if (cond is None or cond(ei)) else ei + for ei in e.args]) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/__init__.py b/.venv/lib/python3.13/site-packages/sympy/stats/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..adb79261954924305c1837555d7d47cd53b8430b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/__init__.py @@ -0,0 +1,202 @@ +""" +SymPy statistics module + +Introduces a random variable type into the SymPy language. + +Random variables may be declared using prebuilt functions such as +Normal, Exponential, Coin, Die, etc... or built with functions like FiniteRV. + +Queries on random expressions can be made using the functions + +========================= ============================= + Expression Meaning +------------------------- ----------------------------- + ``P(condition)`` Probability + ``E(expression)`` Expected value + ``H(expression)`` Entropy + ``variance(expression)`` Variance + ``density(expression)`` Probability Density Function + ``sample(expression)`` Produce a realization + ``where(condition)`` Where the condition is true +========================= ============================= + +Examples +======== + +>>> from sympy.stats import P, E, variance, Die, Normal +>>> from sympy import simplify +>>> X, Y = Die('X', 6), Die('Y', 6) # Define two six sided dice +>>> Z = Normal('Z', 0, 1) # Declare a Normal random variable with mean 0, std 1 +>>> P(X>3) # Probability X is greater than 3 +1/2 +>>> E(X+Y) # Expectation of the sum of two dice +7 +>>> variance(X+Y) # Variance of the sum of two dice +35/6 +>>> simplify(P(Z>1)) # Probability of Z being greater than 1 +1/2 - erf(sqrt(2)/2)/2 + + +One could also create custom distribution and define custom random variables +as follows: + +1. If you want to create a Continuous Random Variable: + +>>> from sympy.stats import ContinuousRV, P, E +>>> from sympy import exp, Symbol, Interval, oo +>>> x = Symbol('x') +>>> pdf = exp(-x) # pdf of the Continuous Distribution +>>> Z = ContinuousRV(x, pdf, set=Interval(0, oo)) +>>> E(Z) +1 +>>> P(Z > 5) +exp(-5) + +1.1 To create an instance of Continuous Distribution: + +>>> from sympy.stats import ContinuousDistributionHandmade +>>> from sympy import Lambda +>>> dist = ContinuousDistributionHandmade(Lambda(x, pdf), set=Interval(0, oo)) +>>> dist.pdf(x) +exp(-x) + +2. If you want to create a Discrete Random Variable: + +>>> from sympy.stats import DiscreteRV, P, E +>>> from sympy import Symbol, S +>>> p = S(1)/2 +>>> x = Symbol('x', integer=True, positive=True) +>>> pdf = p*(1 - p)**(x - 1) +>>> D = DiscreteRV(x, pdf, set=S.Naturals) +>>> E(D) +2 +>>> P(D > 3) +1/8 + +2.1 To create an instance of Discrete Distribution: + +>>> from sympy.stats import DiscreteDistributionHandmade +>>> from sympy import Lambda +>>> dist = DiscreteDistributionHandmade(Lambda(x, pdf), set=S.Naturals) +>>> dist.pdf(x) +2**(1 - x)/2 + +3. If you want to create a Finite Random Variable: + +>>> from sympy.stats import FiniteRV, P, E +>>> from sympy import Rational, Eq +>>> pmf = {1: Rational(1, 3), 2: Rational(1, 6), 3: Rational(1, 4), 4: Rational(1, 4)} +>>> X = FiniteRV('X', pmf) +>>> E(X) +29/12 +>>> P(X > 3) +1/4 + +3.1 To create an instance of Finite Distribution: + +>>> from sympy.stats import FiniteDistributionHandmade +>>> dist = FiniteDistributionHandmade(pmf) +>>> dist.pmf(x) +Lambda(x, Piecewise((1/3, Eq(x, 1)), (1/6, Eq(x, 2)), (1/4, Eq(x, 3) | Eq(x, 4)), (0, True))) +""" + +__all__ = [ + 'P', 'E', 'H', 'density', 'where', 'given', 'sample', 'cdf','median', + 'characteristic_function', 'pspace', 'sample_iter', 'variance', 'std', + 'skewness', 'kurtosis', 'covariance', 'dependent', 'entropy', 'independent', + 'random_symbols', 'correlation', 'factorial_moment', 'moment', 'cmoment', + 'sampling_density', 'moment_generating_function', 'smoment', 'quantile', + 'coskewness', 'sample_stochastic_process', + + 'FiniteRV', 'DiscreteUniform', 'Die', 'Bernoulli', 'Coin', 'Binomial', + 'BetaBinomial', 'Hypergeometric', 'Rademacher', 'IdealSoliton', 'RobustSoliton', + 'FiniteDistributionHandmade', + + 'ContinuousRV', 'Arcsin', 'Benini', 'Beta', 'BetaNoncentral', 'BetaPrime', + 'BoundedPareto', 'Cauchy', 'Chi', 'ChiNoncentral', 'ChiSquared', 'Dagum', 'Davis', 'Erlang', + 'ExGaussian', 'Exponential', 'ExponentialPower', 'FDistribution', + 'FisherZ', 'Frechet', 'Gamma', 'GammaInverse', 'Gompertz', 'Gumbel', + 'Kumaraswamy', 'Laplace', 'Levy', 'Logistic','LogCauchy', 'LogLogistic', 'LogitNormal', 'LogNormal', 'Lomax', + 'Moyal', 'Maxwell', 'Nakagami', 'Normal', 'GaussianInverse', 'Pareto', 'PowerFunction', + 'QuadraticU', 'RaisedCosine', 'Rayleigh','Reciprocal', 'StudentT', 'ShiftedGompertz', + 'Trapezoidal', 'Triangular', 'Uniform', 'UniformSum', 'VonMises', 'Wald', + 'Weibull', 'WignerSemicircle', 'ContinuousDistributionHandmade', + + 'FlorySchulz', 'Geometric','Hermite', 'Logarithmic', 'NegativeBinomial', 'Poisson', 'Skellam', + 'YuleSimon', 'Zeta', 'DiscreteRV', 'DiscreteDistributionHandmade', + + 'JointRV', 'Dirichlet', 'GeneralizedMultivariateLogGamma', + 'GeneralizedMultivariateLogGammaOmega', 'Multinomial', 'MultivariateBeta', + 'MultivariateEwens', 'MultivariateT', 'NegativeMultinomial', + 'NormalGamma', 'MultivariateNormal', 'MultivariateLaplace', 'marginal_distribution', + + 'StochasticProcess', 'DiscreteTimeStochasticProcess', + 'DiscreteMarkovChain', 'TransitionMatrixOf', 'StochasticStateSpaceOf', + 'GeneratorMatrixOf', 'ContinuousMarkovChain', 'BernoulliProcess', + 'PoissonProcess', 'WienerProcess', 'GammaProcess', + + 'CircularEnsemble', 'CircularUnitaryEnsemble', + 'CircularOrthogonalEnsemble', 'CircularSymplecticEnsemble', + 'GaussianEnsemble', 'GaussianUnitaryEnsemble', + 'GaussianOrthogonalEnsemble', 'GaussianSymplecticEnsemble', + 'joint_eigen_distribution', 'JointEigenDistribution', + 'level_spacing_distribution', + + 'MatrixGamma', 'Wishart', 'MatrixNormal', 'MatrixStudentT', + + 'Probability', 'Expectation', 'Variance', 'Covariance', 'Moment', + 'CentralMoment', + + 'ExpectationMatrix', 'VarianceMatrix', 'CrossCovarianceMatrix' + +] +from .rv_interface import (P, E, H, density, where, given, sample, cdf, median, + characteristic_function, pspace, sample_iter, variance, std, skewness, + kurtosis, covariance, dependent, entropy, independent, random_symbols, + correlation, factorial_moment, moment, cmoment, sampling_density, + moment_generating_function, smoment, quantile, coskewness, + sample_stochastic_process) + +from .frv_types import (FiniteRV, DiscreteUniform, Die, Bernoulli, Coin, + Binomial, BetaBinomial, Hypergeometric, Rademacher, + FiniteDistributionHandmade, IdealSoliton, RobustSoliton) + +from .crv_types import (ContinuousRV, Arcsin, Benini, Beta, BetaNoncentral, + BetaPrime, BoundedPareto, Cauchy, Chi, ChiNoncentral, ChiSquared, + Dagum, Davis, Erlang, ExGaussian, Exponential, ExponentialPower, + FDistribution, FisherZ, Frechet, Gamma, GammaInverse, GaussianInverse, + Gompertz, Gumbel, Kumaraswamy, Laplace, Levy, Logistic, LogCauchy, + LogLogistic, LogitNormal, LogNormal, Lomax, Maxwell, Moyal, Nakagami, + Normal, Pareto, QuadraticU, RaisedCosine, Rayleigh, Reciprocal, + StudentT, PowerFunction, ShiftedGompertz, Trapezoidal, Triangular, + Uniform, UniformSum, VonMises, Wald, Weibull, WignerSemicircle, + ContinuousDistributionHandmade) + +from .drv_types import (FlorySchulz, Geometric, Hermite, Logarithmic, NegativeBinomial, Poisson, + Skellam, YuleSimon, Zeta, DiscreteRV, DiscreteDistributionHandmade) + +from .joint_rv_types import (JointRV, Dirichlet, + GeneralizedMultivariateLogGamma, GeneralizedMultivariateLogGammaOmega, + Multinomial, MultivariateBeta, MultivariateEwens, MultivariateT, + NegativeMultinomial, NormalGamma, MultivariateNormal, MultivariateLaplace, + marginal_distribution) + +from .stochastic_process_types import (StochasticProcess, + DiscreteTimeStochasticProcess, DiscreteMarkovChain, + TransitionMatrixOf, StochasticStateSpaceOf, GeneratorMatrixOf, + ContinuousMarkovChain, BernoulliProcess, PoissonProcess, WienerProcess, + GammaProcess) + +from .random_matrix_models import (CircularEnsemble, CircularUnitaryEnsemble, + CircularOrthogonalEnsemble, CircularSymplecticEnsemble, + GaussianEnsemble, GaussianUnitaryEnsemble, GaussianOrthogonalEnsemble, + GaussianSymplecticEnsemble, joint_eigen_distribution, + JointEigenDistribution, level_spacing_distribution) + +from .matrix_distributions import MatrixGamma, Wishart, MatrixNormal, MatrixStudentT + +from .symbolic_probability import (Probability, Expectation, Variance, + Covariance, Moment, CentralMoment) + +from .symbolic_multivariate_probability import (ExpectationMatrix, VarianceMatrix, + CrossCovarianceMatrix) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/compound_rv.py b/.venv/lib/python3.13/site-packages/sympy/stats/compound_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..27555f4233fe691bac303800a87736205acbdee6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/compound_rv.py @@ -0,0 +1,223 @@ +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.function import Lambda +from sympy.core.symbol import Dummy +from sympy.integrals.integrals import Integral +from sympy.stats.rv import (NamedArgsMixin, random_symbols, _symbol_converter, + PSpace, RandomSymbol, is_random, Distribution) +from sympy.stats.crv import ContinuousDistribution, SingleContinuousPSpace +from sympy.stats.drv import DiscreteDistribution, SingleDiscretePSpace +from sympy.stats.frv import SingleFiniteDistribution, SingleFinitePSpace +from sympy.stats.crv_types import ContinuousDistributionHandmade +from sympy.stats.drv_types import DiscreteDistributionHandmade +from sympy.stats.frv_types import FiniteDistributionHandmade + + +class CompoundPSpace(PSpace): + """ + A temporary Probability Space for the Compound Distribution. After + Marginalization, this returns the corresponding Probability Space of the + parent distribution. + """ + + def __new__(cls, s, distribution): + s = _symbol_converter(s) + if isinstance(distribution, ContinuousDistribution): + return SingleContinuousPSpace(s, distribution) + if isinstance(distribution, DiscreteDistribution): + return SingleDiscretePSpace(s, distribution) + if isinstance(distribution, SingleFiniteDistribution): + return SingleFinitePSpace(s, distribution) + if not isinstance(distribution, CompoundDistribution): + raise ValueError("%s should be an isinstance of " + "CompoundDistribution"%(distribution)) + return Basic.__new__(cls, s, distribution) + + @property + def value(self): + return RandomSymbol(self.symbol, self) + + @property + def symbol(self): + return self.args[0] + + @property + def is_Continuous(self): + return self.distribution.is_Continuous + + @property + def is_Finite(self): + return self.distribution.is_Finite + + @property + def is_Discrete(self): + return self.distribution.is_Discrete + + @property + def distribution(self): + return self.args[1] + + @property + def pdf(self): + return self.distribution.pdf(self.symbol) + + @property + def set(self): + return self.distribution.set + + @property + def domain(self): + return self._get_newpspace().domain + + def _get_newpspace(self, evaluate=False): + x = Dummy('x') + parent_dist = self.distribution.args[0] + func = Lambda(x, self.distribution.pdf(x, evaluate)) + new_pspace = self._transform_pspace(self.symbol, parent_dist, func) + if new_pspace is not None: + return new_pspace + message = ("Compound Distribution for %s is not implemented yet" % str(parent_dist)) + raise NotImplementedError(message) + + def _transform_pspace(self, sym, dist, pdf): + """ + This function returns the new pspace of the distribution using handmade + Distributions and their corresponding pspace. + """ + pdf = Lambda(sym, pdf(sym)) + _set = dist.set + if isinstance(dist, ContinuousDistribution): + return SingleContinuousPSpace(sym, ContinuousDistributionHandmade(pdf, _set)) + elif isinstance(dist, DiscreteDistribution): + return SingleDiscretePSpace(sym, DiscreteDistributionHandmade(pdf, _set)) + elif isinstance(dist, SingleFiniteDistribution): + dens = {k: pdf(k) for k in _set} + return SingleFinitePSpace(sym, FiniteDistributionHandmade(dens)) + + def compute_density(self, expr, *, compound_evaluate=True, **kwargs): + new_pspace = self._get_newpspace(compound_evaluate) + expr = expr.subs({self.value: new_pspace.value}) + return new_pspace.compute_density(expr, **kwargs) + + def compute_cdf(self, expr, *, compound_evaluate=True, **kwargs): + new_pspace = self._get_newpspace(compound_evaluate) + expr = expr.subs({self.value: new_pspace.value}) + return new_pspace.compute_cdf(expr, **kwargs) + + def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): + new_pspace = self._get_newpspace(evaluate) + expr = expr.subs({self.value: new_pspace.value}) + if rvs: + rvs = rvs.subs({self.value: new_pspace.value}) + if isinstance(new_pspace, SingleFinitePSpace): + return new_pspace.compute_expectation(expr, rvs, **kwargs) + return new_pspace.compute_expectation(expr, rvs, evaluate, **kwargs) + + def probability(self, condition, *, compound_evaluate=True, **kwargs): + new_pspace = self._get_newpspace(compound_evaluate) + condition = condition.subs({self.value: new_pspace.value}) + return new_pspace.probability(condition) + + def conditional_space(self, condition, *, compound_evaluate=True, **kwargs): + new_pspace = self._get_newpspace(compound_evaluate) + condition = condition.subs({self.value: new_pspace.value}) + return new_pspace.conditional_space(condition) + + +class CompoundDistribution(Distribution, NamedArgsMixin): + """ + Class for Compound Distributions. + + Parameters + ========== + + dist : Distribution + Distribution must contain a random parameter + + Examples + ======== + + >>> from sympy.stats.compound_rv import CompoundDistribution + >>> from sympy.stats.crv_types import NormalDistribution + >>> from sympy.stats import Normal + >>> from sympy.abc import x + >>> X = Normal('X', 2, 4) + >>> N = NormalDistribution(X, 4) + >>> C = CompoundDistribution(N) + >>> C.set + Interval(-oo, oo) + >>> C.pdf(x, evaluate=True).simplify() + exp(-x**2/64 + x/16 - 1/16)/(8*sqrt(pi)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Compound_probability_distribution + + """ + + def __new__(cls, dist): + if not isinstance(dist, (ContinuousDistribution, + SingleFiniteDistribution, DiscreteDistribution)): + message = "Compound Distribution for %s is not implemented yet" % str(dist) + raise NotImplementedError(message) + if not cls._compound_check(dist): + return dist + return Basic.__new__(cls, dist) + + @property + def set(self): + return self.args[0].set + + @property + def is_Continuous(self): + return isinstance(self.args[0], ContinuousDistribution) + + @property + def is_Finite(self): + return isinstance(self.args[0], SingleFiniteDistribution) + + @property + def is_Discrete(self): + return isinstance(self.args[0], DiscreteDistribution) + + def pdf(self, x, evaluate=False): + dist = self.args[0] + randoms = [rv for rv in dist.args if is_random(rv)] + if isinstance(dist, SingleFiniteDistribution): + y = Dummy('y', integer=True, negative=False) + expr = dist.pmf(y) + else: + y = Dummy('y') + expr = dist.pdf(y) + for rv in randoms: + expr = self._marginalise(expr, rv, evaluate) + return Lambda(y, expr)(x) + + def _marginalise(self, expr, rv, evaluate): + if isinstance(rv.pspace.distribution, SingleFiniteDistribution): + rv_dens = rv.pspace.distribution.pmf(rv) + else: + rv_dens = rv.pspace.distribution.pdf(rv) + rv_dom = rv.pspace.domain.set + if rv.pspace.is_Discrete or rv.pspace.is_Finite: + expr = Sum(expr*rv_dens, (rv, rv_dom._inf, + rv_dom._sup)) + else: + expr = Integral(expr*rv_dens, (rv, rv_dom._inf, + rv_dom._sup)) + if evaluate: + return expr.doit() + return expr + + @classmethod + def _compound_check(self, dist): + """ + Checks if the given distribution contains random parameters. + """ + randoms = [] + for arg in dist.args: + randoms.extend(random_symbols(arg)) + if len(randoms) == 0: + return False + return True diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/crv.py b/.venv/lib/python3.13/site-packages/sympy/stats/crv.py new file mode 100644 index 0000000000000000000000000000000000000000..0a5184029679f663c83d81aa6c1b6ca4d948c70f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/crv.py @@ -0,0 +1,570 @@ +""" +Continuous Random Variables Module + +See Also +======== +sympy.stats.crv_types +sympy.stats.rv +sympy.stats.frv +""" + + +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.function import Lambda, PoleError +from sympy.core.numbers import (I, nan, oo) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.core.sympify import _sympify, sympify +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.delta_functions import DiracDelta +from sympy.integrals.integrals import (Integral, integrate) +from sympy.logic.boolalg import (And, Or) +from sympy.polys.polyerrors import PolynomialError +from sympy.polys.polytools import poly +from sympy.series.series import series +from sympy.sets.sets import (FiniteSet, Intersection, Interval, Union) +from sympy.solvers.solveset import solveset +from sympy.solvers.inequalities import reduce_rational_inequalities +from sympy.stats.rv import (RandomDomain, SingleDomain, ConditionalDomain, is_random, + ProductDomain, PSpace, SinglePSpace, random_symbols, NamedArgsMixin, Distribution) + + +class ContinuousDomain(RandomDomain): + """ + A domain with continuous support + + Represented using symbols and Intervals. + """ + is_Continuous = True + + def as_boolean(self): + raise NotImplementedError("Not Implemented for generic Domains") + + +class SingleContinuousDomain(ContinuousDomain, SingleDomain): + """ + A univariate domain with continuous support + + Represented using a single symbol and interval. + """ + def compute_expectation(self, expr, variables=None, **kwargs): + if variables is None: + variables = self.symbols + if not variables: + return expr + if frozenset(variables) != frozenset(self.symbols): + raise ValueError("Values should be equal") + # assumes only intervals + return Integral(expr, (self.symbol, self.set), **kwargs) + + def as_boolean(self): + return self.set.as_relational(self.symbol) + + +class ProductContinuousDomain(ProductDomain, ContinuousDomain): + """ + A collection of independent domains with continuous support + """ + + def compute_expectation(self, expr, variables=None, **kwargs): + if variables is None: + variables = self.symbols + for domain in self.domains: + domain_vars = frozenset(variables) & frozenset(domain.symbols) + if domain_vars: + expr = domain.compute_expectation(expr, domain_vars, **kwargs) + return expr + + def as_boolean(self): + return And(*[domain.as_boolean() for domain in self.domains]) + + +class ConditionalContinuousDomain(ContinuousDomain, ConditionalDomain): + """ + A domain with continuous support that has been further restricted by a + condition such as $x > 3$. + """ + + def compute_expectation(self, expr, variables=None, **kwargs): + if variables is None: + variables = self.symbols + if not variables: + return expr + # Extract the full integral + fullintgrl = self.fulldomain.compute_expectation(expr, variables) + # separate into integrand and limits + integrand, limits = fullintgrl.function, list(fullintgrl.limits) + + conditions = [self.condition] + while conditions: + cond = conditions.pop() + if cond.is_Boolean: + if isinstance(cond, And): + conditions.extend(cond.args) + elif isinstance(cond, Or): + raise NotImplementedError("Or not implemented here") + elif cond.is_Relational: + if cond.is_Equality: + # Add the appropriate Delta to the integrand + integrand *= DiracDelta(cond.lhs - cond.rhs) + else: + symbols = cond.free_symbols & set(self.symbols) + if len(symbols) != 1: # Can't handle x > y + raise NotImplementedError( + "Multivariate Inequalities not yet implemented") + # Can handle x > 0 + symbol = symbols.pop() + # Find the limit with x, such as (x, -oo, oo) + for i, limit in enumerate(limits): + if limit[0] == symbol: + # Make condition into an Interval like [0, oo] + cintvl = reduce_rational_inequalities_wrap( + cond, symbol) + # Make limit into an Interval like [-oo, oo] + lintvl = Interval(limit[1], limit[2]) + # Intersect them to get [0, oo] + intvl = cintvl.intersect(lintvl) + # Put back into limits list + limits[i] = (symbol, intvl.left, intvl.right) + else: + raise TypeError( + "Condition %s is not a relational or Boolean" % cond) + + return Integral(integrand, *limits, **kwargs) + + def as_boolean(self): + return And(self.fulldomain.as_boolean(), self.condition) + + @property + def set(self): + if len(self.symbols) == 1: + return (self.fulldomain.set & reduce_rational_inequalities_wrap( + self.condition, tuple(self.symbols)[0])) + else: + raise NotImplementedError( + "Set of Conditional Domain not Implemented") + + +class ContinuousDistribution(Distribution): + def __call__(self, *args): + return self.pdf(*args) + + +class SingleContinuousDistribution(ContinuousDistribution, NamedArgsMixin): + """ Continuous distribution of a single variable. + + Explanation + =========== + + Serves as superclass for Normal/Exponential/UniformDistribution etc.... + + Represented by parameters for each of the specific classes. E.g + NormalDistribution is represented by a mean and standard deviation. + + Provides methods for pdf, cdf, and sampling. + + See Also + ======== + + sympy.stats.crv_types.* + """ + + set = Interval(-oo, oo) + + def __new__(cls, *args): + args = list(map(sympify, args)) + return Basic.__new__(cls, *args) + + @staticmethod + def check(*args): + pass + + @cacheit + def compute_cdf(self, **kwargs): + """ Compute the CDF from the PDF. + + Returns a Lambda. + """ + x, z = symbols('x, z', real=True, cls=Dummy) + left_bound = self.set.start + + # CDF is integral of PDF from left bound to z + pdf = self.pdf(x) + cdf = integrate(pdf.doit(), (x, left_bound, z), **kwargs) + # CDF Ensure that CDF left of left_bound is zero + cdf = Piecewise((cdf, z >= left_bound), (0, True)) + return Lambda(z, cdf) + + def _cdf(self, x): + return None + + def cdf(self, x, **kwargs): + """ Cumulative density function """ + if len(kwargs) == 0: + cdf = self._cdf(x) + if cdf is not None: + return cdf + return self.compute_cdf(**kwargs)(x) + + @cacheit + def compute_characteristic_function(self, **kwargs): + """ Compute the characteristic function from the PDF. + + Returns a Lambda. + """ + x, t = symbols('x, t', real=True, cls=Dummy) + pdf = self.pdf(x) + cf = integrate(exp(I*t*x)*pdf, (x, self.set)) + return Lambda(t, cf) + + def _characteristic_function(self, t): + return None + + def characteristic_function(self, t, **kwargs): + """ Characteristic function """ + if len(kwargs) == 0: + cf = self._characteristic_function(t) + if cf is not None: + return cf + return self.compute_characteristic_function(**kwargs)(t) + + @cacheit + def compute_moment_generating_function(self, **kwargs): + """ Compute the moment generating function from the PDF. + + Returns a Lambda. + """ + x, t = symbols('x, t', real=True, cls=Dummy) + pdf = self.pdf(x) + mgf = integrate(exp(t * x) * pdf, (x, self.set)) + return Lambda(t, mgf) + + def _moment_generating_function(self, t): + return None + + def moment_generating_function(self, t, **kwargs): + """ Moment generating function """ + if not kwargs: + mgf = self._moment_generating_function(t) + if mgf is not None: + return mgf + return self.compute_moment_generating_function(**kwargs)(t) + + def expectation(self, expr, var, evaluate=True, **kwargs): + """ Expectation of expression over distribution """ + if evaluate: + try: + p = poly(expr, var) + if p.is_zero: + return S.Zero + t = Dummy('t', real=True) + mgf = self._moment_generating_function(t) + if mgf is None: + return integrate(expr * self.pdf(var), (var, self.set), **kwargs) + deg = p.degree() + taylor = poly(series(mgf, t, 0, deg + 1).removeO(), t) + result = 0 + for k in range(deg+1): + result += p.coeff_monomial(var ** k) * taylor.coeff_monomial(t ** k) * factorial(k) + return result + except PolynomialError: + return integrate(expr * self.pdf(var), (var, self.set), **kwargs) + else: + return Integral(expr * self.pdf(var), (var, self.set), **kwargs) + + @cacheit + def compute_quantile(self, **kwargs): + """ Compute the Quantile from the PDF. + + Returns a Lambda. + """ + x, p = symbols('x, p', real=True, cls=Dummy) + left_bound = self.set.start + + pdf = self.pdf(x) + cdf = integrate(pdf, (x, left_bound, x), **kwargs) + quantile = solveset(cdf - p, x, self.set) + return Lambda(p, Piecewise((quantile, (p >= 0) & (p <= 1) ), (nan, True))) + + def _quantile(self, x): + return None + + def quantile(self, x, **kwargs): + """ Cumulative density function """ + if len(kwargs) == 0: + quantile = self._quantile(x) + if quantile is not None: + return quantile + return self.compute_quantile(**kwargs)(x) + + +class ContinuousPSpace(PSpace): + """ Continuous Probability Space + + Represents the likelihood of an event space defined over a continuum. + + Represented with a ContinuousDomain and a PDF (Lambda-Like) + """ + + is_Continuous = True + is_real = True + + @property + def pdf(self): + return self.density(*self.domain.symbols) + + def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): + if rvs is None: + rvs = self.values + else: + rvs = frozenset(rvs) + + expr = expr.xreplace({rv: rv.symbol for rv in rvs}) + + domain_symbols = frozenset(rv.symbol for rv in rvs) + + return self.domain.compute_expectation(self.pdf * expr, + domain_symbols, **kwargs) + + def compute_density(self, expr, **kwargs): + # Common case Density(X) where X in self.values + if expr in self.values: + # Marginalize all other random symbols out of the density + randomsymbols = tuple(set(self.values) - frozenset([expr])) + symbols = tuple(rs.symbol for rs in randomsymbols) + pdf = self.domain.compute_expectation(self.pdf, symbols, **kwargs) + return Lambda(expr.symbol, pdf) + + z = Dummy('z', real=True) + return Lambda(z, self.compute_expectation(DiracDelta(expr - z), **kwargs)) + + @cacheit + def compute_cdf(self, expr, **kwargs): + if not self.domain.set.is_Interval: + raise ValueError( + "CDF not well defined on multivariate expressions") + + d = self.compute_density(expr, **kwargs) + x, z = symbols('x, z', real=True, cls=Dummy) + left_bound = self.domain.set.start + + # CDF is integral of PDF from left bound to z + cdf = integrate(d(x), (x, left_bound, z), **kwargs) + # CDF Ensure that CDF left of left_bound is zero + cdf = Piecewise((cdf, z >= left_bound), (0, True)) + return Lambda(z, cdf) + + @cacheit + def compute_characteristic_function(self, expr, **kwargs): + if not self.domain.set.is_Interval: + raise NotImplementedError("Characteristic function of multivariate expressions not implemented") + + d = self.compute_density(expr, **kwargs) + x, t = symbols('x, t', real=True, cls=Dummy) + cf = integrate(exp(I*t*x)*d(x), (x, -oo, oo), **kwargs) + return Lambda(t, cf) + + @cacheit + def compute_moment_generating_function(self, expr, **kwargs): + if not self.domain.set.is_Interval: + raise NotImplementedError("Moment generating function of multivariate expressions not implemented") + + d = self.compute_density(expr, **kwargs) + x, t = symbols('x, t', real=True, cls=Dummy) + mgf = integrate(exp(t * x) * d(x), (x, -oo, oo), **kwargs) + return Lambda(t, mgf) + + @cacheit + def compute_quantile(self, expr, **kwargs): + if not self.domain.set.is_Interval: + raise ValueError( + "Quantile not well defined on multivariate expressions") + + d = self.compute_cdf(expr, **kwargs) + x = Dummy('x', real=True) + p = Dummy('p', positive=True) + + quantile = solveset(d(x) - p, x, self.set) + + return Lambda(p, quantile) + + def probability(self, condition, **kwargs): + z = Dummy('z', real=True) + cond_inv = False + if isinstance(condition, Ne): + condition = Eq(condition.args[0], condition.args[1]) + cond_inv = True + # Univariate case can be handled by where + try: + domain = self.where(condition) + rv = [rv for rv in self.values if rv.symbol == domain.symbol][0] + # Integrate out all other random variables + pdf = self.compute_density(rv, **kwargs) + # return S.Zero if `domain` is empty set + if domain.set is S.EmptySet or isinstance(domain.set, FiniteSet): + return S.Zero if not cond_inv else S.One + if isinstance(domain.set, Union): + return sum( + Integral(pdf(z), (z, subset), **kwargs) for subset in + domain.set.args if isinstance(subset, Interval)) + # Integrate out the last variable over the special domain + return Integral(pdf(z), (z, domain.set), **kwargs) + + # Other cases can be turned into univariate case + # by computing a density handled by density computation + except NotImplementedError: + from sympy.stats.rv import density + expr = condition.lhs - condition.rhs + if not is_random(expr): + dens = self.density + comp = condition.rhs + else: + dens = density(expr, **kwargs) + comp = 0 + if not isinstance(dens, ContinuousDistribution): + from sympy.stats.crv_types import ContinuousDistributionHandmade + dens = ContinuousDistributionHandmade(dens, set=self.domain.set) + # Turn problem into univariate case + space = SingleContinuousPSpace(z, dens) + result = space.probability(condition.__class__(space.value, comp)) + return result if not cond_inv else S.One - result + + def where(self, condition): + rvs = frozenset(random_symbols(condition)) + if not (len(rvs) == 1 and rvs.issubset(self.values)): + raise NotImplementedError( + "Multiple continuous random variables not supported") + rv = tuple(rvs)[0] + interval = reduce_rational_inequalities_wrap(condition, rv) + interval = interval.intersect(self.domain.set) + return SingleContinuousDomain(rv.symbol, interval) + + def conditional_space(self, condition, normalize=True, **kwargs): + condition = condition.xreplace({rv: rv.symbol for rv in self.values}) + domain = ConditionalContinuousDomain(self.domain, condition) + if normalize: + # create a clone of the variable to + # make sure that variables in nested integrals are different + # from the variables outside the integral + # this makes sure that they are evaluated separately + # and in the correct order + replacement = {rv: Dummy(str(rv)) for rv in self.symbols} + norm = domain.compute_expectation(self.pdf, **kwargs) + pdf = self.pdf / norm.xreplace(replacement) + # XXX: Converting set to tuple. The order matters to Lambda though + # so we shouldn't be starting with a set here... + density = Lambda(tuple(domain.symbols), pdf) + + return ContinuousPSpace(domain, density) + + +class SingleContinuousPSpace(ContinuousPSpace, SinglePSpace): + """ + A continuous probability space over a single univariate variable. + + These consist of a Symbol and a SingleContinuousDistribution + + This class is normally accessed through the various random variable + functions, Normal, Exponential, Uniform, etc.... + """ + + @property + def set(self): + return self.distribution.set + + @property + def domain(self): + return SingleContinuousDomain(sympify(self.symbol), self.set) + + def sample(self, size=(), library='scipy', seed=None): + """ + Internal sample method. + + Returns dictionary mapping RandomSymbol to realization value. + """ + return {self.value: self.distribution.sample(size, library=library, seed=seed)} + + def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): + rvs = rvs or (self.value,) + if self.value not in rvs: + return expr + + expr = _sympify(expr) + expr = expr.xreplace({rv: rv.symbol for rv in rvs}) + + x = self.value.symbol + try: + return self.distribution.expectation(expr, x, evaluate=evaluate, **kwargs) + except PoleError: + return Integral(expr * self.pdf, (x, self.set), **kwargs) + + def compute_cdf(self, expr, **kwargs): + if expr == self.value: + z = Dummy("z", real=True) + return Lambda(z, self.distribution.cdf(z, **kwargs)) + else: + return ContinuousPSpace.compute_cdf(self, expr, **kwargs) + + def compute_characteristic_function(self, expr, **kwargs): + if expr == self.value: + t = Dummy("t", real=True) + return Lambda(t, self.distribution.characteristic_function(t, **kwargs)) + else: + return ContinuousPSpace.compute_characteristic_function(self, expr, **kwargs) + + def compute_moment_generating_function(self, expr, **kwargs): + if expr == self.value: + t = Dummy("t", real=True) + return Lambda(t, self.distribution.moment_generating_function(t, **kwargs)) + else: + return ContinuousPSpace.compute_moment_generating_function(self, expr, **kwargs) + + def compute_density(self, expr, **kwargs): + # https://en.wikipedia.org/wiki/Random_variable#Functions_of_random_variables + if expr == self.value: + return self.density + y = Dummy('y', real=True) + + gs = solveset(expr - y, self.value, S.Reals) + + if isinstance(gs, Intersection): + if len(gs.args) == 2 and gs.args[0] is S.Reals: + gs = gs.args[1] + if not gs.is_FiniteSet: + raise ValueError("Can not solve %s for %s" % (expr, self.value)) + fx = self.compute_density(self.value) + fy = sum(fx(g) * abs(g.diff(y)) for g in gs) + return Lambda(y, fy) + + def compute_quantile(self, expr, **kwargs): + + if expr == self.value: + p = Dummy("p", real=True) + return Lambda(p, self.distribution.quantile(p, **kwargs)) + else: + return ContinuousPSpace.compute_quantile(self, expr, **kwargs) + +def _reduce_inequalities(conditions, var, **kwargs): + try: + return reduce_rational_inequalities(conditions, var, **kwargs) + except PolynomialError: + raise ValueError("Reduction of condition failed %s\n" % conditions[0]) + + +def reduce_rational_inequalities_wrap(condition, var): + if condition.is_Relational: + return _reduce_inequalities([[condition]], var, relational=False) + if isinstance(condition, Or): + return Union(*[_reduce_inequalities([[arg]], var, relational=False) + for arg in condition.args]) + if isinstance(condition, And): + intervals = [_reduce_inequalities([[arg]], var, relational=False) + for arg in condition.args] + I = intervals[0] + for i in intervals: + I = I.intersect(i) + return I diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/crv_types.py b/.venv/lib/python3.13/site-packages/sympy/stats/crv_types.py new file mode 100644 index 0000000000000000000000000000000000000000..073e7350fdf80aac39ecd1dd607488a8b76187e3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/crv_types.py @@ -0,0 +1,4732 @@ +""" +Continuous Random Variables - Prebuilt variables + +Contains +======== +Arcsin +Benini +Beta +BetaNoncentral +BetaPrime +BoundedPareto +Cauchy +Chi +ChiNoncentral +ChiSquared +Dagum +Davis +Erlang +ExGaussian +Exponential +ExponentialPower +FDistribution +FisherZ +Frechet +Gamma +GammaInverse +Gumbel +Gompertz +Kumaraswamy +Laplace +Levy +LogCauchy +Logistic +LogLogistic +LogitNormal +LogNormal +Lomax +Maxwell +Moyal +Nakagami +Normal +Pareto +PowerFunction +QuadraticU +RaisedCosine +Rayleigh +Reciprocal +ShiftedGompertz +StudentT +Trapezoidal +Triangular +Uniform +UniformSum +VonMises +Wald +Weibull +WignerSemicircle +""" + + + +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import (atan, cos, sin, tan) +from sympy.functions.special.bessel import (besseli, besselj, besselk) +from sympy.functions.special.beta_functions import beta as beta_fn +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.function import Lambda +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.complexes import (Abs, sign) +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.hyperbolic import sinh +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt, Max, Min +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import asin +from sympy.functions.special.error_functions import (erf, erfc, erfi, erfinv, expint) +from sympy.functions.special.gamma_functions import (gamma, lowergamma, uppergamma) +from sympy.functions.special.zeta_functions import zeta +from sympy.functions.special.hyper import hyper +from sympy.integrals.integrals import integrate +from sympy.logic.boolalg import And +from sympy.sets.sets import Interval +from sympy.matrices import MatrixBase +from sympy.stats.crv import SingleContinuousPSpace, SingleContinuousDistribution +from sympy.stats.rv import _value_check, is_random + +oo = S.Infinity + +__all__ = ['ContinuousRV', +'Arcsin', +'Benini', +'Beta', +'BetaNoncentral', +'BetaPrime', +'BoundedPareto', +'Cauchy', +'Chi', +'ChiNoncentral', +'ChiSquared', +'Dagum', +'Davis', +'Erlang', +'ExGaussian', +'Exponential', +'ExponentialPower', +'FDistribution', +'FisherZ', +'Frechet', +'Gamma', +'GammaInverse', +'Gompertz', +'Gumbel', +'Kumaraswamy', +'Laplace', +'Levy', +'LogCauchy', +'Logistic', +'LogLogistic', +'LogitNormal', +'LogNormal', +'Lomax', +'Maxwell', +'Moyal', +'Nakagami', +'Normal', +'GaussianInverse', +'Pareto', +'PowerFunction', +'QuadraticU', +'RaisedCosine', +'Rayleigh', +'Reciprocal', +'StudentT', +'ShiftedGompertz', +'Trapezoidal', +'Triangular', +'Uniform', +'UniformSum', +'VonMises', +'Wald', +'Weibull', +'WignerSemicircle', +] + + +@is_random.register(MatrixBase) +def _(x): + return any(is_random(i) for i in x) + +def rv(symbol, cls, args, **kwargs): + args = list(map(sympify, args)) + dist = cls(*args) + if kwargs.pop('check', True): + dist.check(*args) + pspace = SingleContinuousPSpace(symbol, dist) + if any(is_random(arg) for arg in args): + from sympy.stats.compound_rv import CompoundPSpace, CompoundDistribution + pspace = CompoundPSpace(symbol, CompoundDistribution(dist)) + return pspace.value + + +class ContinuousDistributionHandmade(SingleContinuousDistribution): + _argnames = ('pdf',) + + def __new__(cls, pdf, set=Interval(-oo, oo)): + return Basic.__new__(cls, pdf, set) + + @property + def set(self): + return self.args[1] + + @staticmethod + def check(pdf, set): + x = Dummy('x') + val = integrate(pdf(x), (x, set)) + _value_check(Eq(val, 1) != S.false, "The pdf on the given set is incorrect.") + + +def ContinuousRV(symbol, density, set=Interval(-oo, oo), **kwargs): + """ + Create a Continuous Random Variable given the following: + + Parameters + ========== + + symbol : Symbol + Represents name of the random variable. + density : Expression containing symbol + Represents probability density function. + set : set/Interval + Represents the region where the pdf is valid, by default is real line. + check : bool + If True, it will check whether the given density + integrates to 1 over the given set. If False, it + will not perform this check. Default is False. + + + Returns + ======= + + RandomSymbol + + Many common continuous random variable types are already implemented. + This function should be necessary only very rarely. + + + Examples + ======== + + >>> from sympy import Symbol, sqrt, exp, pi + >>> from sympy.stats import ContinuousRV, P, E + + >>> x = Symbol("x") + + >>> pdf = sqrt(2)*exp(-x**2/2)/(2*sqrt(pi)) # Normal distribution + >>> X = ContinuousRV(x, pdf) + + >>> E(X) + 0 + >>> P(X>0) + 1/2 + """ + pdf = Piecewise((density, set.as_relational(symbol)), (0, True)) + pdf = Lambda(symbol, pdf) + # have a default of False while `rv` should have a default of True + kwargs['check'] = kwargs.pop('check', False) + return rv(symbol.name, ContinuousDistributionHandmade, (pdf, set), **kwargs) + +######################################## +# Continuous Probability Distributions # +######################################## + +#------------------------------------------------------------------------------- +# Arcsin distribution ---------------------------------------------------------- + + +class ArcsinDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b') + + @property + def set(self): + return Interval(self.a, self.b) + + def pdf(self, x): + a, b = self.a, self.b + return 1/(pi*sqrt((x - a)*(b - x))) + + def _cdf(self, x): + a, b = self.a, self.b + return Piecewise( + (S.Zero, x < a), + (2*asin(sqrt((x - a)/(b - a)))/pi, x <= b), + (S.One, True)) + + +def Arcsin(name, a=0, b=1): + r""" + Create a Continuous Random Variable with an arcsin distribution. + + The density of the arcsin distribution is given by + + .. math:: + f(x) := \frac{1}{\pi\sqrt{(x-a)(b-x)}} + + with :math:`x \in (a,b)`. It must hold that :math:`-\infty < a < b < \infty`. + + Parameters + ========== + + a : Real number, the left interval boundary + b : Real number, the right interval boundary + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Arcsin, density, cdf + >>> from sympy import Symbol + + >>> a = Symbol("a", real=True) + >>> b = Symbol("b", real=True) + >>> z = Symbol("z") + + >>> X = Arcsin("x", a, b) + + >>> density(X)(z) + 1/(pi*sqrt((-a + z)*(b - z))) + + >>> cdf(X)(z) + Piecewise((0, a > z), + (2*asin(sqrt((-a + z)/(-a + b)))/pi, b >= z), + (1, True)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Arcsine_distribution + + """ + + return rv(name, ArcsinDistribution, (a, b)) + +#------------------------------------------------------------------------------- +# Benini distribution ---------------------------------------------------------- + + +class BeniniDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'beta', 'sigma') + + @staticmethod + def check(alpha, beta, sigma): + _value_check(alpha > 0, "Shape parameter Alpha must be positive.") + _value_check(beta > 0, "Shape parameter Beta must be positive.") + _value_check(sigma > 0, "Scale parameter Sigma must be positive.") + + @property + def set(self): + return Interval(self.sigma, oo) + + def pdf(self, x): + alpha, beta, sigma = self.alpha, self.beta, self.sigma + return (exp(-alpha*log(x/sigma) - beta*log(x/sigma)**2) + *(alpha/x + 2*beta*log(x/sigma)/x)) + + def _moment_generating_function(self, t): + raise NotImplementedError('The moment generating function of the ' + 'Benini distribution does not exist.') + +def Benini(name, alpha, beta, sigma): + r""" + Create a Continuous Random Variable with a Benini distribution. + + The density of the Benini distribution is given by + + .. math:: + f(x) := e^{-\alpha\log{\frac{x}{\sigma}} + -\beta\log^2\left[{\frac{x}{\sigma}}\right]} + \left(\frac{\alpha}{x}+\frac{2\beta\log{\frac{x}{\sigma}}}{x}\right) + + This is a heavy-tailed distribution and is also known as the log-Rayleigh + distribution. + + Parameters + ========== + + alpha : Real number, `\alpha > 0`, a shape + beta : Real number, `\beta > 0`, a shape + sigma : Real number, `\sigma > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Benini, density, cdf + >>> from sympy import Symbol, pprint + + >>> alpha = Symbol("alpha", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> sigma = Symbol("sigma", positive=True) + >>> z = Symbol("z") + + >>> X = Benini("x", alpha, beta, sigma) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + / / z \\ / z \ 2/ z \ + | 2*beta*log|-----|| - alpha*log|-----| - beta*log |-----| + |alpha \sigma/| \sigma/ \sigma/ + |----- + -----------------|*e + \ z z / + + >>> cdf(X)(z) + Piecewise((1 - exp(-alpha*log(z/sigma) - beta*log(z/sigma)**2), sigma <= z), + (0, True)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Benini_distribution + .. [2] https://reference.wolfram.com/legacy/v8/ref/BeniniDistribution.html + + """ + + return rv(name, BeniniDistribution, (alpha, beta, sigma)) + +#------------------------------------------------------------------------------- +# Beta distribution ------------------------------------------------------------ + + +class BetaDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'beta') + + set = Interval(0, 1) + + @staticmethod + def check(alpha, beta): + _value_check(alpha > 0, "Shape parameter Alpha must be positive.") + _value_check(beta > 0, "Shape parameter Beta must be positive.") + + def pdf(self, x): + alpha, beta = self.alpha, self.beta + return x**(alpha - 1) * (1 - x)**(beta - 1) / beta_fn(alpha, beta) + + def _characteristic_function(self, t): + return hyper((self.alpha,), (self.alpha + self.beta,), I*t) + + def _moment_generating_function(self, t): + return hyper((self.alpha,), (self.alpha + self.beta,), t) + + +def Beta(name, alpha, beta): + r""" + Create a Continuous Random Variable with a Beta distribution. + + The density of the Beta distribution is given by + + .. math:: + f(x) := \frac{x^{\alpha-1}(1-x)^{\beta-1}} {\mathrm{B}(\alpha,\beta)} + + with :math:`x \in [0,1]`. + + Parameters + ========== + + alpha : Real number, `\alpha > 0`, a shape + beta : Real number, `\beta > 0`, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Beta, density, E, variance + >>> from sympy import Symbol, simplify, pprint, factor + + >>> alpha = Symbol("alpha", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> z = Symbol("z") + + >>> X = Beta("x", alpha, beta) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + alpha - 1 beta - 1 + z *(1 - z) + -------------------------- + B(alpha, beta) + + >>> simplify(E(X)) + alpha/(alpha + beta) + + >>> factor(simplify(variance(X))) + alpha*beta/((alpha + beta)**2*(alpha + beta + 1)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Beta_distribution + .. [2] https://mathworld.wolfram.com/BetaDistribution.html + + """ + + return rv(name, BetaDistribution, (alpha, beta)) + +#------------------------------------------------------------------------------- +# Noncentral Beta distribution ------------------------------------------------------------ + + +class BetaNoncentralDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'beta', 'lamda') + + set = Interval(0, 1) + + @staticmethod + def check(alpha, beta, lamda): + _value_check(alpha > 0, "Shape parameter Alpha must be positive.") + _value_check(beta > 0, "Shape parameter Beta must be positive.") + _value_check(lamda >= 0, "Noncentrality parameter Lambda must be positive") + + def pdf(self, x): + alpha, beta, lamda = self.alpha, self.beta, self.lamda + k = Dummy("k") + return Sum(exp(-lamda / 2) * (lamda / 2)**k * x**(alpha + k - 1) *( + 1 - x)**(beta - 1) / (factorial(k) * beta_fn(alpha + k, beta)), (k, 0, oo)) + +def BetaNoncentral(name, alpha, beta, lamda): + r""" + Create a Continuous Random Variable with a Type I Noncentral Beta distribution. + + The density of the Noncentral Beta distribution is given by + + .. math:: + f(x) := \sum_{k=0}^\infty e^{-\lambda/2}\frac{(\lambda/2)^k}{k!} + \frac{x^{\alpha+k-1}(1-x)^{\beta-1}}{\mathrm{B}(\alpha+k,\beta)} + + with :math:`x \in [0,1]`. + + Parameters + ========== + + alpha : Real number, `\alpha > 0`, a shape + beta : Real number, `\beta > 0`, a shape + lamda : Real number, `\lambda \geq 0`, noncentrality parameter + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import BetaNoncentral, density, cdf + >>> from sympy import Symbol, pprint + + >>> alpha = Symbol("alpha", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> lamda = Symbol("lamda", nonnegative=True) + >>> z = Symbol("z") + + >>> X = BetaNoncentral("x", alpha, beta, lamda) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + oo + _____ + \ ` + \ -lamda + \ k ------- + \ k + alpha - 1 /lamda\ beta - 1 2 + ) z *|-----| *(1 - z) *e + / \ 2 / + / ------------------------------------------------ + / B(k + alpha, beta)*k! + /____, + k = 0 + + Compute cdf with specific 'x', 'alpha', 'beta' and 'lamda' values as follows: + + >>> cdf(BetaNoncentral("x", 1, 1, 1), evaluate=False)(2).doit() + 2*exp(1/2) + + The argument evaluate=False prevents an attempt at evaluation + of the sum for general x, before the argument 2 is passed. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Noncentral_beta_distribution + .. [2] https://reference.wolfram.com/language/ref/NoncentralBetaDistribution.html + + """ + + return rv(name, BetaNoncentralDistribution, (alpha, beta, lamda)) + + +#------------------------------------------------------------------------------- +# Beta prime distribution ------------------------------------------------------ + + +class BetaPrimeDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'beta') + + @staticmethod + def check(alpha, beta): + _value_check(alpha > 0, "Shape parameter Alpha must be positive.") + _value_check(beta > 0, "Shape parameter Beta must be positive.") + + set = Interval(0, oo) + + def pdf(self, x): + alpha, beta = self.alpha, self.beta + return x**(alpha - 1)*(1 + x)**(-alpha - beta)/beta_fn(alpha, beta) + +def BetaPrime(name, alpha, beta): + r""" + Create a continuous random variable with a Beta prime distribution. + + The density of the Beta prime distribution is given by + + .. math:: + f(x) := \frac{x^{\alpha-1} (1+x)^{-\alpha -\beta}}{B(\alpha,\beta)} + + with :math:`x > 0`. + + Parameters + ========== + + alpha : Real number, `\alpha > 0`, a shape + beta : Real number, `\beta > 0`, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import BetaPrime, density + >>> from sympy import Symbol, pprint + + >>> alpha = Symbol("alpha", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> z = Symbol("z") + + >>> X = BetaPrime("x", alpha, beta) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + alpha - 1 -alpha - beta + z *(z + 1) + ------------------------------- + B(alpha, beta) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Beta_prime_distribution + .. [2] https://mathworld.wolfram.com/BetaPrimeDistribution.html + + """ + + return rv(name, BetaPrimeDistribution, (alpha, beta)) + +#------------------------------------------------------------------------------- +# Bounded Pareto Distribution -------------------------------------------------- +class BoundedParetoDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'left', 'right') + + @property + def set(self): + return Interval(self.left, self.right) + + @staticmethod + def check(alpha, left, right): + _value_check (alpha.is_positive, "Shape must be positive.") + _value_check (left.is_positive, "Left value should be positive.") + _value_check (right > left, "Right should be greater than left.") + + def pdf(self, x): + alpha, left, right = self.alpha, self.left, self.right + num = alpha * (left**alpha) * x**(- alpha -1) + den = 1 - (left/right)**alpha + return num/den + +def BoundedPareto(name, alpha, left, right): + r""" + Create a continuous random variable with a Bounded Pareto distribution. + + The density of the Bounded Pareto distribution is given by + + .. math:: + f(x) := \frac{\alpha L^{\alpha}x^{-\alpha-1}}{1-(\frac{L}{H})^{\alpha}} + + Parameters + ========== + + alpha : Real Number, `\alpha > 0` + Shape parameter + left : Real Number, `left > 0` + Location parameter + right : Real Number, `right > left` + Location parameter + + Examples + ======== + + >>> from sympy.stats import BoundedPareto, density, cdf, E + >>> from sympy import symbols + >>> L, H = symbols('L, H', positive=True) + >>> X = BoundedPareto('X', 2, L, H) + >>> x = symbols('x') + >>> density(X)(x) + 2*L**2/(x**3*(1 - L**2/H**2)) + >>> cdf(X)(x) + Piecewise((-H**2*L**2/(x**2*(H**2 - L**2)) + H**2/(H**2 - L**2), L <= x), (0, True)) + >>> E(X).simplify() + 2*H*L/(H + L) + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Pareto_distribution#Bounded_Pareto_distribution + + """ + return rv (name, BoundedParetoDistribution, (alpha, left, right)) + +# ------------------------------------------------------------------------------ +# Cauchy distribution ---------------------------------------------------------- + + +class CauchyDistribution(SingleContinuousDistribution): + _argnames = ('x0', 'gamma') + + @staticmethod + def check(x0, gamma): + _value_check(gamma > 0, "Scale parameter Gamma must be positive.") + _value_check(x0.is_real, "Location parameter must be real.") + + def pdf(self, x): + return 1/(pi*self.gamma*(1 + ((x - self.x0)/self.gamma)**2)) + + def _cdf(self, x): + x0, gamma = self.x0, self.gamma + return (1/pi)*atan((x - x0)/gamma) + S.Half + + def _characteristic_function(self, t): + return exp(self.x0 * I * t - self.gamma * Abs(t)) + + def _moment_generating_function(self, t): + raise NotImplementedError("The moment generating function for the " + "Cauchy distribution does not exist.") + + def _quantile(self, p): + return self.x0 + self.gamma*tan(pi*(p - S.Half)) + + +def Cauchy(name, x0, gamma): + r""" + Create a continuous random variable with a Cauchy distribution. + + The density of the Cauchy distribution is given by + + .. math:: + f(x) := \frac{1}{\pi \gamma [1 + {(\frac{x-x_0}{\gamma})}^2]} + + Parameters + ========== + + x0 : Real number, the location + gamma : Real number, `\gamma > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Cauchy, density + >>> from sympy import Symbol + + >>> x0 = Symbol("x0") + >>> gamma = Symbol("gamma", positive=True) + >>> z = Symbol("z") + + >>> X = Cauchy("x", x0, gamma) + + >>> density(X)(z) + 1/(pi*gamma*(1 + (-x0 + z)**2/gamma**2)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Cauchy_distribution + .. [2] https://mathworld.wolfram.com/CauchyDistribution.html + + """ + + return rv(name, CauchyDistribution, (x0, gamma)) + +#------------------------------------------------------------------------------- +# Chi distribution ------------------------------------------------------------- + + +class ChiDistribution(SingleContinuousDistribution): + _argnames = ('k',) + + @staticmethod + def check(k): + _value_check(k > 0, "Number of degrees of freedom (k) must be positive.") + _value_check(k.is_integer, "Number of degrees of freedom (k) must be an integer.") + + set = Interval(0, oo) + + def pdf(self, x): + return 2**(1 - self.k/2)*x**(self.k - 1)*exp(-x**2/2)/gamma(self.k/2) + + def _characteristic_function(self, t): + k = self.k + + part_1 = hyper((k/2,), (S.Half,), -t**2/2) + part_2 = I*t*sqrt(2)*gamma((k+1)/2)/gamma(k/2) + part_3 = hyper(((k+1)/2,), (Rational(3, 2),), -t**2/2) + return part_1 + part_2*part_3 + + def _moment_generating_function(self, t): + k = self.k + + part_1 = hyper((k / 2,), (S.Half,), t ** 2 / 2) + part_2 = t * sqrt(2) * gamma((k + 1) / 2) / gamma(k / 2) + part_3 = hyper(((k + 1) / 2,), (S(3) / 2,), t ** 2 / 2) + return part_1 + part_2 * part_3 + +def Chi(name, k): + r""" + Create a continuous random variable with a Chi distribution. + + The density of the Chi distribution is given by + + .. math:: + f(x) := \frac{2^{1-k/2}x^{k-1}e^{-x^2/2}}{\Gamma(k/2)} + + with :math:`x \geq 0`. + + Parameters + ========== + + k : Positive integer, The number of degrees of freedom + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Chi, density, E + >>> from sympy import Symbol, simplify + + >>> k = Symbol("k", integer=True) + >>> z = Symbol("z") + + >>> X = Chi("x", k) + + >>> density(X)(z) + 2**(1 - k/2)*z**(k - 1)*exp(-z**2/2)/gamma(k/2) + + >>> simplify(E(X)) + sqrt(2)*gamma(k/2 + 1/2)/gamma(k/2) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Chi_distribution + .. [2] https://mathworld.wolfram.com/ChiDistribution.html + + """ + + return rv(name, ChiDistribution, (k,)) + +#------------------------------------------------------------------------------- +# Non-central Chi distribution ------------------------------------------------- + + +class ChiNoncentralDistribution(SingleContinuousDistribution): + _argnames = ('k', 'l') + + @staticmethod + def check(k, l): + _value_check(k > 0, "Number of degrees of freedom (k) must be positive.") + _value_check(k.is_integer, "Number of degrees of freedom (k) must be an integer.") + _value_check(l > 0, "Shift parameter Lambda must be positive.") + + set = Interval(0, oo) + + def pdf(self, x): + k, l = self.k, self.l + return exp(-(x**2+l**2)/2)*x**k*l / (l*x)**(k/2) * besseli(k/2-1, l*x) + +def ChiNoncentral(name, k, l): + r""" + Create a continuous random variable with a non-central Chi distribution. + + Explanation + =========== + + The density of the non-central Chi distribution is given by + + .. math:: + f(x) := \frac{e^{-(x^2+\lambda^2)/2} x^k\lambda} + {(\lambda x)^{k/2}} I_{k/2-1}(\lambda x) + + with `x \geq 0`. Here, `I_\nu (x)` is the + :ref:`modified Bessel function of the first kind `. + + Parameters + ========== + + k : A positive Integer, $k > 0$ + The number of degrees of freedom. + lambda : Real number, `\lambda > 0` + Shift parameter. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import ChiNoncentral, density + >>> from sympy import Symbol + + >>> k = Symbol("k", integer=True) + >>> l = Symbol("l") + >>> z = Symbol("z") + + >>> X = ChiNoncentral("x", k, l) + + >>> density(X)(z) + l*z**k*exp(-l**2/2 - z**2/2)*besseli(k/2 - 1, l*z)/(l*z)**(k/2) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Noncentral_chi_distribution + """ + + return rv(name, ChiNoncentralDistribution, (k, l)) + +#------------------------------------------------------------------------------- +# Chi squared distribution ----------------------------------------------------- + + +class ChiSquaredDistribution(SingleContinuousDistribution): + _argnames = ('k',) + + @staticmethod + def check(k): + _value_check(k > 0, "Number of degrees of freedom (k) must be positive.") + _value_check(k.is_integer, "Number of degrees of freedom (k) must be an integer.") + + set = Interval(0, oo) + + def pdf(self, x): + k = self.k + return 1/(2**(k/2)*gamma(k/2))*x**(k/2 - 1)*exp(-x/2) + + def _cdf(self, x): + k = self.k + return Piecewise( + (S.One/gamma(k/2)*lowergamma(k/2, x/2), x >= 0), + (0, True) + ) + + def _characteristic_function(self, t): + return (1 - 2*I*t)**(-self.k/2) + + def _moment_generating_function(self, t): + return (1 - 2*t)**(-self.k/2) + + +def ChiSquared(name, k): + r""" + Create a continuous random variable with a Chi-squared distribution. + + Explanation + =========== + + The density of the Chi-squared distribution is given by + + .. math:: + f(x) := \frac{1}{2^{\frac{k}{2}}\Gamma\left(\frac{k}{2}\right)} + x^{\frac{k}{2}-1} e^{-\frac{x}{2}} + + with :math:`x \geq 0`. + + Parameters + ========== + + k : Positive integer + The number of degrees of freedom. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import ChiSquared, density, E, variance, moment + >>> from sympy import Symbol + + >>> k = Symbol("k", integer=True, positive=True) + >>> z = Symbol("z") + + >>> X = ChiSquared("x", k) + + >>> density(X)(z) + z**(k/2 - 1)*exp(-z/2)/(2**(k/2)*gamma(k/2)) + + >>> E(X) + k + + >>> variance(X) + 2*k + + >>> moment(X, 3) + k**3 + 6*k**2 + 8*k + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Chi_squared_distribution + .. [2] https://mathworld.wolfram.com/Chi-SquaredDistribution.html + """ + + return rv(name, ChiSquaredDistribution, (k, )) + +#------------------------------------------------------------------------------- +# Dagum distribution ----------------------------------------------------------- + + +class DagumDistribution(SingleContinuousDistribution): + _argnames = ('p', 'a', 'b') + + set = Interval(0, oo) + + @staticmethod + def check(p, a, b): + _value_check(p > 0, "Shape parameter p must be positive.") + _value_check(a > 0, "Shape parameter a must be positive.") + _value_check(b > 0, "Scale parameter b must be positive.") + + def pdf(self, x): + p, a, b = self.p, self.a, self.b + return a*p/x*((x/b)**(a*p)/(((x/b)**a + 1)**(p + 1))) + + def _cdf(self, x): + p, a, b = self.p, self.a, self.b + return Piecewise(((S.One + (S(x)/b)**-a)**-p, x>=0), + (S.Zero, True)) + +def Dagum(name, p, a, b): + r""" + Create a continuous random variable with a Dagum distribution. + + Explanation + =========== + + The density of the Dagum distribution is given by + + .. math:: + f(x) := \frac{a p}{x} \left( \frac{\left(\tfrac{x}{b}\right)^{a p}} + {\left(\left(\tfrac{x}{b}\right)^a + 1 \right)^{p+1}} \right) + + with :math:`x > 0`. + + Parameters + ========== + + p : Real number + `p > 0`, a shape. + a : Real number + `a > 0`, a shape. + b : Real number + `b > 0`, a scale. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Dagum, density, cdf + >>> from sympy import Symbol + + >>> p = Symbol("p", positive=True) + >>> a = Symbol("a", positive=True) + >>> b = Symbol("b", positive=True) + >>> z = Symbol("z") + + >>> X = Dagum("x", p, a, b) + + >>> density(X)(z) + a*p*(z/b)**(a*p)*((z/b)**a + 1)**(-p - 1)/z + + >>> cdf(X)(z) + Piecewise(((1 + (z/b)**(-a))**(-p), z >= 0), (0, True)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Dagum_distribution + + """ + + return rv(name, DagumDistribution, (p, a, b)) + +#------------------------------------------------------------------------------- +# Davis distribution ----------------------------------------------------------- + +class DavisDistribution(SingleContinuousDistribution): + _argnames = ('b', 'n', 'mu') + + set = Interval(0, oo) + + @staticmethod + def check(b, n, mu): + _value_check(b > 0, "Scale parameter b must be positive.") + _value_check(n > 1, "Shape parameter n must be above 1.") + _value_check(mu > 0, "Location parameter mu must be positive.") + + def pdf(self, x): + b, n, mu = self.b, self.n, self.mu + dividend = b**n*(x - mu)**(-1-n) + divisor = (exp(b/(x-mu))-1)*(gamma(n)*zeta(n)) + return dividend/divisor + + +def Davis(name, b, n, mu): + r""" Create a continuous random variable with Davis distribution. + + Explanation + =========== + + The density of Davis distribution is given by + + .. math:: + f(x; \mu; b, n) := \frac{b^{n}(x - \mu)^{1-n}}{ \left( e^{\frac{b}{x-\mu}} - 1 \right) \Gamma(n)\zeta(n)} + + with :math:`x \in [0,\infty]`. + + Davis distribution is a generalization of the Planck's law of radiation from statistical physics. It is used for modeling income distribution. + + Parameters + ========== + b : Real number + `p > 0`, a scale. + n : Real number + `n > 1`, a shape. + mu : Real number + `mu > 0`, a location. + + Returns + ======= + + RandomSymbol + + Examples + ======== + >>> from sympy.stats import Davis, density + >>> from sympy import Symbol + >>> b = Symbol("b", positive=True) + >>> n = Symbol("n", positive=True) + >>> mu = Symbol("mu", positive=True) + >>> z = Symbol("z") + >>> X = Davis("x", b, n, mu) + >>> density(X)(z) + b**n*(-mu + z)**(-n - 1)/((exp(b/(-mu + z)) - 1)*gamma(n)*zeta(n)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Davis_distribution + .. [2] https://reference.wolfram.com/language/ref/DavisDistribution.html + + """ + return rv(name, DavisDistribution, (b, n, mu)) + + +#------------------------------------------------------------------------------- +# Erlang distribution ---------------------------------------------------------- + + +def Erlang(name, k, l): + r""" + Create a continuous random variable with an Erlang distribution. + + Explanation + =========== + + The density of the Erlang distribution is given by + + .. math:: + f(x) := \frac{\lambda^k x^{k-1} e^{-\lambda x}}{(k-1)!} + + with :math:`x \in [0,\infty]`. + + Parameters + ========== + + k : Positive integer + l : Real number, `\lambda > 0`, the rate + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Erlang, density, cdf, E, variance + >>> from sympy import Symbol, simplify, pprint + + >>> k = Symbol("k", integer=True, positive=True) + >>> l = Symbol("l", positive=True) + >>> z = Symbol("z") + + >>> X = Erlang("x", k, l) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + k k - 1 -l*z + l *z *e + --------------- + Gamma(k) + + >>> C = cdf(X)(z) + >>> pprint(C, use_unicode=False) + /lowergamma(k, l*z) + |------------------ for z > 0 + < Gamma(k) + | + \ 0 otherwise + + + >>> E(X) + k/l + + >>> simplify(variance(X)) + k/l**2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Erlang_distribution + .. [2] https://mathworld.wolfram.com/ErlangDistribution.html + + """ + + return rv(name, GammaDistribution, (k, S.One/l)) + +# ------------------------------------------------------------------------------- +# ExGaussian distribution ----------------------------------------------------- + + +class ExGaussianDistribution(SingleContinuousDistribution): + _argnames = ('mean', 'std', 'rate') + + set = Interval(-oo, oo) + + @staticmethod + def check(mean, std, rate): + _value_check( + std > 0, "Standard deviation of ExGaussian must be positive.") + _value_check(rate > 0, "Rate of ExGaussian must be positive.") + + def pdf(self, x): + mean, std, rate = self.mean, self.std, self.rate + term1 = rate/2 + term2 = exp(rate * (2 * mean + rate * std**2 - 2*x)/2) + term3 = erfc((mean + rate*std**2 - x)/(sqrt(2)*std)) + return term1*term2*term3 + + def _cdf(self, x): + from sympy.stats import cdf + mean, std, rate = self.mean, self.std, self.rate + u = rate*(x - mean) + v = rate*std + GaussianCDF1 = cdf(Normal('x', 0, v))(u) + GaussianCDF2 = cdf(Normal('x', v**2, v))(u) + + return GaussianCDF1 - exp(-u + (v**2/2) + log(GaussianCDF2)) + + def _characteristic_function(self, t): + mean, std, rate = self.mean, self.std, self.rate + term1 = (1 - I*t/rate)**(-1) + term2 = exp(I*mean*t - std**2*t**2/2) + return term1 * term2 + + def _moment_generating_function(self, t): + mean, std, rate = self.mean, self.std, self.rate + term1 = (1 - t/rate)**(-1) + term2 = exp(mean*t + std**2*t**2/2) + return term1*term2 + + +def ExGaussian(name, mean, std, rate): + r""" + Create a continuous random variable with an Exponentially modified + Gaussian (EMG) distribution. + + Explanation + =========== + + The density of the exponentially modified Gaussian distribution is given by + + .. math:: + f(x) := \frac{\lambda}{2}e^{\frac{\lambda}{2}(2\mu+\lambda\sigma^2-2x)} + \text{erfc}(\frac{\mu + \lambda\sigma^2 - x}{\sqrt{2}\sigma}) + + with $x > 0$. Note that the expected value is `1/\lambda`. + + Parameters + ========== + + name : A string giving a name for this distribution + mean : A Real number, the mean of Gaussian component + std : A positive Real number, + :math: `\sigma^2 > 0` the variance of Gaussian component + rate : A positive Real number, + :math: `\lambda > 0` the rate of Exponential component + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import ExGaussian, density, cdf, E + >>> from sympy.stats import variance, skewness + >>> from sympy import Symbol, pprint, simplify + + >>> mean = Symbol("mu") + >>> std = Symbol("sigma", positive=True) + >>> rate = Symbol("lamda", positive=True) + >>> z = Symbol("z") + >>> X = ExGaussian("x", mean, std, rate) + + >>> pprint(density(X)(z), use_unicode=False) + / 2 \ + lamda*\lamda*sigma + 2*mu - 2*z/ + --------------------------------- / ___ / 2 \\ + 2 |\/ 2 *\lamda*sigma + mu - z/| + lamda*e *erfc|-----------------------------| + \ 2*sigma / + ---------------------------------------------------------------------------- + 2 + + >>> cdf(X)(z) + -(erf(sqrt(2)*(-lamda**2*sigma**2 + lamda*(-mu + z))/(2*lamda*sigma))/2 + 1/2)*exp(lamda**2*sigma**2/2 - lamda*(-mu + z)) + erf(sqrt(2)*(-mu + z)/(2*sigma))/2 + 1/2 + + >>> E(X) + (lamda*mu + 1)/lamda + + >>> simplify(variance(X)) + sigma**2 + lamda**(-2) + + >>> simplify(skewness(X)) + 2/(lamda**2*sigma**2 + 1)**(3/2) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Exponentially_modified_Gaussian_distribution + """ + return rv(name, ExGaussianDistribution, (mean, std, rate)) + +#------------------------------------------------------------------------------- +# Exponential distribution ----------------------------------------------------- + + +class ExponentialDistribution(SingleContinuousDistribution): + _argnames = ('rate',) + + set = Interval(0, oo) + + @staticmethod + def check(rate): + _value_check(rate > 0, "Rate must be positive.") + + def pdf(self, x): + return self.rate * exp(-self.rate*x) + + def _cdf(self, x): + return Piecewise( + (S.One - exp(-self.rate*x), x >= 0), + (0, True), + ) + + def _characteristic_function(self, t): + rate = self.rate + return rate / (rate - I*t) + + def _moment_generating_function(self, t): + rate = self.rate + return rate / (rate - t) + + def _quantile(self, p): + return -log(1-p)/self.rate + + +def Exponential(name, rate): + r""" + Create a continuous random variable with an Exponential distribution. + + Explanation + =========== + + The density of the exponential distribution is given by + + .. math:: + f(x) := \lambda \exp(-\lambda x) + + with $x > 0$. Note that the expected value is `1/\lambda`. + + Parameters + ========== + + rate : A positive Real number, `\lambda > 0`, the rate (or inverse scale/inverse mean) + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Exponential, density, cdf, E + >>> from sympy.stats import variance, std, skewness, quantile + >>> from sympy import Symbol + + >>> l = Symbol("lambda", positive=True) + >>> z = Symbol("z") + >>> p = Symbol("p") + >>> X = Exponential("x", l) + + >>> density(X)(z) + lambda*exp(-lambda*z) + + >>> cdf(X)(z) + Piecewise((1 - exp(-lambda*z), z >= 0), (0, True)) + + >>> quantile(X)(p) + -log(1 - p)/lambda + + >>> E(X) + 1/lambda + + >>> variance(X) + lambda**(-2) + + >>> skewness(X) + 2 + + >>> X = Exponential('x', 10) + + >>> density(X)(z) + 10*exp(-10*z) + + >>> E(X) + 1/10 + + >>> std(X) + 1/10 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Exponential_distribution + .. [2] https://mathworld.wolfram.com/ExponentialDistribution.html + + """ + + return rv(name, ExponentialDistribution, (rate, )) + + +# ------------------------------------------------------------------------------- +# Exponential Power distribution ----------------------------------------------------- + +class ExponentialPowerDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'alpha', 'beta') + + set = Interval(-oo, oo) + + @staticmethod + def check(mu, alpha, beta): + _value_check(alpha > 0, "Scale parameter alpha must be positive.") + _value_check(beta > 0, "Shape parameter beta must be positive.") + + def pdf(self, x): + mu, alpha, beta = self.mu, self.alpha, self.beta + num = beta*exp(-(Abs(x - mu)/alpha)**beta) + den = 2*alpha*gamma(1/beta) + return num/den + + def _cdf(self, x): + mu, alpha, beta = self.mu, self.alpha, self.beta + num = lowergamma(1/beta, (Abs(x - mu) / alpha)**beta) + den = 2*gamma(1/beta) + return sign(x - mu)*num/den + S.Half + + +def ExponentialPower(name, mu, alpha, beta): + r""" + Create a Continuous Random Variable with Exponential Power distribution. + This distribution is known also as Generalized Normal + distribution version 1. + + Explanation + =========== + + The density of the Exponential Power distribution is given by + + .. math:: + f(x) := \frac{\beta}{2\alpha\Gamma(\frac{1}{\beta})} + e^{{-(\frac{|x - \mu|}{\alpha})^{\beta}}} + + with :math:`x \in [ - \infty, \infty ]`. + + Parameters + ========== + + mu : Real number + A location. + alpha : Real number,`\alpha > 0` + A scale. + beta : Real number, `\beta > 0` + A shape. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import ExponentialPower, density, cdf + >>> from sympy import Symbol, pprint + >>> z = Symbol("z") + >>> mu = Symbol("mu") + >>> alpha = Symbol("alpha", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> X = ExponentialPower("x", mu, alpha, beta) + >>> pprint(density(X)(z), use_unicode=False) + beta + /|mu - z|\ + -|--------| + \ alpha / + beta*e + --------------------- + / 1 \ + 2*alpha*Gamma|----| + \beta/ + >>> cdf(X)(z) + 1/2 + lowergamma(1/beta, (Abs(mu - z)/alpha)**beta)*sign(-mu + z)/(2*gamma(1/beta)) + + References + ========== + + .. [1] https://reference.wolfram.com/language/ref/ExponentialPowerDistribution.html + .. [2] https://en.wikipedia.org/wiki/Generalized_normal_distribution#Version_1 + + """ + return rv(name, ExponentialPowerDistribution, (mu, alpha, beta)) + + +#------------------------------------------------------------------------------- +# F distribution --------------------------------------------------------------- + + +class FDistributionDistribution(SingleContinuousDistribution): + _argnames = ('d1', 'd2') + + set = Interval(0, oo) + + @staticmethod + def check(d1, d2): + _value_check((d1 > 0, d1.is_integer), + "Degrees of freedom d1 must be positive integer.") + _value_check((d2 > 0, d2.is_integer), + "Degrees of freedom d2 must be positive integer.") + + def pdf(self, x): + d1, d2 = self.d1, self.d2 + return (sqrt((d1*x)**d1*d2**d2 / (d1*x+d2)**(d1+d2)) + / (x * beta_fn(d1/2, d2/2))) + + def _moment_generating_function(self, t): + raise NotImplementedError('The moment generating function for the ' + 'F-distribution does not exist.') + +def FDistribution(name, d1, d2): + r""" + Create a continuous random variable with a F distribution. + + Explanation + =========== + + The density of the F distribution is given by + + .. math:: + f(x) := \frac{\sqrt{\frac{(d_1 x)^{d_1} d_2^{d_2}} + {(d_1 x + d_2)^{d_1 + d_2}}}} + {x \mathrm{B} \left(\frac{d_1}{2}, \frac{d_2}{2}\right)} + + with :math:`x > 0`. + + Parameters + ========== + + d1 : `d_1 > 0`, where `d_1` is the degrees of freedom (`n_1 - 1`) + d2 : `d_2 > 0`, where `d_2` is the degrees of freedom (`n_2 - 1`) + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import FDistribution, density + >>> from sympy import Symbol, pprint + + >>> d1 = Symbol("d1", positive=True) + >>> d2 = Symbol("d2", positive=True) + >>> z = Symbol("z") + + >>> X = FDistribution("x", d1, d2) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + d2 + -- ______________________________ + 2 / d1 -d1 - d2 + d2 *\/ (d1*z) *(d1*z + d2) + -------------------------------------- + /d1 d2\ + z*B|--, --| + \2 2 / + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/F-distribution + .. [2] https://mathworld.wolfram.com/F-Distribution.html + + """ + + return rv(name, FDistributionDistribution, (d1, d2)) + +#------------------------------------------------------------------------------- +# Fisher Z distribution -------------------------------------------------------- + +class FisherZDistribution(SingleContinuousDistribution): + _argnames = ('d1', 'd2') + + set = Interval(-oo, oo) + + @staticmethod + def check(d1, d2): + _value_check(d1 > 0, "Degree of freedom d1 must be positive.") + _value_check(d2 > 0, "Degree of freedom d2 must be positive.") + + def pdf(self, x): + d1, d2 = self.d1, self.d2 + return (2*d1**(d1/2)*d2**(d2/2) / beta_fn(d1/2, d2/2) * + exp(d1*x) / (d1*exp(2*x)+d2)**((d1+d2)/2)) + +def FisherZ(name, d1, d2): + r""" + Create a Continuous Random Variable with an Fisher's Z distribution. + + Explanation + =========== + + The density of the Fisher's Z distribution is given by + + .. math:: + f(x) := \frac{2d_1^{d_1/2} d_2^{d_2/2}} {\mathrm{B}(d_1/2, d_2/2)} + \frac{e^{d_1z}}{\left(d_1e^{2z}+d_2\right)^{\left(d_1+d_2\right)/2}} + + + .. TODO - What is the difference between these degrees of freedom? + + Parameters + ========== + + d1 : `d_1 > 0` + Degree of freedom. + d2 : `d_2 > 0` + Degree of freedom. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import FisherZ, density + >>> from sympy import Symbol, pprint + + >>> d1 = Symbol("d1", positive=True) + >>> d2 = Symbol("d2", positive=True) + >>> z = Symbol("z") + + >>> X = FisherZ("x", d1, d2) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + d1 d2 + d1 d2 - -- - -- + -- -- 2 2 + 2 2 / 2*z \ d1*z + 2*d1 *d2 *\d1*e + d2/ *e + ----------------------------------------- + /d1 d2\ + B|--, --| + \2 2 / + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fisher%27s_z-distribution + .. [2] https://mathworld.wolfram.com/Fishersz-Distribution.html + + """ + + return rv(name, FisherZDistribution, (d1, d2)) + +#------------------------------------------------------------------------------- +# Frechet distribution --------------------------------------------------------- + +class FrechetDistribution(SingleContinuousDistribution): + _argnames = ('a', 's', 'm') + + set = Interval(0, oo) + + @staticmethod + def check(a, s, m): + _value_check(a > 0, "Shape parameter alpha must be positive.") + _value_check(s > 0, "Scale parameter s must be positive.") + + def __new__(cls, a, s=1, m=0): + a, s, m = list(map(sympify, (a, s, m))) + return Basic.__new__(cls, a, s, m) + + def pdf(self, x): + a, s, m = self.a, self.s, self.m + return a/s * ((x-m)/s)**(-1-a) * exp(-((x-m)/s)**(-a)) + + def _cdf(self, x): + a, s, m = self.a, self.s, self.m + return Piecewise((exp(-((x-m)/s)**(-a)), x >= m), + (S.Zero, True)) + +def Frechet(name, a, s=1, m=0): + r""" + Create a continuous random variable with a Frechet distribution. + + Explanation + =========== + + The density of the Frechet distribution is given by + + .. math:: + f(x) := \frac{\alpha}{s} \left(\frac{x-m}{s}\right)^{-1-\alpha} + e^{-(\frac{x-m}{s})^{-\alpha}} + + with :math:`x \geq m`. + + Parameters + ========== + + a : Real number, :math:`a \in \left(0, \infty\right)` the shape + s : Real number, :math:`s \in \left(0, \infty\right)` the scale + m : Real number, :math:`m \in \left(-\infty, \infty\right)` the minimum + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Frechet, density, cdf + >>> from sympy import Symbol + + >>> a = Symbol("a", positive=True) + >>> s = Symbol("s", positive=True) + >>> m = Symbol("m", real=True) + >>> z = Symbol("z") + + >>> X = Frechet("x", a, s, m) + + >>> density(X)(z) + a*((-m + z)/s)**(-a - 1)*exp(-1/((-m + z)/s)**a)/s + + >>> cdf(X)(z) + Piecewise((exp(-1/((-m + z)/s)**a), m <= z), (0, True)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fr%C3%A9chet_distribution + + """ + + return rv(name, FrechetDistribution, (a, s, m)) + +#------------------------------------------------------------------------------- +# Gamma distribution ----------------------------------------------------------- + + +class GammaDistribution(SingleContinuousDistribution): + _argnames = ('k', 'theta') + + set = Interval(0, oo) + + @staticmethod + def check(k, theta): + _value_check(k > 0, "k must be positive") + _value_check(theta > 0, "Theta must be positive") + + def pdf(self, x): + k, theta = self.k, self.theta + return x**(k - 1) * exp(-x/theta) / (gamma(k)*theta**k) + + def _cdf(self, x): + k, theta = self.k, self.theta + return Piecewise( + (lowergamma(k, S(x)/theta)/gamma(k), x > 0), + (S.Zero, True)) + + def _characteristic_function(self, t): + return (1 - self.theta*I*t)**(-self.k) + + def _moment_generating_function(self, t): + return (1- self.theta*t)**(-self.k) + + +def Gamma(name, k, theta): + r""" + Create a continuous random variable with a Gamma distribution. + + Explanation + =========== + + The density of the Gamma distribution is given by + + .. math:: + f(x) := \frac{1}{\Gamma(k) \theta^k} x^{k - 1} e^{-\frac{x}{\theta}} + + with :math:`x \in [0,1]`. + + Parameters + ========== + + k : Real number, `k > 0`, a shape + theta : Real number, `\theta > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Gamma, density, cdf, E, variance + >>> from sympy import Symbol, pprint, simplify + + >>> k = Symbol("k", positive=True) + >>> theta = Symbol("theta", positive=True) + >>> z = Symbol("z") + + >>> X = Gamma("x", k, theta) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + -z + ----- + -k k - 1 theta + theta *z *e + --------------------- + Gamma(k) + + >>> C = cdf(X, meijerg=True)(z) + >>> pprint(C, use_unicode=False) + / / z \ + |k*lowergamma|k, -----| + | \ theta/ + <---------------------- for z >= 0 + | Gamma(k + 1) + | + \ 0 otherwise + + >>> E(X) + k*theta + + >>> V = simplify(variance(X)) + >>> pprint(V, use_unicode=False) + 2 + k*theta + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gamma_distribution + .. [2] https://mathworld.wolfram.com/GammaDistribution.html + + """ + + return rv(name, GammaDistribution, (k, theta)) + +#------------------------------------------------------------------------------- +# Inverse Gamma distribution --------------------------------------------------- + + +class GammaInverseDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b') + + set = Interval(0, oo) + + @staticmethod + def check(a, b): + _value_check(a > 0, "alpha must be positive") + _value_check(b > 0, "beta must be positive") + + def pdf(self, x): + a, b = self.a, self.b + return b**a/gamma(a) * x**(-a-1) * exp(-b/x) + + def _cdf(self, x): + a, b = self.a, self.b + return Piecewise((uppergamma(a,b/x)/gamma(a), x > 0), + (S.Zero, True)) + + def _characteristic_function(self, t): + a, b = self.a, self.b + return 2 * (-I*b*t)**(a/2) * besselk(a, sqrt(-4*I*b*t)) / gamma(a) + + def _moment_generating_function(self, t): + raise NotImplementedError('The moment generating function for the ' + 'gamma inverse distribution does not exist.') + +def GammaInverse(name, a, b): + r""" + Create a continuous random variable with an inverse Gamma distribution. + + Explanation + =========== + + The density of the inverse Gamma distribution is given by + + .. math:: + f(x) := \frac{\beta^\alpha}{\Gamma(\alpha)} x^{-\alpha - 1} + \exp\left(\frac{-\beta}{x}\right) + + with :math:`x > 0`. + + Parameters + ========== + + a : Real number, `a > 0`, a shape + b : Real number, `b > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import GammaInverse, density, cdf + >>> from sympy import Symbol, pprint + + >>> a = Symbol("a", positive=True) + >>> b = Symbol("b", positive=True) + >>> z = Symbol("z") + + >>> X = GammaInverse("x", a, b) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + -b + --- + a -a - 1 z + b *z *e + --------------- + Gamma(a) + + >>> cdf(X)(z) + Piecewise((uppergamma(a, b/z)/gamma(a), z > 0), (0, True)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Inverse-gamma_distribution + + """ + + return rv(name, GammaInverseDistribution, (a, b)) + + +#------------------------------------------------------------------------------- +# Gumbel distribution (Maximum and Minimum) -------------------------------------------------------- + + +class GumbelDistribution(SingleContinuousDistribution): + _argnames = ('beta', 'mu', 'minimum') + + set = Interval(-oo, oo) + + @staticmethod + def check(beta, mu, minimum): + _value_check(beta > 0, "Scale parameter beta must be positive.") + + def pdf(self, x): + beta, mu = self.beta, self.mu + z = (x - mu)/beta + f_max = (1/beta)*exp(-z - exp(-z)) + f_min = (1/beta)*exp(z - exp(z)) + return Piecewise((f_min, self.minimum), (f_max, not self.minimum)) + + def _cdf(self, x): + beta, mu = self.beta, self.mu + z = (x - mu)/beta + F_max = exp(-exp(-z)) + F_min = 1 - exp(-exp(z)) + return Piecewise((F_min, self.minimum), (F_max, not self.minimum)) + + def _characteristic_function(self, t): + cf_max = gamma(1 - I*self.beta*t) * exp(I*self.mu*t) + cf_min = gamma(1 + I*self.beta*t) * exp(I*self.mu*t) + return Piecewise((cf_min, self.minimum), (cf_max, not self.minimum)) + + def _moment_generating_function(self, t): + mgf_max = gamma(1 - self.beta*t) * exp(self.mu*t) + mgf_min = gamma(1 + self.beta*t) * exp(self.mu*t) + return Piecewise((mgf_min, self.minimum), (mgf_max, not self.minimum)) + +def Gumbel(name, beta, mu, minimum=False): + r""" + Create a Continuous Random Variable with Gumbel distribution. + + Explanation + =========== + + The density of the Gumbel distribution is given by + + For Maximum + + .. math:: + f(x) := \dfrac{1}{\beta} \exp \left( -\dfrac{x-\mu}{\beta} + - \exp \left( -\dfrac{x - \mu}{\beta} \right) \right) + + with :math:`x \in [ - \infty, \infty ]`. + + For Minimum + + .. math:: + f(x) := \frac{e^{- e^{\frac{- \mu + x}{\beta}} + \frac{- \mu + x}{\beta}}}{\beta} + + with :math:`x \in [ - \infty, \infty ]`. + + Parameters + ========== + + mu : Real number, `\mu`, a location + beta : Real number, `\beta > 0`, a scale + minimum : Boolean, by default ``False``, set to ``True`` for enabling minimum distribution + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Gumbel, density, cdf + >>> from sympy import Symbol + >>> x = Symbol("x") + >>> mu = Symbol("mu") + >>> beta = Symbol("beta", positive=True) + >>> X = Gumbel("x", beta, mu) + >>> density(X)(x) + exp(-exp(-(-mu + x)/beta) - (-mu + x)/beta)/beta + >>> cdf(X)(x) + exp(-exp(-(-mu + x)/beta)) + + References + ========== + + .. [1] https://mathworld.wolfram.com/GumbelDistribution.html + .. [2] https://en.wikipedia.org/wiki/Gumbel_distribution + .. [3] https://web.archive.org/web/20200628222206/http://www.mathwave.com/help/easyfit/html/analyses/distributions/gumbel_max.html + .. [4] https://web.archive.org/web/20200628222212/http://www.mathwave.com/help/easyfit/html/analyses/distributions/gumbel_min.html + + """ + return rv(name, GumbelDistribution, (beta, mu, minimum)) + +#------------------------------------------------------------------------------- +# Gompertz distribution -------------------------------------------------------- + +class GompertzDistribution(SingleContinuousDistribution): + _argnames = ('b', 'eta') + + set = Interval(0, oo) + + @staticmethod + def check(b, eta): + _value_check(b > 0, "b must be positive") + _value_check(eta > 0, "eta must be positive") + + def pdf(self, x): + eta, b = self.eta, self.b + return b*eta*exp(b*x)*exp(eta)*exp(-eta*exp(b*x)) + + def _cdf(self, x): + eta, b = self.eta, self.b + return 1 - exp(eta)*exp(-eta*exp(b*x)) + + def _moment_generating_function(self, t): + eta, b = self.eta, self.b + return eta * exp(eta) * expint(t/b, eta) + +def Gompertz(name, b, eta): + r""" + Create a Continuous Random Variable with Gompertz distribution. + + Explanation + =========== + + The density of the Gompertz distribution is given by + + .. math:: + f(x) := b \eta e^{b x} e^{\eta} \exp \left(-\eta e^{bx} \right) + + with :math:`x \in [0, \infty)`. + + Parameters + ========== + + b : Real number, `b > 0`, a scale + eta : Real number, `\eta > 0`, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Gompertz, density + >>> from sympy import Symbol + + >>> b = Symbol("b", positive=True) + >>> eta = Symbol("eta", positive=True) + >>> z = Symbol("z") + + >>> X = Gompertz("x", b, eta) + + >>> density(X)(z) + b*eta*exp(eta)*exp(b*z)*exp(-eta*exp(b*z)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gompertz_distribution + + """ + return rv(name, GompertzDistribution, (b, eta)) + +#------------------------------------------------------------------------------- +# Kumaraswamy distribution ----------------------------------------------------- + + +class KumaraswamyDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b') + + set = Interval(0, oo) + + @staticmethod + def check(a, b): + _value_check(a > 0, "a must be positive") + _value_check(b > 0, "b must be positive") + + def pdf(self, x): + a, b = self.a, self.b + return a * b * x**(a-1) * (1-x**a)**(b-1) + + def _cdf(self, x): + a, b = self.a, self.b + return Piecewise( + (S.Zero, x < S.Zero), + (1 - (1 - x**a)**b, x <= S.One), + (S.One, True)) + +def Kumaraswamy(name, a, b): + r""" + Create a Continuous Random Variable with a Kumaraswamy distribution. + + Explanation + =========== + + The density of the Kumaraswamy distribution is given by + + .. math:: + f(x) := a b x^{a-1} (1-x^a)^{b-1} + + with :math:`x \in [0,1]`. + + Parameters + ========== + + a : Real number, `a > 0`, a shape + b : Real number, `b > 0`, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Kumaraswamy, density, cdf + >>> from sympy import Symbol, pprint + + >>> a = Symbol("a", positive=True) + >>> b = Symbol("b", positive=True) + >>> z = Symbol("z") + + >>> X = Kumaraswamy("x", a, b) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + b - 1 + a - 1 / a\ + a*b*z *\1 - z / + + >>> cdf(X)(z) + Piecewise((0, z < 0), (1 - (1 - z**a)**b, z <= 1), (1, True)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Kumaraswamy_distribution + + """ + + return rv(name, KumaraswamyDistribution, (a, b)) + +#------------------------------------------------------------------------------- +# Laplace distribution --------------------------------------------------------- + + +class LaplaceDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'b') + + set = Interval(-oo, oo) + + @staticmethod + def check(mu, b): + _value_check(b > 0, "Scale parameter b must be positive.") + _value_check(mu.is_real, "Location parameter mu should be real") + + def pdf(self, x): + mu, b = self.mu, self.b + return 1/(2*b)*exp(-Abs(x - mu)/b) + + def _cdf(self, x): + mu, b = self.mu, self.b + return Piecewise( + (S.Half*exp((x - mu)/b), x < mu), + (S.One - S.Half*exp(-(x - mu)/b), x >= mu) + ) + + def _characteristic_function(self, t): + return exp(self.mu*I*t) / (1 + self.b**2*t**2) + + def _moment_generating_function(self, t): + return exp(self.mu*t) / (1 - self.b**2*t**2) + +def Laplace(name, mu, b): + r""" + Create a continuous random variable with a Laplace distribution. + + Explanation + =========== + + The density of the Laplace distribution is given by + + .. math:: + f(x) := \frac{1}{2 b} \exp \left(-\frac{|x-\mu|}b \right) + + Parameters + ========== + + mu : Real number or a list/matrix, the location (mean) or the + location vector + b : Real number or a positive definite matrix, representing a scale + or the covariance matrix. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Laplace, density, cdf + >>> from sympy import Symbol, pprint + + >>> mu = Symbol("mu") + >>> b = Symbol("b", positive=True) + >>> z = Symbol("z") + + >>> X = Laplace("x", mu, b) + + >>> density(X)(z) + exp(-Abs(mu - z)/b)/(2*b) + + >>> cdf(X)(z) + Piecewise((exp((-mu + z)/b)/2, mu > z), (1 - exp((mu - z)/b)/2, True)) + + >>> L = Laplace('L', [1, 2], [[1, 0], [0, 1]]) + >>> pprint(density(L)(1, 2), use_unicode=False) + 5 / ____\ + e *besselk\0, \/ 35 / + --------------------- + pi + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Laplace_distribution + .. [2] https://mathworld.wolfram.com/LaplaceDistribution.html + + """ + + if isinstance(mu, (list, MatrixBase)) and\ + isinstance(b, (list, MatrixBase)): + from sympy.stats.joint_rv_types import MultivariateLaplace + return MultivariateLaplace(name, mu, b) + + return rv(name, LaplaceDistribution, (mu, b)) + +#------------------------------------------------------------------------------- +# Levy distribution --------------------------------------------------------- + + +class LevyDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'c') + + @property + def set(self): + return Interval(self.mu, oo) + + @staticmethod + def check(mu, c): + _value_check(c > 0, "c (scale parameter) must be positive") + _value_check(mu.is_real, "mu (location parameter) must be real") + + def pdf(self, x): + mu, c = self.mu, self.c + return sqrt(c/(2*pi))*exp(-c/(2*(x - mu)))/((x - mu)**(S.One + S.Half)) + + def _cdf(self, x): + mu, c = self.mu, self.c + return erfc(sqrt(c/(2*(x - mu)))) + + def _characteristic_function(self, t): + mu, c = self.mu, self.c + return exp(I * mu * t - sqrt(-2 * I * c * t)) + + def _moment_generating_function(self, t): + raise NotImplementedError('The moment generating function of Levy distribution does not exist.') + +def Levy(name, mu, c): + r""" + Create a continuous random variable with a Levy distribution. + + The density of the Levy distribution is given by + + .. math:: + f(x) := \sqrt(\frac{c}{2 \pi}) \frac{\exp -\frac{c}{2 (x - \mu)}}{(x - \mu)^{3/2}} + + Parameters + ========== + + mu : Real number + The location parameter. + c : Real number, `c > 0` + A scale parameter. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Levy, density, cdf + >>> from sympy import Symbol + + >>> mu = Symbol("mu", real=True) + >>> c = Symbol("c", positive=True) + >>> z = Symbol("z") + + >>> X = Levy("x", mu, c) + + >>> density(X)(z) + sqrt(2)*sqrt(c)*exp(-c/(-2*mu + 2*z))/(2*sqrt(pi)*(-mu + z)**(3/2)) + + >>> cdf(X)(z) + erfc(sqrt(c)*sqrt(1/(-2*mu + 2*z))) + + References + ========== + .. [1] https://en.wikipedia.org/wiki/L%C3%A9vy_distribution + .. [2] https://mathworld.wolfram.com/LevyDistribution.html + """ + + return rv(name, LevyDistribution, (mu, c)) + +#------------------------------------------------------------------------------- +# Log-Cauchy distribution -------------------------------------------------------- + + +class LogCauchyDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'sigma') + + set = Interval.open(0, oo) + + @staticmethod + def check(mu, sigma): + _value_check((sigma > 0) != False, "Scale parameter Gamma must be positive.") + _value_check(mu.is_real != False, "Location parameter must be real.") + + def pdf(self, x): + mu, sigma = self.mu, self.sigma + return 1/(x*pi)*(sigma/((log(x) - mu)**2 + sigma**2)) + + def _cdf(self, x): + mu, sigma = self.mu, self.sigma + return (1/pi)*atan((log(x) - mu)/sigma) + S.Half + + def _characteristic_function(self, t): + raise NotImplementedError("The characteristic function for the " + "Log-Cauchy distribution does not exist.") + + def _moment_generating_function(self, t): + raise NotImplementedError("The moment generating function for the " + "Log-Cauchy distribution does not exist.") + +def LogCauchy(name, mu, sigma): + r""" + Create a continuous random variable with a Log-Cauchy distribution. + The density of the Log-Cauchy distribution is given by + + .. math:: + f(x) := \frac{1}{\pi x} \frac{\sigma}{(log(x)-\mu^2) + \sigma^2} + + Parameters + ========== + + mu : Real number, the location + + sigma : Real number, `\sigma > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import LogCauchy, density, cdf + >>> from sympy import Symbol, S + + >>> mu = 2 + >>> sigma = S.One / 5 + >>> z = Symbol("z") + + >>> X = LogCauchy("x", mu, sigma) + + >>> density(X)(z) + 1/(5*pi*z*((log(z) - 2)**2 + 1/25)) + + >>> cdf(X)(z) + atan(5*log(z) - 10)/pi + 1/2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Log-Cauchy_distribution + """ + + return rv(name, LogCauchyDistribution, (mu, sigma)) + + +#------------------------------------------------------------------------------- +# Logistic distribution -------------------------------------------------------- + + +class LogisticDistribution(SingleContinuousDistribution): + _argnames = ('mu', 's') + + set = Interval(-oo, oo) + + @staticmethod + def check(mu, s): + _value_check(s > 0, "Scale parameter s must be positive.") + + def pdf(self, x): + mu, s = self.mu, self.s + return exp(-(x - mu)/s)/(s*(1 + exp(-(x - mu)/s))**2) + + def _cdf(self, x): + mu, s = self.mu, self.s + return S.One/(1 + exp(-(x - mu)/s)) + + def _characteristic_function(self, t): + return Piecewise((exp(I*t*self.mu) * pi*self.s*t / sinh(pi*self.s*t), Ne(t, 0)), (S.One, True)) + + def _moment_generating_function(self, t): + return exp(self.mu*t) * beta_fn(1 - self.s*t, 1 + self.s*t) + + def _quantile(self, p): + return self.mu - self.s*log(-S.One + S.One/p) + +def Logistic(name, mu, s): + r""" + Create a continuous random variable with a logistic distribution. + + Explanation + =========== + + The density of the logistic distribution is given by + + .. math:: + f(x) := \frac{e^{-(x-\mu)/s}} {s\left(1+e^{-(x-\mu)/s}\right)^2} + + Parameters + ========== + + mu : Real number, the location (mean) + s : Real number, `s > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Logistic, density, cdf + >>> from sympy import Symbol + + >>> mu = Symbol("mu", real=True) + >>> s = Symbol("s", positive=True) + >>> z = Symbol("z") + + >>> X = Logistic("x", mu, s) + + >>> density(X)(z) + exp((mu - z)/s)/(s*(exp((mu - z)/s) + 1)**2) + + >>> cdf(X)(z) + 1/(exp((mu - z)/s) + 1) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Logistic_distribution + .. [2] https://mathworld.wolfram.com/LogisticDistribution.html + + """ + + return rv(name, LogisticDistribution, (mu, s)) + +#------------------------------------------------------------------------------- +# Log-logistic distribution -------------------------------------------------------- + + +class LogLogisticDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'beta') + + set = Interval(0, oo) + + @staticmethod + def check(alpha, beta): + _value_check(alpha > 0, "Scale parameter Alpha must be positive.") + _value_check(beta > 0, "Shape parameter Beta must be positive.") + + def pdf(self, x): + a, b = self.alpha, self.beta + return ((b/a)*(x/a)**(b - 1))/(1 + (x/a)**b)**2 + + def _cdf(self, x): + a, b = self.alpha, self.beta + return 1/(1 + (x/a)**(-b)) + + def _quantile(self, p): + a, b = self.alpha, self.beta + return a*((p/(1 - p))**(1/b)) + + def expectation(self, expr, var, **kwargs): + a, b = self.args + return Piecewise((S.NaN, b <= 1), (pi*a/(b*sin(pi/b)), True)) + +def LogLogistic(name, alpha, beta): + r""" + Create a continuous random variable with a log-logistic distribution. + The distribution is unimodal when ``beta > 1``. + + Explanation + =========== + + The density of the log-logistic distribution is given by + + .. math:: + f(x) := \frac{(\frac{\beta}{\alpha})(\frac{x}{\alpha})^{\beta - 1}} + {(1 + (\frac{x}{\alpha})^{\beta})^2} + + Parameters + ========== + + alpha : Real number, `\alpha > 0`, scale parameter and median of distribution + beta : Real number, `\beta > 0`, a shape parameter + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import LogLogistic, density, cdf, quantile + >>> from sympy import Symbol, pprint + + >>> alpha = Symbol("alpha", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> p = Symbol("p") + >>> z = Symbol("z", positive=True) + + >>> X = LogLogistic("x", alpha, beta) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + beta - 1 + / z \ + beta*|-----| + \alpha/ + ------------------------ + 2 + / beta \ + |/ z \ | + alpha*||-----| + 1| + \\alpha/ / + + >>> cdf(X)(z) + 1/(1 + (z/alpha)**(-beta)) + + >>> quantile(X)(p) + alpha*(p/(1 - p))**(1/beta) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Log-logistic_distribution + + """ + + return rv(name, LogLogisticDistribution, (alpha, beta)) + +#------------------------------------------------------------------------------- +#Logit-Normal distribution------------------------------------------------------ + +class LogitNormalDistribution(SingleContinuousDistribution): + _argnames = ('mu', 's') + set = Interval.open(0, 1) + + @staticmethod + def check(mu, s): + _value_check((s ** 2).is_real is not False and s ** 2 > 0, "Squared scale parameter s must be positive.") + _value_check(mu.is_real is not False, "Location parameter must be real") + + def _logit(self, x): + return log(x / (1 - x)) + + def pdf(self, x): + mu, s = self.mu, self.s + return exp(-(self._logit(x) - mu)**2/(2*s**2))*(S.One/sqrt(2*pi*(s**2)))*(1/(x*(1 - x))) + + def _cdf(self, x): + mu, s = self.mu, self.s + return (S.One/2)*(1 + erf((self._logit(x) - mu)/(sqrt(2*s**2)))) + + +def LogitNormal(name, mu, s): + r""" + Create a continuous random variable with a Logit-Normal distribution. + + The density of the logistic distribution is given by + + .. math:: + f(x) := \frac{1}{s \sqrt{2 \pi}} \frac{1}{x(1 - x)} e^{- \frac{(logit(x) - \mu)^2}{s^2}} + where logit(x) = \log(\frac{x}{1 - x}) + Parameters + ========== + + mu : Real number, the location (mean) + s : Real number, `s > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import LogitNormal, density, cdf + >>> from sympy import Symbol,pprint + + >>> mu = Symbol("mu", real=True) + >>> s = Symbol("s", positive=True) + >>> z = Symbol("z") + >>> X = LogitNormal("x",mu,s) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + 2 + / / z \\ + -|-mu + log|-----|| + \ \1 - z// + --------------------- + 2 + ___ 2*s + \/ 2 *e + ---------------------------- + ____ + 2*\/ pi *s*z*(1 - z) + + >>> density(X)(z) + sqrt(2)*exp(-(-mu + log(z/(1 - z)))**2/(2*s**2))/(2*sqrt(pi)*s*z*(1 - z)) + + >>> cdf(X)(z) + erf(sqrt(2)*(-mu + log(z/(1 - z)))/(2*s))/2 + 1/2 + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Logit-normal_distribution + + """ + + return rv(name, LogitNormalDistribution, (mu, s)) + +#------------------------------------------------------------------------------- +# Log Normal distribution ------------------------------------------------------ + + +class LogNormalDistribution(SingleContinuousDistribution): + _argnames = ('mean', 'std') + + set = Interval(0, oo) + + @staticmethod + def check(mean, std): + _value_check(std > 0, "Parameter std must be positive.") + + def pdf(self, x): + mean, std = self.mean, self.std + return exp(-(log(x) - mean)**2 / (2*std**2)) / (x*sqrt(2*pi)*std) + + def _cdf(self, x): + mean, std = self.mean, self.std + return Piecewise( + (S.Half + S.Half*erf((log(x) - mean)/sqrt(2)/std), x > 0), + (S.Zero, True) + ) + + def _moment_generating_function(self, t): + raise NotImplementedError('Moment generating function of the log-normal distribution is not defined.') + + +def LogNormal(name, mean, std): + r""" + Create a continuous random variable with a log-normal distribution. + + Explanation + =========== + + The density of the log-normal distribution is given by + + .. math:: + f(x) := \frac{1}{x\sqrt{2\pi\sigma^2}} + e^{-\frac{\left(\ln x-\mu\right)^2}{2\sigma^2}} + + with :math:`x \geq 0`. + + Parameters + ========== + + mu : Real number + The log-scale. + sigma : Real number + A shape. ($\sigma^2 > 0$) + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import LogNormal, density + >>> from sympy import Symbol, pprint + + >>> mu = Symbol("mu", real=True) + >>> sigma = Symbol("sigma", positive=True) + >>> z = Symbol("z") + + >>> X = LogNormal("x", mu, sigma) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + 2 + -(-mu + log(z)) + ----------------- + 2 + ___ 2*sigma + \/ 2 *e + ------------------------ + ____ + 2*\/ pi *sigma*z + + + >>> X = LogNormal('x', 0, 1) # Mean 0, standard deviation 1 + + >>> density(X)(z) + sqrt(2)*exp(-log(z)**2/2)/(2*sqrt(pi)*z) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lognormal + .. [2] https://mathworld.wolfram.com/LogNormalDistribution.html + + """ + + return rv(name, LogNormalDistribution, (mean, std)) + +#------------------------------------------------------------------------------- +# Lomax Distribution ----------------------------------------------------------- + +class LomaxDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'lamda',) + set = Interval(0, oo) + + @staticmethod + def check(alpha, lamda): + _value_check(alpha.is_real, "Shape parameter should be real.") + _value_check(lamda.is_real, "Scale parameter should be real.") + _value_check(alpha.is_positive, "Shape parameter should be positive.") + _value_check(lamda.is_positive, "Scale parameter should be positive.") + + def pdf(self, x): + lamba, alpha = self.lamda, self.alpha + return (alpha/lamba) * (S.One + x/lamba)**(-alpha-1) + +def Lomax(name, alpha, lamda): + r""" + Create a continuous random variable with a Lomax distribution. + + Explanation + =========== + + The density of the Lomax distribution is given by + + .. math:: + f(x) := \frac{\alpha}{\lambda}\left[1+\frac{x}{\lambda}\right]^{-(\alpha+1)} + + Parameters + ========== + + alpha : Real Number, `\alpha > 0` + Shape parameter + lamda : Real Number, `\lambda > 0` + Scale parameter + + Examples + ======== + + >>> from sympy.stats import Lomax, density, cdf, E + >>> from sympy import symbols + >>> a, l = symbols('a, l', positive=True) + >>> X = Lomax('X', a, l) + >>> x = symbols('x') + >>> density(X)(x) + a*(1 + x/l)**(-a - 1)/l + >>> cdf(X)(x) + Piecewise((1 - 1/(1 + x/l)**a, x >= 0), (0, True)) + >>> a = 2 + >>> X = Lomax('X', a, l) + >>> E(X) + l + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lomax_distribution + + """ + return rv(name, LomaxDistribution, (alpha, lamda)) + +#------------------------------------------------------------------------------- +# Maxwell distribution --------------------------------------------------------- + + +class MaxwellDistribution(SingleContinuousDistribution): + _argnames = ('a',) + + set = Interval(0, oo) + + @staticmethod + def check(a): + _value_check(a > 0, "Parameter a must be positive.") + + def pdf(self, x): + a = self.a + return sqrt(2/pi)*x**2*exp(-x**2/(2*a**2))/a**3 + + def _cdf(self, x): + a = self.a + return erf(sqrt(2)*x/(2*a)) - sqrt(2)*x*exp(-x**2/(2*a**2))/(sqrt(pi)*a) + +def Maxwell(name, a): + r""" + Create a continuous random variable with a Maxwell distribution. + + Explanation + =========== + + The density of the Maxwell distribution is given by + + .. math:: + f(x) := \sqrt{\frac{2}{\pi}} \frac{x^2 e^{-x^2/(2a^2)}}{a^3} + + with :math:`x \geq 0`. + + .. TODO - what does the parameter mean? + + Parameters + ========== + + a : Real number, `a > 0` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Maxwell, density, E, variance + >>> from sympy import Symbol, simplify + + >>> a = Symbol("a", positive=True) + >>> z = Symbol("z") + + >>> X = Maxwell("x", a) + + >>> density(X)(z) + sqrt(2)*z**2*exp(-z**2/(2*a**2))/(sqrt(pi)*a**3) + + >>> E(X) + 2*sqrt(2)*a/sqrt(pi) + + >>> simplify(variance(X)) + a**2*(-8 + 3*pi)/pi + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Maxwell_distribution + .. [2] https://mathworld.wolfram.com/MaxwellDistribution.html + + """ + + return rv(name, MaxwellDistribution, (a, )) + +#------------------------------------------------------------------------------- +# Moyal Distribution ----------------------------------------------------------- +class MoyalDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'sigma') + + @staticmethod + def check(mu, sigma): + _value_check(mu.is_real, "Location parameter must be real.") + _value_check(sigma.is_real and sigma > 0, "Scale parameter must be real\ + and positive.") + + def pdf(self, x): + mu, sigma = self.mu, self.sigma + num = exp(-(exp(-(x - mu)/sigma) + (x - mu)/(sigma))/2) + den = (sqrt(2*pi) * sigma) + return num/den + + def _characteristic_function(self, t): + mu, sigma = self.mu, self.sigma + term1 = exp(I*t*mu) + term2 = (2**(-I*sigma*t) * gamma(Rational(1, 2) - I*t*sigma)) + return (term1 * term2)/sqrt(pi) + + def _moment_generating_function(self, t): + mu, sigma = self.mu, self.sigma + term1 = exp(t*mu) + term2 = (2**(-1*sigma*t) * gamma(Rational(1, 2) - t*sigma)) + return (term1 * term2)/sqrt(pi) + +def Moyal(name, mu, sigma): + r""" + Create a continuous random variable with a Moyal distribution. + + Explanation + =========== + + The density of the Moyal distribution is given by + + .. math:: + f(x) := \frac{\exp-\frac{1}{2}\exp-\frac{x-\mu}{\sigma}-\frac{x-\mu}{2\sigma}}{\sqrt{2\pi}\sigma} + + with :math:`x \in \mathbb{R}`. + + Parameters + ========== + + mu : Real number + Location parameter + sigma : Real positive number + Scale parameter + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Moyal, density, cdf + >>> from sympy import Symbol, simplify + >>> mu = Symbol("mu", real=True) + >>> sigma = Symbol("sigma", positive=True, real=True) + >>> z = Symbol("z") + >>> X = Moyal("x", mu, sigma) + >>> density(X)(z) + sqrt(2)*exp(-exp((mu - z)/sigma)/2 - (-mu + z)/(2*sigma))/(2*sqrt(pi)*sigma) + >>> simplify(cdf(X)(z)) + 1 - erf(sqrt(2)*exp((mu - z)/(2*sigma))/2) + + References + ========== + + .. [1] https://reference.wolfram.com/language/ref/MoyalDistribution.html + .. [2] https://www.stat.rice.edu/~dobelman/textfiles/DistributionsHandbook.pdf + + """ + + return rv(name, MoyalDistribution, (mu, sigma)) + +#------------------------------------------------------------------------------- +# Nakagami distribution -------------------------------------------------------- + + +class NakagamiDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'omega') + + set = Interval(0, oo) + + @staticmethod + def check(mu, omega): + _value_check(mu >= S.Half, "Shape parameter mu must be greater than equal to 1/2.") + _value_check(omega > 0, "Spread parameter omega must be positive.") + + def pdf(self, x): + mu, omega = self.mu, self.omega + return 2*mu**mu/(gamma(mu)*omega**mu)*x**(2*mu - 1)*exp(-mu/omega*x**2) + + def _cdf(self, x): + mu, omega = self.mu, self.omega + return Piecewise( + (lowergamma(mu, (mu/omega)*x**2)/gamma(mu), x > 0), + (S.Zero, True)) + +def Nakagami(name, mu, omega): + r""" + Create a continuous random variable with a Nakagami distribution. + + Explanation + =========== + + The density of the Nakagami distribution is given by + + .. math:: + f(x) := \frac{2\mu^\mu}{\Gamma(\mu)\omega^\mu} x^{2\mu-1} + \exp\left(-\frac{\mu}{\omega}x^2 \right) + + with :math:`x > 0`. + + Parameters + ========== + + mu : Real number, `\mu \geq \frac{1}{2}`, a shape + omega : Real number, `\omega > 0`, the spread + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Nakagami, density, E, variance, cdf + >>> from sympy import Symbol, simplify, pprint + + >>> mu = Symbol("mu", positive=True) + >>> omega = Symbol("omega", positive=True) + >>> z = Symbol("z") + + >>> X = Nakagami("x", mu, omega) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + 2 + -mu*z + ------- + mu -mu 2*mu - 1 omega + 2*mu *omega *z *e + ---------------------------------- + Gamma(mu) + + >>> simplify(E(X)) + sqrt(mu)*sqrt(omega)*gamma(mu + 1/2)/gamma(mu + 1) + + >>> V = simplify(variance(X)) + >>> pprint(V, use_unicode=False) + 2 + omega*Gamma (mu + 1/2) + omega - ----------------------- + Gamma(mu)*Gamma(mu + 1) + + >>> cdf(X)(z) + Piecewise((lowergamma(mu, mu*z**2/omega)/gamma(mu), z > 0), + (0, True)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Nakagami_distribution + + """ + + return rv(name, NakagamiDistribution, (mu, omega)) + +#------------------------------------------------------------------------------- +# Normal distribution ---------------------------------------------------------- + + +class NormalDistribution(SingleContinuousDistribution): + _argnames = ('mean', 'std') + + @staticmethod + def check(mean, std): + _value_check(std > 0, "Standard deviation must be positive") + + def pdf(self, x): + return exp(-(x - self.mean)**2 / (2*self.std**2)) / (sqrt(2*pi)*self.std) + + def _cdf(self, x): + mean, std = self.mean, self.std + return erf(sqrt(2)*(-mean + x)/(2*std))/2 + S.Half + + def _characteristic_function(self, t): + mean, std = self.mean, self.std + return exp(I*mean*t - std**2*t**2/2) + + def _moment_generating_function(self, t): + mean, std = self.mean, self.std + return exp(mean*t + std**2*t**2/2) + + def _quantile(self, p): + mean, std = self.mean, self.std + return mean + std*sqrt(2)*erfinv(2*p - 1) + + +def Normal(name, mean, std): + r""" + Create a continuous random variable with a Normal distribution. + + Explanation + =========== + + The density of the Normal distribution is given by + + .. math:: + f(x) := \frac{1}{\sigma\sqrt{2\pi}} e^{ -\frac{(x-\mu)^2}{2\sigma^2} } + + Parameters + ========== + + mu : Real number or a list representing the mean or the mean vector + sigma : Real number or a positive definite square matrix, + :math:`\sigma^2 > 0`, the variance + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Normal, density, E, std, cdf, skewness, quantile, marginal_distribution + >>> from sympy import Symbol, simplify, pprint + + >>> mu = Symbol("mu") + >>> sigma = Symbol("sigma", positive=True) + >>> z = Symbol("z") + >>> y = Symbol("y") + >>> p = Symbol("p") + >>> X = Normal("x", mu, sigma) + + >>> density(X)(z) + sqrt(2)*exp(-(-mu + z)**2/(2*sigma**2))/(2*sqrt(pi)*sigma) + + >>> C = simplify(cdf(X))(z) # it needs a little more help... + >>> pprint(C, use_unicode=False) + / ___ \ + |\/ 2 *(-mu + z)| + erf|---------------| + \ 2*sigma / 1 + -------------------- + - + 2 2 + + >>> quantile(X)(p) + mu + sqrt(2)*sigma*erfinv(2*p - 1) + + >>> simplify(skewness(X)) + 0 + + >>> X = Normal("x", 0, 1) # Mean 0, standard deviation 1 + >>> density(X)(z) + sqrt(2)*exp(-z**2/2)/(2*sqrt(pi)) + + >>> E(2*X + 1) + 1 + + >>> simplify(std(2*X + 1)) + 2 + + >>> m = Normal('X', [1, 2], [[2, 1], [1, 2]]) + >>> pprint(density(m)(y, z), use_unicode=False) + 2 2 + y y*z z + - -- + --- - -- + z - 1 + ___ 3 3 3 + \/ 3 *e + ------------------------------ + 6*pi + + >>> marginal_distribution(m, m[0])(1) + 1/(2*sqrt(pi)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Normal_distribution + .. [2] https://mathworld.wolfram.com/NormalDistributionFunction.html + + """ + + if isinstance(mean, list) or getattr(mean, 'is_Matrix', False) and\ + isinstance(std, list) or getattr(std, 'is_Matrix', False): + from sympy.stats.joint_rv_types import MultivariateNormal + return MultivariateNormal(name, mean, std) + return rv(name, NormalDistribution, (mean, std)) + + +#------------------------------------------------------------------------------- +# Inverse Gaussian distribution ---------------------------------------------------------- + + +class GaussianInverseDistribution(SingleContinuousDistribution): + _argnames = ('mean', 'shape') + + @property + def set(self): + return Interval(0, oo) + + @staticmethod + def check(mean, shape): + _value_check(shape > 0, "Shape parameter must be positive") + _value_check(mean > 0, "Mean must be positive") + + def pdf(self, x): + mu, s = self.mean, self.shape + return exp(-s*(x - mu)**2 / (2*x*mu**2)) * sqrt(s/(2*pi*x**3)) + + def _cdf(self, x): + from sympy.stats import cdf + mu, s = self.mean, self.shape + stdNormalcdf = cdf(Normal('x', 0, 1)) + + first_term = stdNormalcdf(sqrt(s/x) * ((x/mu) - S.One)) + second_term = exp(2*s/mu) * stdNormalcdf(-sqrt(s/x)*(x/mu + S.One)) + + return first_term + second_term + + def _characteristic_function(self, t): + mu, s = self.mean, self.shape + return exp((s/mu)*(1 - sqrt(1 - (2*mu**2*I*t)/s))) + + def _moment_generating_function(self, t): + mu, s = self.mean, self.shape + return exp((s/mu)*(1 - sqrt(1 - (2*mu**2*t)/s))) + + +def GaussianInverse(name, mean, shape): + r""" + Create a continuous random variable with an Inverse Gaussian distribution. + Inverse Gaussian distribution is also known as Wald distribution. + + Explanation + =========== + + The density of the Inverse Gaussian distribution is given by + + .. math:: + f(x) := \sqrt{\frac{\lambda}{2\pi x^3}} e^{-\frac{\lambda(x-\mu)^2}{2x\mu^2}} + + Parameters + ========== + + mu : + Positive number representing the mean. + lambda : + Positive number representing the shape parameter. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import GaussianInverse, density, E, std, skewness + >>> from sympy import Symbol, pprint + + >>> mu = Symbol("mu", positive=True) + >>> lamda = Symbol("lambda", positive=True) + >>> z = Symbol("z", positive=True) + >>> X = GaussianInverse("x", mu, lamda) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + 2 + -lambda*(-mu + z) + ------------------- + 2 + ___ ________ 2*mu *z + \/ 2 *\/ lambda *e + ------------------------------------- + ____ 3/2 + 2*\/ pi *z + + >>> E(X) + mu + + >>> std(X).expand() + mu**(3/2)/sqrt(lambda) + + >>> skewness(X).expand() + 3*sqrt(mu)/sqrt(lambda) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution + .. [2] https://mathworld.wolfram.com/InverseGaussianDistribution.html + + """ + + return rv(name, GaussianInverseDistribution, (mean, shape)) + +Wald = GaussianInverse + +#------------------------------------------------------------------------------- +# Pareto distribution ---------------------------------------------------------- + + +class ParetoDistribution(SingleContinuousDistribution): + _argnames = ('xm', 'alpha') + + @property + def set(self): + return Interval(self.xm, oo) + + @staticmethod + def check(xm, alpha): + _value_check(xm > 0, "Xm must be positive") + _value_check(alpha > 0, "Alpha must be positive") + + def pdf(self, x): + xm, alpha = self.xm, self.alpha + return alpha * xm**alpha / x**(alpha + 1) + + def _cdf(self, x): + xm, alpha = self.xm, self.alpha + return Piecewise( + (S.One - xm**alpha/x**alpha, x>=xm), + (0, True), + ) + + def _moment_generating_function(self, t): + xm, alpha = self.xm, self.alpha + return alpha * (-xm*t)**alpha * uppergamma(-alpha, -xm*t) + + def _characteristic_function(self, t): + xm, alpha = self.xm, self.alpha + return alpha * (-I * xm * t) ** alpha * uppergamma(-alpha, -I * xm * t) + + +def Pareto(name, xm, alpha): + r""" + Create a continuous random variable with the Pareto distribution. + + Explanation + =========== + + The density of the Pareto distribution is given by + + .. math:: + f(x) := \frac{\alpha\,x_m^\alpha}{x^{\alpha+1}} + + with :math:`x \in [x_m,\infty]`. + + Parameters + ========== + + xm : Real number, `x_m > 0`, a scale + alpha : Real number, `\alpha > 0`, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Pareto, density + >>> from sympy import Symbol + + >>> xm = Symbol("xm", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> z = Symbol("z") + + >>> X = Pareto("x", xm, beta) + + >>> density(X)(z) + beta*xm**beta*z**(-beta - 1) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Pareto_distribution + .. [2] https://mathworld.wolfram.com/ParetoDistribution.html + + """ + + return rv(name, ParetoDistribution, (xm, alpha)) + +#------------------------------------------------------------------------------- +# PowerFunction distribution --------------------------------------------------- + + +class PowerFunctionDistribution(SingleContinuousDistribution): + _argnames=('alpha','a','b') + + @property + def set(self): + return Interval(self.a, self.b) + + @staticmethod + def check(alpha, a, b): + _value_check(a.is_real, "Continuous Boundary parameter should be real.") + _value_check(b.is_real, "Continuous Boundary parameter should be real.") + _value_check(a < b, " 'a' the left Boundary must be smaller than 'b' the right Boundary." ) + _value_check(alpha.is_positive, "Continuous Shape parameter should be positive.") + + def pdf(self, x): + alpha, a, b = self.alpha, self.a, self.b + num = alpha*(x - a)**(alpha - 1) + den = (b - a)**alpha + return num/den + +def PowerFunction(name, alpha, a, b): + r""" + Creates a continuous random variable with a Power Function Distribution. + + Explanation + =========== + + The density of PowerFunction distribution is given by + + .. math:: + f(x) := \frac{{\alpha}(x - a)^{\alpha - 1}}{(b - a)^{\alpha}} + + with :math:`x \in [a,b]`. + + Parameters + ========== + + alpha : Positive number, `0 < \alpha`, the shape parameter + a : Real number, :math:`-\infty < a`, the left boundary + b : Real number, :math:`a < b < \infty`, the right boundary + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import PowerFunction, density, cdf, E, variance + >>> from sympy import Symbol + >>> alpha = Symbol("alpha", positive=True) + >>> a = Symbol("a", real=True) + >>> b = Symbol("b", real=True) + >>> z = Symbol("z") + + >>> X = PowerFunction("X", 2, a, b) + + >>> density(X)(z) + (-2*a + 2*z)/(-a + b)**2 + + >>> cdf(X)(z) + Piecewise((a**2/(a**2 - 2*a*b + b**2) - 2*a*z/(a**2 - 2*a*b + b**2) + + z**2/(a**2 - 2*a*b + b**2), a <= z), (0, True)) + + >>> alpha = 2 + >>> a = 0 + >>> b = 1 + >>> Y = PowerFunction("Y", alpha, a, b) + + >>> E(Y) + 2/3 + + >>> variance(Y) + 1/18 + + References + ========== + + .. [1] https://web.archive.org/web/20200204081320/http://www.mathwave.com/help/easyfit/html/analyses/distributions/power_func.html + + """ + return rv(name, PowerFunctionDistribution, (alpha, a, b)) + +#------------------------------------------------------------------------------- +# QuadraticU distribution ------------------------------------------------------ + + +class QuadraticUDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b') + + @property + def set(self): + return Interval(self.a, self.b) + + @staticmethod + def check(a, b): + _value_check(b > a, "Parameter b must be in range (%s, oo)."%(a)) + + def pdf(self, x): + a, b = self.a, self.b + alpha = 12 / (b-a)**3 + beta = (a+b) / 2 + return Piecewise( + (alpha * (x-beta)**2, And(a<=x, x<=b)), + (S.Zero, True)) + + def _moment_generating_function(self, t): + a, b = self.a, self.b + return -3 * (exp(a*t) * (4 + (a**2 + 2*a*(-2 + b) + b**2) * t) \ + - exp(b*t) * (4 + (-4*b + (a + b)**2) * t)) / ((a-b)**3 * t**2) + + def _characteristic_function(self, t): + a, b = self.a, self.b + return -3*I*(exp(I*a*t*exp(I*b*t)) * (4*I - (-4*b + (a+b)**2)*t)) \ + / ((a-b)**3 * t**2) + + +def QuadraticU(name, a, b): + r""" + Create a Continuous Random Variable with a U-quadratic distribution. + + Explanation + =========== + + The density of the U-quadratic distribution is given by + + .. math:: + f(x) := \alpha (x-\beta)^2 + + with :math:`x \in [a,b]`. + + Parameters + ========== + + a : Real number + b : Real number, :math:`a < b` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import QuadraticU, density + >>> from sympy import Symbol, pprint + + >>> a = Symbol("a", real=True) + >>> b = Symbol("b", real=True) + >>> z = Symbol("z") + + >>> X = QuadraticU("x", a, b) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + / 2 + | / a b \ + |12*|- - - - + z| + | \ 2 2 / + <----------------- for And(b >= z, a <= z) + | 3 + | (-a + b) + | + \ 0 otherwise + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/U-quadratic_distribution + + """ + + return rv(name, QuadraticUDistribution, (a, b)) + +#------------------------------------------------------------------------------- +# RaisedCosine distribution ---------------------------------------------------- + + +class RaisedCosineDistribution(SingleContinuousDistribution): + _argnames = ('mu', 's') + + @property + def set(self): + return Interval(self.mu - self.s, self.mu + self.s) + + @staticmethod + def check(mu, s): + _value_check(s > 0, "s must be positive") + + def pdf(self, x): + mu, s = self.mu, self.s + return Piecewise( + ((1+cos(pi*(x-mu)/s)) / (2*s), And(mu-s<=x, x<=mu+s)), + (S.Zero, True)) + + def _characteristic_function(self, t): + mu, s = self.mu, self.s + return Piecewise((exp(-I*pi*mu/s)/2, Eq(t, -pi/s)), + (exp(I*pi*mu/s)/2, Eq(t, pi/s)), + (pi**2*sin(s*t)*exp(I*mu*t) / (s*t*(pi**2 - s**2*t**2)), True)) + + def _moment_generating_function(self, t): + mu, s = self.mu, self.s + return pi**2 * sinh(s*t) * exp(mu*t) / (s*t*(pi**2 + s**2*t**2)) + +def RaisedCosine(name, mu, s): + r""" + Create a Continuous Random Variable with a raised cosine distribution. + + Explanation + =========== + + The density of the raised cosine distribution is given by + + .. math:: + f(x) := \frac{1}{2s}\left(1+\cos\left(\frac{x-\mu}{s}\pi\right)\right) + + with :math:`x \in [\mu-s,\mu+s]`. + + Parameters + ========== + + mu : Real number + s : Real number, `s > 0` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import RaisedCosine, density + >>> from sympy import Symbol, pprint + + >>> mu = Symbol("mu", real=True) + >>> s = Symbol("s", positive=True) + >>> z = Symbol("z") + + >>> X = RaisedCosine("x", mu, s) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + / /pi*(-mu + z)\ + |cos|------------| + 1 + | \ s / + <--------------------- for And(z >= mu - s, z <= mu + s) + | 2*s + | + \ 0 otherwise + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Raised_cosine_distribution + + """ + + return rv(name, RaisedCosineDistribution, (mu, s)) + +#------------------------------------------------------------------------------- +# Rayleigh distribution -------------------------------------------------------- + + +class RayleighDistribution(SingleContinuousDistribution): + _argnames = ('sigma',) + + set = Interval(0, oo) + + @staticmethod + def check(sigma): + _value_check(sigma > 0, "Scale parameter sigma must be positive.") + + def pdf(self, x): + sigma = self.sigma + return x/sigma**2*exp(-x**2/(2*sigma**2)) + + def _cdf(self, x): + sigma = self.sigma + return 1 - exp(-(x**2/(2*sigma**2))) + + def _characteristic_function(self, t): + sigma = self.sigma + return 1 - sigma*t*exp(-sigma**2*t**2/2) * sqrt(pi/2) * (erfi(sigma*t/sqrt(2)) - I) + + def _moment_generating_function(self, t): + sigma = self.sigma + return 1 + sigma*t*exp(sigma**2*t**2/2) * sqrt(pi/2) * (erf(sigma*t/sqrt(2)) + 1) + + +def Rayleigh(name, sigma): + r""" + Create a continuous random variable with a Rayleigh distribution. + + Explanation + =========== + + The density of the Rayleigh distribution is given by + + .. math :: + f(x) := \frac{x}{\sigma^2} e^{-x^2/2\sigma^2} + + with :math:`x > 0`. + + Parameters + ========== + + sigma : Real number, `\sigma > 0` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Rayleigh, density, E, variance + >>> from sympy import Symbol + + >>> sigma = Symbol("sigma", positive=True) + >>> z = Symbol("z") + + >>> X = Rayleigh("x", sigma) + + >>> density(X)(z) + z*exp(-z**2/(2*sigma**2))/sigma**2 + + >>> E(X) + sqrt(2)*sqrt(pi)*sigma/2 + + >>> variance(X) + -pi*sigma**2/2 + 2*sigma**2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Rayleigh_distribution + .. [2] https://mathworld.wolfram.com/RayleighDistribution.html + + """ + + return rv(name, RayleighDistribution, (sigma, )) + +#------------------------------------------------------------------------------- +# Reciprocal distribution -------------------------------------------------------- + +class ReciprocalDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b') + + @property + def set(self): + return Interval(self.a, self.b) + + @staticmethod + def check(a, b): + _value_check(a > 0, "Parameter > 0. a = %s"%a) + _value_check((a < b), + "Parameter b must be in range (%s, +oo]. b = %s"%(a, b)) + + def pdf(self, x): + a, b = self.a, self.b + return 1/(x*(log(b) - log(a))) + + +def Reciprocal(name, a, b): + r"""Creates a continuous random variable with a reciprocal distribution. + + + Parameters + ========== + + a : Real number, :math:`0 < a` + b : Real number, :math:`a < b` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Reciprocal, density, cdf + >>> from sympy import symbols + >>> a, b, x = symbols('a, b, x', positive=True) + >>> R = Reciprocal('R', a, b) + + >>> density(R)(x) + 1/(x*(-log(a) + log(b))) + >>> cdf(R)(x) + Piecewise((log(a)/(log(a) - log(b)) - log(x)/(log(a) - log(b)), a <= x), (0, True)) + + Reference + ========= + + .. [1] https://en.wikipedia.org/wiki/Reciprocal_distribution + + """ + return rv(name, ReciprocalDistribution, (a, b)) + + +#------------------------------------------------------------------------------- +# Shifted Gompertz distribution ------------------------------------------------ + + +class ShiftedGompertzDistribution(SingleContinuousDistribution): + _argnames = ('b', 'eta') + + set = Interval(0, oo) + + @staticmethod + def check(b, eta): + _value_check(b > 0, "b must be positive") + _value_check(eta > 0, "eta must be positive") + + def pdf(self, x): + b, eta = self.b, self.eta + return b*exp(-b*x)*exp(-eta*exp(-b*x))*(1+eta*(1-exp(-b*x))) + +def ShiftedGompertz(name, b, eta): + r""" + Create a continuous random variable with a Shifted Gompertz distribution. + + Explanation + =========== + + The density of the Shifted Gompertz distribution is given by + + .. math:: + f(x) := b e^{-b x} e^{-\eta \exp(-b x)} \left[1 + \eta(1 - e^(-bx)) \right] + + with :math:`x \in [0, \infty)`. + + Parameters + ========== + + b : Real number, `b > 0`, a scale + eta : Real number, `\eta > 0`, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + >>> from sympy.stats import ShiftedGompertz, density + >>> from sympy import Symbol + + >>> b = Symbol("b", positive=True) + >>> eta = Symbol("eta", positive=True) + >>> x = Symbol("x") + + >>> X = ShiftedGompertz("x", b, eta) + + >>> density(X)(x) + b*(eta*(1 - exp(-b*x)) + 1)*exp(-b*x)*exp(-eta*exp(-b*x)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Shifted_Gompertz_distribution + + """ + return rv(name, ShiftedGompertzDistribution, (b, eta)) + +#------------------------------------------------------------------------------- +# StudentT distribution -------------------------------------------------------- + + +class StudentTDistribution(SingleContinuousDistribution): + _argnames = ('nu',) + + set = Interval(-oo, oo) + + @staticmethod + def check(nu): + _value_check(nu > 0, "Degrees of freedom nu must be positive.") + + def pdf(self, x): + nu = self.nu + return 1/(sqrt(nu)*beta_fn(S.Half, nu/2))*(1 + x**2/nu)**(-(nu + 1)/2) + + def _cdf(self, x): + nu = self.nu + return S.Half + x*gamma((nu+1)/2)*hyper((S.Half, (nu+1)/2), + (Rational(3, 2),), -x**2/nu)/(sqrt(pi*nu)*gamma(nu/2)) + + def _moment_generating_function(self, t): + raise NotImplementedError('The moment generating function for the Student-T distribution is undefined.') + + +def StudentT(name, nu): + r""" + Create a continuous random variable with a student's t distribution. + + Explanation + =========== + + The density of the student's t distribution is given by + + .. math:: + f(x) := \frac{\Gamma \left(\frac{\nu+1}{2} \right)} + {\sqrt{\nu\pi}\Gamma \left(\frac{\nu}{2} \right)} + \left(1+\frac{x^2}{\nu} \right)^{-\frac{\nu+1}{2}} + + Parameters + ========== + + nu : Real number, `\nu > 0`, the degrees of freedom + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import StudentT, density, cdf + >>> from sympy import Symbol, pprint + + >>> nu = Symbol("nu", positive=True) + >>> z = Symbol("z") + + >>> X = StudentT("x", nu) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + nu 1 + - -- - - + 2 2 + / 2\ + | z | + |1 + --| + \ nu/ + ----------------- + ____ / nu\ + \/ nu *B|1/2, --| + \ 2 / + + >>> cdf(X)(z) + 1/2 + z*gamma(nu/2 + 1/2)*hyper((1/2, nu/2 + 1/2), (3/2,), + -z**2/nu)/(sqrt(pi)*sqrt(nu)*gamma(nu/2)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Student_t-distribution + .. [2] https://mathworld.wolfram.com/Studentst-Distribution.html + + """ + + return rv(name, StudentTDistribution, (nu, )) + +#------------------------------------------------------------------------------- +# Trapezoidal distribution ------------------------------------------------------ + + +class TrapezoidalDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b', 'c', 'd') + + @property + def set(self): + return Interval(self.a, self.d) + + @staticmethod + def check(a, b, c, d): + _value_check(a < d, "Lower bound parameter a < %s. a = %s"%(d, a)) + _value_check((a <= b, b < c), + "Level start parameter b must be in range [%s, %s). b = %s"%(a, c, b)) + _value_check((b < c, c <= d), + "Level end parameter c must be in range (%s, %s]. c = %s"%(b, d, c)) + _value_check(d >= c, "Upper bound parameter d > %s. d = %s"%(c, d)) + + def pdf(self, x): + a, b, c, d = self.a, self.b, self.c, self.d + return Piecewise( + (2*(x-a) / ((b-a)*(d+c-a-b)), And(a <= x, x < b)), + (2 / (d+c-a-b), And(b <= x, x < c)), + (2*(d-x) / ((d-c)*(d+c-a-b)), And(c <= x, x <= d)), + (S.Zero, True)) + +def Trapezoidal(name, a, b, c, d): + r""" + Create a continuous random variable with a trapezoidal distribution. + + Explanation + =========== + + The density of the trapezoidal distribution is given by + + .. math:: + f(x) := \begin{cases} + 0 & \mathrm{for\ } x < a, \\ + \frac{2(x-a)}{(b-a)(d+c-a-b)} & \mathrm{for\ } a \le x < b, \\ + \frac{2}{d+c-a-b} & \mathrm{for\ } b \le x < c, \\ + \frac{2(d-x)}{(d-c)(d+c-a-b)} & \mathrm{for\ } c \le x < d, \\ + 0 & \mathrm{for\ } d < x. + \end{cases} + + Parameters + ========== + + a : Real number, :math:`a < d` + b : Real number, :math:`a \le b < c` + c : Real number, :math:`b < c \le d` + d : Real number + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Trapezoidal, density + >>> from sympy import Symbol, pprint + + >>> a = Symbol("a") + >>> b = Symbol("b") + >>> c = Symbol("c") + >>> d = Symbol("d") + >>> z = Symbol("z") + + >>> X = Trapezoidal("x", a,b,c,d) + + >>> pprint(density(X)(z), use_unicode=False) + / -2*a + 2*z + |------------------------- for And(a <= z, b > z) + |(-a + b)*(-a - b + c + d) + | + | 2 + | -------------- for And(b <= z, c > z) + < -a - b + c + d + | + | 2*d - 2*z + |------------------------- for And(d >= z, c <= z) + |(-c + d)*(-a - b + c + d) + | + \ 0 otherwise + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Trapezoidal_distribution + + """ + return rv(name, TrapezoidalDistribution, (a, b, c, d)) + +#------------------------------------------------------------------------------- +# Triangular distribution ------------------------------------------------------ + + +class TriangularDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b', 'c') + + @property + def set(self): + return Interval(self.a, self.b) + + @staticmethod + def check(a, b, c): + _value_check(b > a, "Parameter b > %s. b = %s"%(a, b)) + _value_check((a <= c, c <= b), + "Parameter c must be in range [%s, %s]. c = %s"%(a, b, c)) + + def pdf(self, x): + a, b, c = self.a, self.b, self.c + return Piecewise( + (2*(x - a)/((b - a)*(c - a)), And(a <= x, x < c)), + (2/(b - a), Eq(x, c)), + (2*(b - x)/((b - a)*(b - c)), And(c < x, x <= b)), + (S.Zero, True)) + + def _characteristic_function(self, t): + a, b, c = self.a, self.b, self.c + return -2 *((b-c) * exp(I*a*t) - (b-a) * exp(I*c*t) + (c-a) * exp(I*b*t)) / ((b-a)*(c-a)*(b-c)*t**2) + + def _moment_generating_function(self, t): + a, b, c = self.a, self.b, self.c + return 2 * ((b - c) * exp(a * t) - (b - a) * exp(c * t) + (c - a) * exp(b * t)) / ( + (b - a) * (c - a) * (b - c) * t ** 2) + + +def Triangular(name, a, b, c): + r""" + Create a continuous random variable with a triangular distribution. + + Explanation + =========== + + The density of the triangular distribution is given by + + .. math:: + f(x) := \begin{cases} + 0 & \mathrm{for\ } x < a, \\ + \frac{2(x-a)}{(b-a)(c-a)} & \mathrm{for\ } a \le x < c, \\ + \frac{2}{b-a} & \mathrm{for\ } x = c, \\ + \frac{2(b-x)}{(b-a)(b-c)} & \mathrm{for\ } c < x \le b, \\ + 0 & \mathrm{for\ } b < x. + \end{cases} + + Parameters + ========== + + a : Real number, :math:`a \in \left(-\infty, \infty\right)` + b : Real number, :math:`a < b` + c : Real number, :math:`a \leq c \leq b` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Triangular, density + >>> from sympy import Symbol, pprint + + >>> a = Symbol("a") + >>> b = Symbol("b") + >>> c = Symbol("c") + >>> z = Symbol("z") + + >>> X = Triangular("x", a,b,c) + + >>> pprint(density(X)(z), use_unicode=False) + / -2*a + 2*z + |----------------- for And(a <= z, c > z) + |(-a + b)*(-a + c) + | + | 2 + | ------ for c = z + < -a + b + | + | 2*b - 2*z + |---------------- for And(b >= z, c < z) + |(-a + b)*(b - c) + | + \ 0 otherwise + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Triangular_distribution + .. [2] https://mathworld.wolfram.com/TriangularDistribution.html + + """ + + return rv(name, TriangularDistribution, (a, b, c)) + +#------------------------------------------------------------------------------- +# Uniform distribution --------------------------------------------------------- + + +class UniformDistribution(SingleContinuousDistribution): + _argnames = ('left', 'right') + + @property + def set(self): + return Interval(self.left, self.right) + + @staticmethod + def check(left, right): + _value_check(left < right, "Lower limit should be less than Upper limit.") + + def pdf(self, x): + left, right = self.left, self.right + return Piecewise( + (S.One/(right - left), And(left <= x, x <= right)), + (S.Zero, True) + ) + + def _cdf(self, x): + left, right = self.left, self.right + return Piecewise( + (S.Zero, x < left), + ((x - left)/(right - left), x <= right), + (S.One, True) + ) + + def _characteristic_function(self, t): + left, right = self.left, self.right + return Piecewise(((exp(I*t*right) - exp(I*t*left)) / (I*t*(right - left)), Ne(t, 0)), + (S.One, True)) + + def _moment_generating_function(self, t): + left, right = self.left, self.right + return Piecewise(((exp(t*right) - exp(t*left)) / (t * (right - left)), Ne(t, 0)), + (S.One, True)) + + def expectation(self, expr, var, **kwargs): + kwargs['evaluate'] = True + result = SingleContinuousDistribution.expectation(self, expr, var, **kwargs) + result = result.subs({Max(self.left, self.right): self.right, + Min(self.left, self.right): self.left}) + return result + + +def Uniform(name, left, right): + r""" + Create a continuous random variable with a uniform distribution. + + Explanation + =========== + + The density of the uniform distribution is given by + + .. math:: + f(x) := \begin{cases} + \frac{1}{b - a} & \text{for } x \in [a,b] \\ + 0 & \text{otherwise} + \end{cases} + + with :math:`x \in [a,b]`. + + Parameters + ========== + + a : Real number, :math:`-\infty < a`, the left boundary + b : Real number, :math:`a < b < \infty`, the right boundary + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Uniform, density, cdf, E, variance + >>> from sympy import Symbol, simplify + + >>> a = Symbol("a", negative=True) + >>> b = Symbol("b", positive=True) + >>> z = Symbol("z") + + >>> X = Uniform("x", a, b) + + >>> density(X)(z) + Piecewise((1/(-a + b), (b >= z) & (a <= z)), (0, True)) + + >>> cdf(X)(z) + Piecewise((0, a > z), ((-a + z)/(-a + b), b >= z), (1, True)) + + >>> E(X) + a/2 + b/2 + + >>> simplify(variance(X)) + a**2/12 - a*b/6 + b**2/12 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Uniform_distribution_%28continuous%29 + .. [2] https://mathworld.wolfram.com/UniformDistribution.html + + """ + + return rv(name, UniformDistribution, (left, right)) + +#------------------------------------------------------------------------------- +# UniformSum distribution ------------------------------------------------------ + + +class UniformSumDistribution(SingleContinuousDistribution): + _argnames = ('n',) + + @property + def set(self): + return Interval(0, self.n) + + @staticmethod + def check(n): + _value_check((n > 0, n.is_integer), + "Parameter n must be positive integer.") + + def pdf(self, x): + n = self.n + k = Dummy("k") + return 1/factorial( + n - 1)*Sum((-1)**k*binomial(n, k)*(x - k)**(n - 1), (k, 0, floor(x))) + + def _cdf(self, x): + n = self.n + k = Dummy("k") + return Piecewise((S.Zero, x < 0), + (1/factorial(n)*Sum((-1)**k*binomial(n, k)*(x - k)**(n), + (k, 0, floor(x))), x <= n), + (S.One, True)) + + def _characteristic_function(self, t): + return ((exp(I*t) - 1) / (I*t))**self.n + + def _moment_generating_function(self, t): + return ((exp(t) - 1) / t)**self.n + +def UniformSum(name, n): + r""" + Create a continuous random variable with an Irwin-Hall distribution. + + Explanation + =========== + + The probability distribution function depends on a single parameter + $n$ which is an integer. + + The density of the Irwin-Hall distribution is given by + + .. math :: + f(x) := \frac{1}{(n-1)!}\sum_{k=0}^{\left\lfloor x\right\rfloor}(-1)^k + \binom{n}{k}(x-k)^{n-1} + + Parameters + ========== + + n : A positive integer, `n > 0` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import UniformSum, density, cdf + >>> from sympy import Symbol, pprint + + >>> n = Symbol("n", integer=True) + >>> z = Symbol("z") + + >>> X = UniformSum("x", n) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + floor(z) + ___ + \ ` + \ k n - 1 /n\ + ) (-1) *(-k + z) *| | + / \k/ + /__, + k = 0 + -------------------------------- + (n - 1)! + + >>> cdf(X)(z) + Piecewise((0, z < 0), (Sum((-1)**_k*(-_k + z)**n*binomial(n, _k), + (_k, 0, floor(z)))/factorial(n), n >= z), (1, True)) + + + Compute cdf with specific 'x' and 'n' values as follows : + >>> cdf(UniformSum("x", 5), evaluate=False)(2).doit() + 9/40 + + The argument evaluate=False prevents an attempt at evaluation + of the sum for general n, before the argument 2 is passed. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Uniform_sum_distribution + .. [2] https://mathworld.wolfram.com/UniformSumDistribution.html + + """ + + return rv(name, UniformSumDistribution, (n, )) + +#------------------------------------------------------------------------------- +# VonMises distribution -------------------------------------------------------- + + +class VonMisesDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'k') + + set = Interval(0, 2*pi) + + @staticmethod + def check(mu, k): + _value_check(k > 0, "k must be positive") + + def pdf(self, x): + mu, k = self.mu, self.k + return exp(k*cos(x-mu)) / (2*pi*besseli(0, k)) + +def VonMises(name, mu, k): + r""" + Create a Continuous Random Variable with a von Mises distribution. + + Explanation + =========== + + The density of the von Mises distribution is given by + + .. math:: + f(x) := \frac{e^{\kappa\cos(x-\mu)}}{2\pi I_0(\kappa)} + + with :math:`x \in [0,2\pi]`. + + Parameters + ========== + + mu : Real number + Measure of location. + k : Real number + Measure of concentration. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import VonMises, density + >>> from sympy import Symbol, pprint + + >>> mu = Symbol("mu") + >>> k = Symbol("k", positive=True) + >>> z = Symbol("z") + + >>> X = VonMises("x", mu, k) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + k*cos(mu - z) + e + ------------------ + 2*pi*besseli(0, k) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Von_Mises_distribution + .. [2] https://mathworld.wolfram.com/vonMisesDistribution.html + + """ + + return rv(name, VonMisesDistribution, (mu, k)) + +#------------------------------------------------------------------------------- +# Weibull distribution --------------------------------------------------------- + + +class WeibullDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'beta') + + set = Interval(0, oo) + + @staticmethod + def check(alpha, beta): + _value_check(alpha > 0, "Alpha must be positive") + _value_check(beta > 0, "Beta must be positive") + + def pdf(self, x): + alpha, beta = self.alpha, self.beta + return beta * (x/alpha)**(beta - 1) * exp(-(x/alpha)**beta) / alpha + + +def Weibull(name, alpha, beta): + r""" + Create a continuous random variable with a Weibull distribution. + + Explanation + =========== + + The density of the Weibull distribution is given by + + .. math:: + f(x) := \begin{cases} + \frac{k}{\lambda}\left(\frac{x}{\lambda}\right)^{k-1} + e^{-(x/\lambda)^{k}} & x\geq0\\ + 0 & x<0 + \end{cases} + + Parameters + ========== + + lambda : Real number, $\lambda > 0$, a scale + k : Real number, $k > 0$, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Weibull, density, E, variance + >>> from sympy import Symbol, simplify + + >>> l = Symbol("lambda", positive=True) + >>> k = Symbol("k", positive=True) + >>> z = Symbol("z") + + >>> X = Weibull("x", l, k) + + >>> density(X)(z) + k*(z/lambda)**(k - 1)*exp(-(z/lambda)**k)/lambda + + >>> simplify(E(X)) + lambda*gamma(1 + 1/k) + + >>> simplify(variance(X)) + lambda**2*(-gamma(1 + 1/k)**2 + gamma(1 + 2/k)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Weibull_distribution + .. [2] https://mathworld.wolfram.com/WeibullDistribution.html + + """ + + return rv(name, WeibullDistribution, (alpha, beta)) + +#------------------------------------------------------------------------------- +# Wigner semicircle distribution ----------------------------------------------- + + +class WignerSemicircleDistribution(SingleContinuousDistribution): + _argnames = ('R',) + + @property + def set(self): + return Interval(-self.R, self.R) + + @staticmethod + def check(R): + _value_check(R > 0, "Radius R must be positive.") + + def pdf(self, x): + R = self.R + return 2/(pi*R**2)*sqrt(R**2 - x**2) + + def _characteristic_function(self, t): + return Piecewise((2 * besselj(1, self.R*t) / (self.R*t), Ne(t, 0)), + (S.One, True)) + + def _moment_generating_function(self, t): + return Piecewise((2 * besseli(1, self.R*t) / (self.R*t), Ne(t, 0)), + (S.One, True)) + +def WignerSemicircle(name, R): + r""" + Create a continuous random variable with a Wigner semicircle distribution. + + Explanation + =========== + + The density of the Wigner semicircle distribution is given by + + .. math:: + f(x) := \frac2{\pi R^2}\,\sqrt{R^2-x^2} + + with :math:`x \in [-R,R]`. + + Parameters + ========== + + R : Real number, `R > 0`, the radius + + Returns + ======= + + A RandomSymbol. + + Examples + ======== + + >>> from sympy.stats import WignerSemicircle, density, E + >>> from sympy import Symbol + + >>> R = Symbol("R", positive=True) + >>> z = Symbol("z") + + >>> X = WignerSemicircle("x", R) + + >>> density(X)(z) + 2*sqrt(R**2 - z**2)/(pi*R**2) + + >>> E(X) + 0 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Wigner_semicircle_distribution + .. [2] https://mathworld.wolfram.com/WignersSemicircleLaw.html + + """ + + return rv(name, WignerSemicircleDistribution, (R,)) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/drv.py b/.venv/lib/python3.13/site-packages/sympy/stats/drv.py new file mode 100644 index 0000000000000000000000000000000000000000..dea14f2dfd1078c223c61bb5cd1373105e72ea28 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/drv.py @@ -0,0 +1,350 @@ +from sympy.concrete.summations import (Sum, summation) +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.function import Lambda +from sympy.core.numbers import I +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.piecewise import Piecewise +from sympy.logic.boolalg import And +from sympy.polys.polytools import poly +from sympy.series.series import series + +from sympy.polys.polyerrors import PolynomialError +from sympy.stats.crv import reduce_rational_inequalities_wrap +from sympy.stats.rv import (NamedArgsMixin, SinglePSpace, SingleDomain, + random_symbols, PSpace, ConditionalDomain, RandomDomain, + ProductDomain, Distribution) +from sympy.stats.symbolic_probability import Probability +from sympy.sets.fancysets import Range, FiniteSet +from sympy.sets.sets import Union +from sympy.sets.contains import Contains +from sympy.utilities import filldedent +from sympy.core.sympify import _sympify + + +class DiscreteDistribution(Distribution): + def __call__(self, *args): + return self.pdf(*args) + + +class SingleDiscreteDistribution(DiscreteDistribution, NamedArgsMixin): + """ Discrete distribution of a single variable. + + Serves as superclass for PoissonDistribution etc.... + + Provides methods for pdf, cdf, and sampling + + See Also: + sympy.stats.crv_types.* + """ + + set = S.Integers + + def __new__(cls, *args): + args = list(map(sympify, args)) + return Basic.__new__(cls, *args) + + @staticmethod + def check(*args): + pass + + @cacheit + def compute_cdf(self, **kwargs): + """ Compute the CDF from the PDF. + + Returns a Lambda. + """ + x = symbols('x', integer=True, cls=Dummy) + z = symbols('z', real=True, cls=Dummy) + left_bound = self.set.inf + + # CDF is integral of PDF from left bound to z + pdf = self.pdf(x) + cdf = summation(pdf, (x, left_bound, floor(z)), **kwargs) + # CDF Ensure that CDF left of left_bound is zero + cdf = Piecewise((cdf, z >= left_bound), (0, True)) + return Lambda(z, cdf) + + def _cdf(self, x): + return None + + def cdf(self, x, **kwargs): + """ Cumulative density function """ + if not kwargs: + cdf = self._cdf(x) + if cdf is not None: + return cdf + return self.compute_cdf(**kwargs)(x) + + @cacheit + def compute_characteristic_function(self, **kwargs): + """ Compute the characteristic function from the PDF. + + Returns a Lambda. + """ + x, t = symbols('x, t', real=True, cls=Dummy) + pdf = self.pdf(x) + cf = summation(exp(I*t*x)*pdf, (x, self.set.inf, self.set.sup)) + return Lambda(t, cf) + + def _characteristic_function(self, t): + return None + + def characteristic_function(self, t, **kwargs): + """ Characteristic function """ + if not kwargs: + cf = self._characteristic_function(t) + if cf is not None: + return cf + return self.compute_characteristic_function(**kwargs)(t) + + @cacheit + def compute_moment_generating_function(self, **kwargs): + t = Dummy('t', real=True) + x = Dummy('x', integer=True) + pdf = self.pdf(x) + mgf = summation(exp(t*x)*pdf, (x, self.set.inf, self.set.sup)) + return Lambda(t, mgf) + + def _moment_generating_function(self, t): + return None + + def moment_generating_function(self, t, **kwargs): + if not kwargs: + mgf = self._moment_generating_function(t) + if mgf is not None: + return mgf + return self.compute_moment_generating_function(**kwargs)(t) + + @cacheit + def compute_quantile(self, **kwargs): + """ Compute the Quantile from the PDF. + + Returns a Lambda. + """ + x = Dummy('x', integer=True) + p = Dummy('p', real=True) + left_bound = self.set.inf + pdf = self.pdf(x) + cdf = summation(pdf, (x, left_bound, x), **kwargs) + set = ((x, p <= cdf), ) + return Lambda(p, Piecewise(*set)) + + def _quantile(self, x): + return None + + def quantile(self, x, **kwargs): + """ Cumulative density function """ + if not kwargs: + quantile = self._quantile(x) + if quantile is not None: + return quantile + return self.compute_quantile(**kwargs)(x) + + def expectation(self, expr, var, evaluate=True, **kwargs): + """ Expectation of expression over distribution """ + # TODO: support discrete sets with non integer stepsizes + + if evaluate: + try: + p = poly(expr, var) + + t = Dummy('t', real=True) + + mgf = self.moment_generating_function(t) + deg = p.degree() + taylor = poly(series(mgf, t, 0, deg + 1).removeO(), t) + result = 0 + for k in range(deg+1): + result += p.coeff_monomial(var ** k) * taylor.coeff_monomial(t ** k) * factorial(k) + + return result + + except PolynomialError: + return summation(expr * self.pdf(var), + (var, self.set.inf, self.set.sup), **kwargs) + + else: + return Sum(expr * self.pdf(var), + (var, self.set.inf, self.set.sup), **kwargs) + + def __call__(self, *args): + return self.pdf(*args) + + +class DiscreteDomain(RandomDomain): + """ + A domain with discrete support with step size one. + Represented using symbols and Range. + """ + is_Discrete = True + +class SingleDiscreteDomain(DiscreteDomain, SingleDomain): + def as_boolean(self): + return Contains(self.symbol, self.set) + + +class ConditionalDiscreteDomain(DiscreteDomain, ConditionalDomain): + """ + Domain with discrete support of step size one, that is restricted by + some condition. + """ + @property + def set(self): + rv = self.symbols + if len(self.symbols) > 1: + raise NotImplementedError(filldedent(''' + Multivariate conditional domains are not yet implemented.''')) + rv = list(rv)[0] + return reduce_rational_inequalities_wrap(self.condition, + rv).intersect(self.fulldomain.set) + + +class DiscretePSpace(PSpace): + is_real = True + is_Discrete = True + + @property + def pdf(self): + return self.density(*self.symbols) + + def where(self, condition): + rvs = random_symbols(condition) + assert all(r.symbol in self.symbols for r in rvs) + if len(rvs) > 1: + raise NotImplementedError(filldedent('''Multivariate discrete + random variables are not yet supported.''')) + conditional_domain = reduce_rational_inequalities_wrap(condition, + rvs[0]) + conditional_domain = conditional_domain.intersect(self.domain.set) + return SingleDiscreteDomain(rvs[0].symbol, conditional_domain) + + def probability(self, condition): + complement = isinstance(condition, Ne) + if complement: + condition = Eq(condition.args[0], condition.args[1]) + try: + _domain = self.where(condition).set + if condition == False or _domain is S.EmptySet: + return S.Zero + if condition == True or _domain == self.domain.set: + return S.One + prob = self.eval_prob(_domain) + except NotImplementedError: + from sympy.stats.rv import density + expr = condition.lhs - condition.rhs + dens = density(expr) + if not isinstance(dens, DiscreteDistribution): + from sympy.stats.drv_types import DiscreteDistributionHandmade + dens = DiscreteDistributionHandmade(dens) + z = Dummy('z', real=True) + space = SingleDiscretePSpace(z, dens) + prob = space.probability(condition.__class__(space.value, 0)) + if prob is None: + prob = Probability(condition) + return prob if not complement else S.One - prob + + def eval_prob(self, _domain): + sym = list(self.symbols)[0] + if isinstance(_domain, Range): + n = symbols('n', integer=True) + inf, sup, step = (r for r in _domain.args) + summand = ((self.pdf).replace( + sym, n*step)) + rv = summation(summand, + (n, inf/step, (sup)/step - 1)).doit() + return rv + elif isinstance(_domain, FiniteSet): + pdf = Lambda(sym, self.pdf) + rv = sum(pdf(x) for x in _domain) + return rv + elif isinstance(_domain, Union): + rv = sum(self.eval_prob(x) for x in _domain.args) + return rv + + def conditional_space(self, condition): + # XXX: Converting from set to tuple. The order matters to Lambda + # though so we should be starting with a set... + density = Lambda(tuple(self.symbols), self.pdf/self.probability(condition)) + condition = condition.xreplace({rv: rv.symbol for rv in self.values}) + domain = ConditionalDiscreteDomain(self.domain, condition) + return DiscretePSpace(domain, density) + +class ProductDiscreteDomain(ProductDomain, DiscreteDomain): + def as_boolean(self): + return And(*[domain.as_boolean for domain in self.domains]) + +class SingleDiscretePSpace(DiscretePSpace, SinglePSpace): + """ Discrete probability space over a single univariate variable """ + is_real = True + + @property + def set(self): + return self.distribution.set + + @property + def domain(self): + return SingleDiscreteDomain(self.symbol, self.set) + + def sample(self, size=(), library='scipy', seed=None): + """ + Internal sample method. + + Returns dictionary mapping RandomSymbol to realization value. + """ + return {self.value: self.distribution.sample(size, library=library, seed=seed)} + + def compute_expectation(self, expr, rvs=None, evaluate=True, **kwargs): + rvs = rvs or (self.value,) + if self.value not in rvs: + return expr + + expr = _sympify(expr) + expr = expr.xreplace({rv: rv.symbol for rv in rvs}) + + x = self.value.symbol + try: + return self.distribution.expectation(expr, x, evaluate=evaluate, + **kwargs) + except NotImplementedError: + return Sum(expr * self.pdf, (x, self.set.inf, self.set.sup), + **kwargs) + + def compute_cdf(self, expr, **kwargs): + if expr == self.value: + x = Dummy("x", real=True) + return Lambda(x, self.distribution.cdf(x, **kwargs)) + else: + raise NotImplementedError() + + def compute_density(self, expr, **kwargs): + if expr == self.value: + return self.distribution + raise NotImplementedError() + + def compute_characteristic_function(self, expr, **kwargs): + if expr == self.value: + t = Dummy("t", real=True) + return Lambda(t, self.distribution.characteristic_function(t, **kwargs)) + else: + raise NotImplementedError() + + def compute_moment_generating_function(self, expr, **kwargs): + if expr == self.value: + t = Dummy("t", real=True) + return Lambda(t, self.distribution.moment_generating_function(t, **kwargs)) + else: + raise NotImplementedError() + + def compute_quantile(self, expr, **kwargs): + if expr == self.value: + p = Dummy("p", real=True) + return Lambda(p, self.distribution.quantile(p, **kwargs)) + else: + raise NotImplementedError() diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/drv_types.py b/.venv/lib/python3.13/site-packages/sympy/stats/drv_types.py new file mode 100644 index 0000000000000000000000000000000000000000..84920d31c0083828efc2cd3f752d2c48f5430102 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/drv_types.py @@ -0,0 +1,849 @@ +""" + +Contains +======== +FlorySchulz +Geometric +Hermite +Logarithmic +NegativeBinomial +Poisson +Skellam +YuleSimon +Zeta +""" + + + +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.function import Lambda +from sympy.core.numbers import I +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (binomial, factorial, FallingFactorial) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.bessel import besseli +from sympy.functions.special.beta_functions import beta +from sympy.functions.special.hyper import hyper +from sympy.functions.special.zeta_functions import (polylog, zeta) +from sympy.stats.drv import SingleDiscreteDistribution, SingleDiscretePSpace +from sympy.stats.rv import _value_check, is_random + + +__all__ = ['FlorySchulz', +'Geometric', +'Hermite', +'Logarithmic', +'NegativeBinomial', +'Poisson', +'Skellam', +'YuleSimon', +'Zeta' +] + + +def rv(symbol, cls, *args, **kwargs): + args = list(map(sympify, args)) + dist = cls(*args) + if kwargs.pop('check', True): + dist.check(*args) + pspace = SingleDiscretePSpace(symbol, dist) + if any(is_random(arg) for arg in args): + from sympy.stats.compound_rv import CompoundPSpace, CompoundDistribution + pspace = CompoundPSpace(symbol, CompoundDistribution(dist)) + return pspace.value + + +class DiscreteDistributionHandmade(SingleDiscreteDistribution): + _argnames = ('pdf',) + + def __new__(cls, pdf, set=S.Integers): + return Basic.__new__(cls, pdf, set) + + @property + def set(self): + return self.args[1] + + @staticmethod + def check(pdf, set): + x = Dummy('x') + val = Sum(pdf(x), (x, set._inf, set._sup)).doit() + _value_check(Eq(val, 1) != S.false, "The pdf is incorrect on the given set.") + + + +def DiscreteRV(symbol, density, set=S.Integers, **kwargs): + """ + Create a Discrete Random Variable given the following: + + Parameters + ========== + + symbol : Symbol + Represents name of the random variable. + density : Expression containing symbol + Represents probability density function. + set : set + Represents the region where the pdf is valid, by default is real line. + check : bool + If True, it will check whether the given density + integrates to 1 over the given set. If False, it + will not perform this check. Default is False. + + Examples + ======== + + >>> from sympy.stats import DiscreteRV, P, E + >>> from sympy import Rational, Symbol + >>> x = Symbol('x') + >>> n = 10 + >>> density = Rational(1, 10) + >>> X = DiscreteRV(x, density, set=set(range(n))) + >>> E(X) + 9/2 + >>> P(X>3) + 3/5 + + Returns + ======= + + RandomSymbol + + """ + set = sympify(set) + pdf = Piecewise((density, set.as_relational(symbol)), (0, True)) + pdf = Lambda(symbol, pdf) + # have a default of False while `rv` should have a default of True + kwargs['check'] = kwargs.pop('check', False) + return rv(symbol.name, DiscreteDistributionHandmade, pdf, set, **kwargs) + + +#------------------------------------------------------------------------------- +# Flory-Schulz distribution ------------------------------------------------------------ + +class FlorySchulzDistribution(SingleDiscreteDistribution): + _argnames = ('a',) + set = S.Naturals + + @staticmethod + def check(a): + _value_check((0 < a, a < 1), "a must be between 0 and 1") + + def pdf(self, k): + a = self.a + return (a**2 * k * (1 - a)**(k - 1)) + + def _characteristic_function(self, t): + a = self.a + return a**2*exp(I*t)/((1 + (a - 1)*exp(I*t))**2) + + def _moment_generating_function(self, t): + a = self.a + return a**2*exp(t)/((1 + (a - 1)*exp(t))**2) + + +def FlorySchulz(name, a): + r""" + Create a discrete random variable with a FlorySchulz distribution. + + The density of the FlorySchulz distribution is given by + + .. math:: + f(k) := (a^2) k (1 - a)^{k-1} + + Parameters + ========== + + a : A real number between 0 and 1 + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, E, variance, FlorySchulz + >>> from sympy import Symbol, S + + >>> a = S.One / 5 + >>> z = Symbol("z") + + >>> X = FlorySchulz("x", a) + + >>> density(X)(z) + (4/5)**(z - 1)*z/25 + + >>> E(X) + 9 + + >>> variance(X) + 40 + + References + ========== + + https://en.wikipedia.org/wiki/Flory%E2%80%93Schulz_distribution + """ + return rv(name, FlorySchulzDistribution, a) + + +#------------------------------------------------------------------------------- +# Geometric distribution ------------------------------------------------------------ + +class GeometricDistribution(SingleDiscreteDistribution): + _argnames = ('p',) + set = S.Naturals + + @staticmethod + def check(p): + _value_check((0 < p, p <= 1), "p must be between 0 and 1") + + def pdf(self, k): + return (1 - self.p)**(k - 1) * self.p + + def _characteristic_function(self, t): + p = self.p + return p * exp(I*t) / (1 - (1 - p)*exp(I*t)) + + def _moment_generating_function(self, t): + p = self.p + return p * exp(t) / (1 - (1 - p) * exp(t)) + + +def Geometric(name, p): + r""" + Create a discrete random variable with a Geometric distribution. + + Explanation + =========== + + The density of the Geometric distribution is given by + + .. math:: + f(k) := p (1 - p)^{k - 1} + + Parameters + ========== + + p : A probability between 0 and 1 + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Geometric, density, E, variance + >>> from sympy import Symbol, S + + >>> p = S.One / 5 + >>> z = Symbol("z") + + >>> X = Geometric("x", p) + + >>> density(X)(z) + (4/5)**(z - 1)/5 + + >>> E(X) + 5 + + >>> variance(X) + 20 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Geometric_distribution + .. [2] https://mathworld.wolfram.com/GeometricDistribution.html + + """ + return rv(name, GeometricDistribution, p) + + +#------------------------------------------------------------------------------- +# Hermite distribution --------------------------------------------------------- + + +class HermiteDistribution(SingleDiscreteDistribution): + _argnames = ('a1', 'a2') + set = S.Naturals0 + + @staticmethod + def check(a1, a2): + _value_check(a1.is_nonnegative, 'Parameter a1 must be >= 0.') + _value_check(a2.is_nonnegative, 'Parameter a2 must be >= 0.') + + def pdf(self, k): + a1, a2 = self.a1, self.a2 + term1 = exp(-(a1 + a2)) + j = Dummy("j", integer=True) + num = a1**(k - 2*j) * a2**j + den = factorial(k - 2*j) * factorial(j) + return term1 * Sum(num/den, (j, 0, k//2)).doit() + + def _moment_generating_function(self, t): + a1, a2 = self.a1, self.a2 + term1 = a1 * (exp(t) - 1) + term2 = a2 * (exp(2*t) - 1) + return exp(term1 + term2) + + def _characteristic_function(self, t): + a1, a2 = self.a1, self.a2 + term1 = a1 * (exp(I*t) - 1) + term2 = a2 * (exp(2*I*t) - 1) + return exp(term1 + term2) + +def Hermite(name, a1, a2): + r""" + Create a discrete random variable with a Hermite distribution. + + Explanation + =========== + + The density of the Hermite distribution is given by + + .. math:: + f(x):= e^{-a_1 -a_2}\sum_{j=0}^{\left \lfloor x/2 \right \rfloor} + \frac{a_{1}^{x-2j}a_{2}^{j}}{(x-2j)!j!} + + Parameters + ========== + + a1 : A Positive number greater than equal to 0. + a2 : A Positive number greater than equal to 0. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Hermite, density, E, variance + >>> from sympy import Symbol + + >>> a1 = Symbol("a1", positive=True) + >>> a2 = Symbol("a2", positive=True) + >>> x = Symbol("x") + + >>> H = Hermite("H", a1=5, a2=4) + + >>> density(H)(2) + 33*exp(-9)/2 + + >>> E(H) + 13 + + >>> variance(H) + 21 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hermite_distribution + + """ + + return rv(name, HermiteDistribution, a1, a2) + + +#------------------------------------------------------------------------------- +# Logarithmic distribution ------------------------------------------------------------ + +class LogarithmicDistribution(SingleDiscreteDistribution): + _argnames = ('p',) + + set = S.Naturals + + @staticmethod + def check(p): + _value_check((p > 0, p < 1), "p should be between 0 and 1") + + def pdf(self, k): + p = self.p + return (-1) * p**k / (k * log(1 - p)) + + def _characteristic_function(self, t): + p = self.p + return log(1 - p * exp(I*t)) / log(1 - p) + + def _moment_generating_function(self, t): + p = self.p + return log(1 - p * exp(t)) / log(1 - p) + + +def Logarithmic(name, p): + r""" + Create a discrete random variable with a Logarithmic distribution. + + Explanation + =========== + + The density of the Logarithmic distribution is given by + + .. math:: + f(k) := \frac{-p^k}{k \ln{(1 - p)}} + + Parameters + ========== + + p : A value between 0 and 1 + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Logarithmic, density, E, variance + >>> from sympy import Symbol, S + + >>> p = S.One / 5 + >>> z = Symbol("z") + + >>> X = Logarithmic("x", p) + + >>> density(X)(z) + -1/(5**z*z*log(4/5)) + + >>> E(X) + -1/(-4*log(5) + 8*log(2)) + + >>> variance(X) + -1/((-4*log(5) + 8*log(2))*(-2*log(5) + 4*log(2))) + 1/(-64*log(2)*log(5) + 64*log(2)**2 + 16*log(5)**2) - 10/(-32*log(5) + 64*log(2)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Logarithmic_distribution + .. [2] https://mathworld.wolfram.com/LogarithmicDistribution.html + + """ + return rv(name, LogarithmicDistribution, p) + + +#------------------------------------------------------------------------------- +# Negative binomial distribution ------------------------------------------------------------ + +class NegativeBinomialDistribution(SingleDiscreteDistribution): + _argnames = ('r', 'p') + set = S.Naturals0 + + @staticmethod + def check(r, p): + _value_check(r > 0, 'r should be positive') + _value_check((p > 0, p < 1), 'p should be between 0 and 1') + + def pdf(self, k): + r = self.r + p = self.p + + return binomial(k + r - 1, k) * (1 - p)**k * p**r + + def _characteristic_function(self, t): + r = self.r + p = self.p + + return (p / (1 - (1 - p) * exp(I*t)))**r + + def _moment_generating_function(self, t): + r = self.r + p = self.p + + return (p / (1 - (1 - p) * exp(t)))**r + +def NegativeBinomial(name, r, p): + r""" + Create a discrete random variable with a Negative Binomial distribution. + + Explanation + =========== + + The density of the Negative Binomial distribution is given by + + .. math:: + f(k) := \binom{k + r - 1}{k} (1 - p)^k p^r + + Parameters + ========== + + r : A positive value + Number of successes until the experiment is stopped. + p : A value between 0 and 1 + Probability of success. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import NegativeBinomial, density, E, variance + >>> from sympy import Symbol, S + + >>> r = 5 + >>> p = S.One / 3 + >>> z = Symbol("z") + + >>> X = NegativeBinomial("x", r, p) + + >>> density(X)(z) + (2/3)**z*binomial(z + 4, z)/243 + + >>> E(X) + 10 + + >>> variance(X) + 30 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Negative_binomial_distribution + .. [2] https://mathworld.wolfram.com/NegativeBinomialDistribution.html + + """ + return rv(name, NegativeBinomialDistribution, r, p) + + +#------------------------------------------------------------------------------- +# Poisson distribution ------------------------------------------------------------ + +class PoissonDistribution(SingleDiscreteDistribution): + _argnames = ('lamda',) + + set = S.Naturals0 + + @staticmethod + def check(lamda): + _value_check(lamda > 0, "Lambda must be positive") + + def pdf(self, k): + return self.lamda**k / factorial(k) * exp(-self.lamda) + + def _characteristic_function(self, t): + return exp(self.lamda * (exp(I*t) - 1)) + + def _moment_generating_function(self, t): + return exp(self.lamda * (exp(t) - 1)) + + def expectation(self, expr, var, evaluate=True, **kwargs): + if evaluate: + if expr == var: + return self.lamda + if ( + isinstance(expr, FallingFactorial) + and expr.args[1].is_integer + and expr.args[1].is_positive + and expr.args[0] == var + ): + return self.lamda ** expr.args[1] + return super().expectation(expr, var, evaluate, **kwargs) + +def Poisson(name, lamda): + r""" + Create a discrete random variable with a Poisson distribution. + + Explanation + =========== + + The density of the Poisson distribution is given by + + .. math:: + f(k) := \frac{\lambda^{k} e^{- \lambda}}{k!} + + Parameters + ========== + + lamda : Positive number, a rate + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Poisson, density, E, variance + >>> from sympy import Symbol, simplify + + >>> rate = Symbol("lambda", positive=True) + >>> z = Symbol("z") + + >>> X = Poisson("x", rate) + + >>> density(X)(z) + lambda**z*exp(-lambda)/factorial(z) + + >>> E(X) + lambda + + >>> simplify(variance(X)) + lambda + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Poisson_distribution + .. [2] https://mathworld.wolfram.com/PoissonDistribution.html + + """ + return rv(name, PoissonDistribution, lamda) + + +# ----------------------------------------------------------------------------- +# Skellam distribution -------------------------------------------------------- + + +class SkellamDistribution(SingleDiscreteDistribution): + _argnames = ('mu1', 'mu2') + set = S.Integers + + @staticmethod + def check(mu1, mu2): + _value_check(mu1 >= 0, 'Parameter mu1 must be >= 0') + _value_check(mu2 >= 0, 'Parameter mu2 must be >= 0') + + def pdf(self, k): + (mu1, mu2) = (self.mu1, self.mu2) + term1 = exp(-(mu1 + mu2)) * (mu1 / mu2) ** (k / 2) + term2 = besseli(k, 2 * sqrt(mu1 * mu2)) + return term1 * term2 + + def _cdf(self, x): + raise NotImplementedError( + "Skellam doesn't have closed form for the CDF.") + + def _characteristic_function(self, t): + (mu1, mu2) = (self.mu1, self.mu2) + return exp(-(mu1 + mu2) + mu1 * exp(I * t) + mu2 * exp(-I * t)) + + def _moment_generating_function(self, t): + (mu1, mu2) = (self.mu1, self.mu2) + return exp(-(mu1 + mu2) + mu1 * exp(t) + mu2 * exp(-t)) + + +def Skellam(name, mu1, mu2): + r""" + Create a discrete random variable with a Skellam distribution. + + Explanation + =========== + + The Skellam is the distribution of the difference N1 - N2 + of two statistically independent random variables N1 and N2 + each Poisson-distributed with respective expected values mu1 and mu2. + + The density of the Skellam distribution is given by + + .. math:: + f(k) := e^{-(\mu_1+\mu_2)}(\frac{\mu_1}{\mu_2})^{k/2}I_k(2\sqrt{\mu_1\mu_2}) + + Parameters + ========== + + mu1 : A non-negative value + mu2 : A non-negative value + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Skellam, density, E, variance + >>> from sympy import Symbol, pprint + + >>> z = Symbol("z", integer=True) + >>> mu1 = Symbol("mu1", positive=True) + >>> mu2 = Symbol("mu2", positive=True) + >>> X = Skellam("x", mu1, mu2) + + >>> pprint(density(X)(z), use_unicode=False) + z + - + 2 + /mu1\ -mu1 - mu2 / _____ _____\ + |---| *e *besseli\z, 2*\/ mu1 *\/ mu2 / + \mu2/ + >>> E(X) + mu1 - mu2 + >>> variance(X).expand() + mu1 + mu2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Skellam_distribution + + """ + return rv(name, SkellamDistribution, mu1, mu2) + + +#------------------------------------------------------------------------------- +# Yule-Simon distribution ------------------------------------------------------------ + +class YuleSimonDistribution(SingleDiscreteDistribution): + _argnames = ('rho',) + set = S.Naturals + + @staticmethod + def check(rho): + _value_check(rho > 0, 'rho should be positive') + + def pdf(self, k): + rho = self.rho + return rho * beta(k, rho + 1) + + def _cdf(self, x): + return Piecewise((1 - floor(x) * beta(floor(x), self.rho + 1), x >= 1), (0, True)) + + def _characteristic_function(self, t): + rho = self.rho + return rho * hyper((1, 1), (rho + 2,), exp(I*t)) * exp(I*t) / (rho + 1) + + def _moment_generating_function(self, t): + rho = self.rho + return rho * hyper((1, 1), (rho + 2,), exp(t)) * exp(t) / (rho + 1) + + +def YuleSimon(name, rho): + r""" + Create a discrete random variable with a Yule-Simon distribution. + + Explanation + =========== + + The density of the Yule-Simon distribution is given by + + .. math:: + f(k) := \rho B(k, \rho + 1) + + Parameters + ========== + + rho : A positive value + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import YuleSimon, density, E, variance + >>> from sympy import Symbol, simplify + + >>> p = 5 + >>> z = Symbol("z") + + >>> X = YuleSimon("x", p) + + >>> density(X)(z) + 5*beta(z, 6) + + >>> simplify(E(X)) + 5/4 + + >>> simplify(variance(X)) + 25/48 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Yule%E2%80%93Simon_distribution + + """ + return rv(name, YuleSimonDistribution, rho) + + +#------------------------------------------------------------------------------- +# Zeta distribution ------------------------------------------------------------ + +class ZetaDistribution(SingleDiscreteDistribution): + _argnames = ('s',) + set = S.Naturals + + @staticmethod + def check(s): + _value_check(s > 1, 's should be greater than 1') + + def pdf(self, k): + s = self.s + return 1 / (k**s * zeta(s)) + + def _characteristic_function(self, t): + return polylog(self.s, exp(I*t)) / zeta(self.s) + + def _moment_generating_function(self, t): + return polylog(self.s, exp(t)) / zeta(self.s) + + +def Zeta(name, s): + r""" + Create a discrete random variable with a Zeta distribution. + + Explanation + =========== + + The density of the Zeta distribution is given by + + .. math:: + f(k) := \frac{1}{k^s \zeta{(s)}} + + Parameters + ========== + + s : A value greater than 1 + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Zeta, density, E, variance + >>> from sympy import Symbol + + >>> s = 5 + >>> z = Symbol("z") + + >>> X = Zeta("x", s) + + >>> density(X)(z) + 1/(z**5*zeta(5)) + + >>> E(X) + pi**4/(90*zeta(5)) + + >>> variance(X) + -pi**8/(8100*zeta(5)**2) + zeta(3)/zeta(5) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Zeta_distribution + + """ + return rv(name, ZetaDistribution, s) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/error_prop.py b/.venv/lib/python3.13/site-packages/sympy/stats/error_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..e6cacb894307fe60cbf096c7760e6ed57f385a91 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/error_prop.py @@ -0,0 +1,100 @@ +"""Tools for arithmetic error propagation.""" + +from itertools import repeat, combinations + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.simplify.simplify import simplify +from sympy.stats.symbolic_probability import RandomSymbol, Variance, Covariance +from sympy.stats.rv import is_random + +_arg0_or_var = lambda var: var.args[0] if len(var.args) > 0 else var + + +def variance_prop(expr, consts=(), include_covar=False): + r"""Symbolically propagates variance (`\sigma^2`) for expressions. + This is computed as as seen in [1]_. + + Parameters + ========== + + expr : Expr + A SymPy expression to compute the variance for. + consts : sequence of Symbols, optional + Represents symbols that are known constants in the expr, + and thus have zero variance. All symbols not in consts are + assumed to be variant. + include_covar : bool, optional + Flag for whether or not to include covariances, default=False. + + Returns + ======= + + var_expr : Expr + An expression for the total variance of the expr. + The variance for the original symbols (e.g. x) are represented + via instance of the Variance symbol (e.g. Variance(x)). + + Examples + ======== + + >>> from sympy import symbols, exp + >>> from sympy.stats.error_prop import variance_prop + >>> x, y = symbols('x y') + + >>> variance_prop(x + y) + Variance(x) + Variance(y) + + >>> variance_prop(x * y) + x**2*Variance(y) + y**2*Variance(x) + + >>> variance_prop(exp(2*x)) + 4*exp(4*x)*Variance(x) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Propagation_of_uncertainty + + """ + args = expr.args + if len(args) == 0: + if expr in consts: + return S.Zero + elif is_random(expr): + return Variance(expr).doit() + elif isinstance(expr, Symbol): + return Variance(RandomSymbol(expr)).doit() + else: + return S.Zero + nargs = len(args) + var_args = list(map(variance_prop, args, repeat(consts, nargs), + repeat(include_covar, nargs))) + if isinstance(expr, Add): + var_expr = Add(*var_args) + if include_covar: + terms = [2 * Covariance(_arg0_or_var(x), _arg0_or_var(y)).expand() \ + for x, y in combinations(var_args, 2)] + var_expr += Add(*terms) + elif isinstance(expr, Mul): + terms = [v/a**2 for a, v in zip(args, var_args)] + var_expr = simplify(expr**2 * Add(*terms)) + if include_covar: + terms = [2*Covariance(_arg0_or_var(x), _arg0_or_var(y)).expand()/(a*b) \ + for (a, b), (x, y) in zip(combinations(args, 2), + combinations(var_args, 2))] + var_expr += Add(*terms) + elif isinstance(expr, Pow): + b = args[1] + v = var_args[0] * (expr * b / args[0])**2 + var_expr = simplify(v) + elif isinstance(expr, exp): + var_expr = simplify(var_args[0] * expr**2) + else: + # unknown how to proceed, return variance of whole expr. + var_expr = Variance(expr) + return var_expr diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/frv.py b/.venv/lib/python3.13/site-packages/sympy/stats/frv.py new file mode 100644 index 0000000000000000000000000000000000000000..498d7e4006b2b8db306a0905ed67578021e220a8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/frv.py @@ -0,0 +1,512 @@ +""" +Finite Discrete Random Variables Module + +See Also +======== +sympy.stats.frv_types +sympy.stats.rv +sympy.stats.crv +""" +from itertools import product + +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.function import Lambda +from sympy.core.mul import Mul +from sympy.core.numbers import (I, nan) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.core.sympify import sympify +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.piecewise import Piecewise +from sympy.logic.boolalg import (And, Or) +from sympy.sets.sets import Intersection +from sympy.core.containers import Dict +from sympy.core.logic import Logic +from sympy.core.relational import Relational +from sympy.core.sympify import _sympify +from sympy.sets.sets import FiniteSet +from sympy.stats.rv import (RandomDomain, ProductDomain, ConditionalDomain, + PSpace, IndependentProductPSpace, SinglePSpace, random_symbols, + sumsets, rv_subs, NamedArgsMixin, Density, Distribution) + + +class FiniteDensity(dict): + """ + A domain with Finite Density. + """ + def __call__(self, item): + """ + Make instance of a class callable. + + If item belongs to current instance of a class, return it. + + Otherwise, return 0. + """ + item = sympify(item) + if item in self: + return self[item] + else: + return 0 + + @property + def dict(self): + """ + Return item as dictionary. + """ + return dict(self) + +class FiniteDomain(RandomDomain): + """ + A domain with discrete finite support + + Represented using a FiniteSet. + """ + is_Finite = True + + @property + def symbols(self): + return FiniteSet(sym for sym, val in self.elements) + + @property + def elements(self): + return self.args[0] + + @property + def dict(self): + return FiniteSet(*[Dict(dict(el)) for el in self.elements]) + + def __contains__(self, other): + return other in self.elements + + def __iter__(self): + return self.elements.__iter__() + + def as_boolean(self): + return Or(*[And(*[Eq(sym, val) for sym, val in item]) for item in self]) + + +class SingleFiniteDomain(FiniteDomain): + """ + A FiniteDomain over a single symbol/set + + Example: The possibilities of a *single* die roll. + """ + + def __new__(cls, symbol, set): + if not isinstance(set, FiniteSet) and \ + not isinstance(set, Intersection): + set = FiniteSet(*set) + return Basic.__new__(cls, symbol, set) + + @property + def symbol(self): + return self.args[0] + + @property + def symbols(self): + return FiniteSet(self.symbol) + + @property + def set(self): + return self.args[1] + + @property + def elements(self): + return FiniteSet(*[frozenset(((self.symbol, elem), )) for elem in self.set]) + + def __iter__(self): + return (frozenset(((self.symbol, elem),)) for elem in self.set) + + def __contains__(self, other): + sym, val = tuple(other)[0] + return sym == self.symbol and val in self.set + + +class ProductFiniteDomain(ProductDomain, FiniteDomain): + """ + A Finite domain consisting of several other FiniteDomains + + Example: The possibilities of the rolls of three independent dice + """ + + def __iter__(self): + proditer = product(*self.domains) + return (sumsets(items) for items in proditer) + + @property + def elements(self): + return FiniteSet(*self) + + +class ConditionalFiniteDomain(ConditionalDomain, ProductFiniteDomain): + """ + A FiniteDomain that has been restricted by a condition + + Example: The possibilities of a die roll under the condition that the + roll is even. + """ + + def __new__(cls, domain, condition): + """ + Create a new instance of ConditionalFiniteDomain class + """ + if condition is True: + return domain + cond = rv_subs(condition) + return Basic.__new__(cls, domain, cond) + + def _test(self, elem): + """ + Test the value. If value is boolean, return it. If value is equality + relational (two objects are equal), return it with left-hand side + being equal to right-hand side. Otherwise, raise ValueError exception. + """ + val = self.condition.xreplace(dict(elem)) + if val in [True, False]: + return val + elif val.is_Equality: + return val.lhs == val.rhs + raise ValueError("Undecidable if %s" % str(val)) + + def __contains__(self, other): + return other in self.fulldomain and self._test(other) + + def __iter__(self): + return (elem for elem in self.fulldomain if self._test(elem)) + + @property + def set(self): + if isinstance(self.fulldomain, SingleFiniteDomain): + return FiniteSet(*[elem for elem in self.fulldomain.set + if frozenset(((self.fulldomain.symbol, elem),)) in self]) + else: + raise NotImplementedError( + "Not implemented on multi-dimensional conditional domain") + + def as_boolean(self): + return FiniteDomain.as_boolean(self) + + +class SingleFiniteDistribution(Distribution, NamedArgsMixin): + def __new__(cls, *args): + args = list(map(sympify, args)) + return Basic.__new__(cls, *args) + + @staticmethod + def check(*args): + pass + + @property # type: ignore + @cacheit + def dict(self): + if self.is_symbolic: + return Density(self) + return {k: self.pmf(k) for k in self.set} + + def pmf(self, *args): # to be overridden by specific distribution + raise NotImplementedError() + + @property + def set(self): # to be overridden by specific distribution + raise NotImplementedError() + + values = property(lambda self: self.dict.values) + items = property(lambda self: self.dict.items) + is_symbolic = property(lambda self: False) + __iter__ = property(lambda self: self.dict.__iter__) + __getitem__ = property(lambda self: self.dict.__getitem__) + + def __call__(self, *args): + return self.pmf(*args) + + def __contains__(self, other): + return other in self.set + + +#============================================= +#========= Probability Space =============== +#============================================= + + +class FinitePSpace(PSpace): + """ + A Finite Probability Space + + Represents the probabilities of a finite number of events. + """ + is_Finite = True + + def __new__(cls, domain, density): + density = {sympify(key): sympify(val) + for key, val in density.items()} + public_density = Dict(density) + + obj = PSpace.__new__(cls, domain, public_density) + obj._density = density + return obj + + def prob_of(self, elem): + elem = sympify(elem) + density = self._density + if isinstance(list(density.keys())[0], FiniteSet): + return density.get(elem, S.Zero) + return density.get(tuple(elem)[0][1], S.Zero) + + def where(self, condition): + assert all(r.symbol in self.symbols for r in random_symbols(condition)) + return ConditionalFiniteDomain(self.domain, condition) + + def compute_density(self, expr): + expr = rv_subs(expr, self.values) + d = FiniteDensity() + for elem in self.domain: + val = expr.xreplace(dict(elem)) + prob = self.prob_of(elem) + d[val] = d.get(val, S.Zero) + prob + return d + + @cacheit + def compute_cdf(self, expr): + d = self.compute_density(expr) + cum_prob = S.Zero + cdf = [] + for key in sorted(d): + prob = d[key] + cum_prob += prob + cdf.append((key, cum_prob)) + + return dict(cdf) + + @cacheit + def sorted_cdf(self, expr, python_float=False): + cdf = self.compute_cdf(expr) + items = list(cdf.items()) + sorted_items = sorted(items, key=lambda val_cumprob: val_cumprob[1]) + if python_float: + sorted_items = [(v, float(cum_prob)) + for v, cum_prob in sorted_items] + return sorted_items + + @cacheit + def compute_characteristic_function(self, expr): + d = self.compute_density(expr) + t = Dummy('t', real=True) + + return Lambda(t, sum(exp(I*k*t)*v for k,v in d.items())) + + @cacheit + def compute_moment_generating_function(self, expr): + d = self.compute_density(expr) + t = Dummy('t', real=True) + + return Lambda(t, sum(exp(k*t)*v for k,v in d.items())) + + def compute_expectation(self, expr, rvs=None, **kwargs): + rvs = rvs or self.values + expr = rv_subs(expr, rvs) + probs = [self.prob_of(elem) for elem in self.domain] + if isinstance(expr, (Logic, Relational)): + parse_domain = [tuple(elem)[0][1] for elem in self.domain] + bools = [expr.xreplace(dict(elem)) for elem in self.domain] + else: + parse_domain = [expr.xreplace(dict(elem)) for elem in self.domain] + bools = [True for elem in self.domain] + return sum(Piecewise((prob * elem, blv), (S.Zero, True)) + for prob, elem, blv in zip(probs, parse_domain, bools)) + + def compute_quantile(self, expr): + cdf = self.compute_cdf(expr) + p = Dummy('p', real=True) + set = ((nan, (p < 0) | (p > 1)),) + for key, value in cdf.items(): + set = set + ((key, p <= value), ) + return Lambda(p, Piecewise(*set)) + + def probability(self, condition): + cond_symbols = frozenset(rs.symbol for rs in random_symbols(condition)) + cond = rv_subs(condition) + if not cond_symbols.issubset(self.symbols): + raise ValueError("Cannot compare foreign random symbols, %s" + %(str(cond_symbols - self.symbols))) + if isinstance(condition, Relational) and \ + (not cond.free_symbols.issubset(self.domain.free_symbols)): + rv = condition.lhs if isinstance(condition.rhs, Symbol) else condition.rhs + return sum(Piecewise( + (self.prob_of(elem), condition.subs(rv, list(elem)[0][1])), + (S.Zero, True)) for elem in self.domain) + return sympify(sum(self.prob_of(elem) for elem in self.where(condition))) + + def conditional_space(self, condition): + domain = self.where(condition) + prob = self.probability(condition) + density = {key: val / prob + for key, val in self._density.items() if domain._test(key)} + return FinitePSpace(domain, density) + + def sample(self, size=(), library='scipy', seed=None): + """ + Internal sample method + + Returns dictionary mapping RandomSymbol to realization value. + """ + return {self.value: self.distribution.sample(size, library, seed)} + + +class SingleFinitePSpace(SinglePSpace, FinitePSpace): + """ + A single finite probability space + + Represents the probabilities of a set of random events that can be + attributed to a single variable/symbol. + + This class is implemented by many of the standard FiniteRV types such as + Die, Bernoulli, Coin, etc.... + """ + @property + def domain(self): + return SingleFiniteDomain(self.symbol, self.distribution.set) + + @property + def _is_symbolic(self): + """ + Helper property to check if the distribution + of the random variable is having symbolic + dimension. + """ + return self.distribution.is_symbolic + + @property + def distribution(self): + return self.args[1] + + def pmf(self, expr): + return self.distribution.pmf(expr) + + @property # type: ignore + @cacheit + def _density(self): + return {FiniteSet((self.symbol, val)): prob + for val, prob in self.distribution.dict.items()} + + @cacheit + def compute_characteristic_function(self, expr): + if self._is_symbolic: + d = self.compute_density(expr) + t = Dummy('t', real=True) + ki = Dummy('ki') + return Lambda(t, Sum(d(ki)*exp(I*ki*t), (ki, self.args[1].low, self.args[1].high))) + expr = rv_subs(expr, self.values) + return FinitePSpace(self.domain, self.distribution).compute_characteristic_function(expr) + + @cacheit + def compute_moment_generating_function(self, expr): + if self._is_symbolic: + d = self.compute_density(expr) + t = Dummy('t', real=True) + ki = Dummy('ki') + return Lambda(t, Sum(d(ki)*exp(ki*t), (ki, self.args[1].low, self.args[1].high))) + expr = rv_subs(expr, self.values) + return FinitePSpace(self.domain, self.distribution).compute_moment_generating_function(expr) + + def compute_quantile(self, expr): + if self._is_symbolic: + raise NotImplementedError("Computing quantile for random variables " + "with symbolic dimension because the bounds of searching the required " + "value is undetermined.") + expr = rv_subs(expr, self.values) + return FinitePSpace(self.domain, self.distribution).compute_quantile(expr) + + def compute_density(self, expr): + if self._is_symbolic: + rv = list(random_symbols(expr))[0] + k = Dummy('k', integer=True) + cond = True if not isinstance(expr, (Relational, Logic)) \ + else expr.subs(rv, k) + return Lambda(k, + Piecewise((self.pmf(k), And(k >= self.args[1].low, + k <= self.args[1].high, cond)), (S.Zero, True))) + expr = rv_subs(expr, self.values) + return FinitePSpace(self.domain, self.distribution).compute_density(expr) + + def compute_cdf(self, expr): + if self._is_symbolic: + d = self.compute_density(expr) + k = Dummy('k') + ki = Dummy('ki') + return Lambda(k, Sum(d(ki), (ki, self.args[1].low, k))) + expr = rv_subs(expr, self.values) + return FinitePSpace(self.domain, self.distribution).compute_cdf(expr) + + def compute_expectation(self, expr, rvs=None, **kwargs): + if self._is_symbolic: + rv = random_symbols(expr)[0] + k = Dummy('k', integer=True) + expr = expr.subs(rv, k) + cond = True if not isinstance(expr, (Relational, Logic)) \ + else expr + func = self.pmf(k) * k if cond != True else self.pmf(k) * expr + return Sum(Piecewise((func, cond), (S.Zero, True)), + (k, self.distribution.low, self.distribution.high)).doit() + + expr = _sympify(expr) + expr = rv_subs(expr, rvs) + return FinitePSpace(self.domain, self.distribution).compute_expectation(expr, rvs, **kwargs) + + def probability(self, condition): + if self._is_symbolic: + #TODO: Implement the mechanism for handling queries for symbolic sized distributions. + raise NotImplementedError("Currently, probability queries are not " + "supported for random variables with symbolic sized distributions.") + condition = rv_subs(condition) + return FinitePSpace(self.domain, self.distribution).probability(condition) + + def conditional_space(self, condition): + """ + This method is used for transferring the + computation to probability method because + conditional space of random variables with + symbolic dimensions is currently not possible. + """ + if self._is_symbolic: + self + domain = self.where(condition) + prob = self.probability(condition) + density = {key: val / prob + for key, val in self._density.items() if domain._test(key)} + return FinitePSpace(domain, density) + + +class ProductFinitePSpace(IndependentProductPSpace, FinitePSpace): + """ + A collection of several independent finite probability spaces + """ + @property + def domain(self): + return ProductFiniteDomain(*[space.domain for space in self.spaces]) + + @property # type: ignore + @cacheit + def _density(self): + proditer = product(*[iter(space._density.items()) + for space in self.spaces]) + d = {} + for items in proditer: + elems, probs = list(zip(*items)) + elem = sumsets(elems) + prob = Mul(*probs) + d[elem] = d.get(elem, S.Zero) + prob + return Dict(d) + + @property # type: ignore + @cacheit + def density(self): + return Dict(self._density) + + def probability(self, condition): + return FinitePSpace.probability(self, condition) + + def compute_density(self, expr): + return FinitePSpace.compute_density(self, expr) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/frv_types.py b/.venv/lib/python3.13/site-packages/sympy/stats/frv_types.py new file mode 100644 index 0000000000000000000000000000000000000000..bde656c219791c287ff445d5d215e3759271e923 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/frv_types.py @@ -0,0 +1,873 @@ +""" +Finite Discrete Random Variables - Prebuilt variable types + +Contains +======== +FiniteRV +DiscreteUniform +Die +Bernoulli +Coin +Binomial +BetaBinomial +Hypergeometric +Rademacher +IdealSoliton +RobustSoliton +""" + + +from sympy.core.cache import cacheit +from sympy.core.function import Lambda +from sympy.core.numbers import (Integer, Rational) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.piecewise import Piecewise +from sympy.logic.boolalg import Or +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Range +from sympy.sets.sets import (Intersection, Interval) +from sympy.functions.special.beta_functions import beta as beta_fn +from sympy.stats.frv import (SingleFiniteDistribution, + SingleFinitePSpace) +from sympy.stats.rv import _value_check, Density, is_random +from sympy.utilities.iterables import multiset +from sympy.utilities.misc import filldedent + + +__all__ = ['FiniteRV', +'DiscreteUniform', +'Die', +'Bernoulli', +'Coin', +'Binomial', +'BetaBinomial', +'Hypergeometric', +'Rademacher', +'IdealSoliton', +'RobustSoliton', +] + +def rv(name, cls, *args, **kwargs): + args = list(map(sympify, args)) + dist = cls(*args) + if kwargs.pop('check', True): + dist.check(*args) + pspace = SingleFinitePSpace(name, dist) + if any(is_random(arg) for arg in args): + from sympy.stats.compound_rv import CompoundPSpace, CompoundDistribution + pspace = CompoundPSpace(name, CompoundDistribution(dist)) + return pspace.value + +class FiniteDistributionHandmade(SingleFiniteDistribution): + + @property + def dict(self): + return self.args[0] + + def pmf(self, x): + x = Symbol('x') + return Lambda(x, Piecewise(*( + [(v, Eq(k, x)) for k, v in self.dict.items()] + [(S.Zero, True)]))) + + @property + def set(self): + return set(self.dict.keys()) + + @staticmethod + def check(density): + for p in density.values(): + _value_check((p >= 0, p <= 1), + "Probability at a point must be between 0 and 1.") + val = sum(density.values()) + _value_check(Eq(val, 1) != S.false, "Total Probability must be 1.") + +def FiniteRV(name, density, **kwargs): + r""" + Create a Finite Random Variable given a dict representing the density. + + Parameters + ========== + + name : Symbol + Represents name of the random variable. + density : dict + Dictionary containing the pdf of finite distribution + check : bool + If True, it will check whether the given density + integrates to 1 over the given set. If False, it + will not perform this check. Default is False. + + Examples + ======== + + >>> from sympy.stats import FiniteRV, P, E + + >>> density = {0: .1, 1: .2, 2: .3, 3: .4} + >>> X = FiniteRV('X', density) + + >>> E(X) + 2.00000000000000 + >>> P(X >= 2) + 0.700000000000000 + + Returns + ======= + + RandomSymbol + + """ + # have a default of False while `rv` should have a default of True + kwargs['check'] = kwargs.pop('check', False) + return rv(name, FiniteDistributionHandmade, density, **kwargs) + +class DiscreteUniformDistribution(SingleFiniteDistribution): + + @staticmethod + def check(*args): + # not using _value_check since there is a + # suggestion for the user + if len(set(args)) != len(args): + weights = multiset(args) + n = Integer(len(args)) + for k in weights: + weights[k] /= n + raise ValueError(filldedent(""" + Repeated args detected but set expected. For a + distribution having different weights for each + item use the following:""") + ( + '\nS("FiniteRV(%s, %s)")' % ("'X'", weights))) + + @property + def p(self): + return Rational(1, len(self.args)) + + @property # type: ignore + @cacheit + def dict(self): + return dict.fromkeys(self.set, self.p) + + @property + def set(self): + return set(self.args) + + def pmf(self, x): + if x in self.args: + return self.p + else: + return S.Zero + + +def DiscreteUniform(name, items): + r""" + Create a Finite Random Variable representing a uniform distribution over + the input set. + + Parameters + ========== + + items : list/tuple + Items over which Uniform distribution is to be made + + Examples + ======== + + >>> from sympy.stats import DiscreteUniform, density + >>> from sympy import symbols + + >>> X = DiscreteUniform('X', symbols('a b c')) # equally likely over a, b, c + >>> density(X).dict + {a: 1/3, b: 1/3, c: 1/3} + + >>> Y = DiscreteUniform('Y', list(range(5))) # distribution over a range + >>> density(Y).dict + {0: 1/5, 1: 1/5, 2: 1/5, 3: 1/5, 4: 1/5} + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Discrete_uniform_distribution + .. [2] https://mathworld.wolfram.com/DiscreteUniformDistribution.html + + """ + return rv(name, DiscreteUniformDistribution, *items) + + +class DieDistribution(SingleFiniteDistribution): + _argnames = ('sides',) + + @staticmethod + def check(sides): + _value_check((sides.is_positive, sides.is_integer), + "number of sides must be a positive integer.") + + @property + def is_symbolic(self): + return not self.sides.is_number + + @property + def high(self): + return self.sides + + @property + def low(self): + return S.One + + @property + def set(self): + if self.is_symbolic: + return Intersection(S.Naturals0, Interval(0, self.sides)) + return set(map(Integer, range(1, self.sides + 1))) + + def pmf(self, x): + x = sympify(x) + if not (x.is_number or x.is_Symbol or is_random(x)): + raise ValueError("'x' expected as an argument of type 'number', 'Symbol', or " + "'RandomSymbol' not %s" % (type(x))) + cond = Ge(x, 1) & Le(x, self.sides) & Contains(x, S.Integers) + return Piecewise((S.One/self.sides, cond), (S.Zero, True)) + +def Die(name, sides=6): + r""" + Create a Finite Random Variable representing a fair die. + + Parameters + ========== + + sides : Integer + Represents the number of sides of the Die, by default is 6 + + Examples + ======== + + >>> from sympy.stats import Die, density + >>> from sympy import Symbol + + >>> D6 = Die('D6', 6) # Six sided Die + >>> density(D6).dict + {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6} + + >>> D4 = Die('D4', 4) # Four sided Die + >>> density(D4).dict + {1: 1/4, 2: 1/4, 3: 1/4, 4: 1/4} + + >>> n = Symbol('n', positive=True, integer=True) + >>> Dn = Die('Dn', n) # n sided Die + >>> density(Dn).dict + Density(DieDistribution(n)) + >>> density(Dn).dict.subs(n, 4).doit() + {1: 1/4, 2: 1/4, 3: 1/4, 4: 1/4} + + Returns + ======= + + RandomSymbol + """ + + return rv(name, DieDistribution, sides) + + +class BernoulliDistribution(SingleFiniteDistribution): + _argnames = ('p', 'succ', 'fail') + + @staticmethod + def check(p, succ, fail): + _value_check((p >= 0, p <= 1), + "p should be in range [0, 1].") + + @property + def set(self): + return {self.succ, self.fail} + + def pmf(self, x): + if isinstance(self.succ, Symbol) and isinstance(self.fail, Symbol): + return Piecewise((self.p, x == self.succ), + (1 - self.p, x == self.fail), + (S.Zero, True)) + return Piecewise((self.p, Eq(x, self.succ)), + (1 - self.p, Eq(x, self.fail)), + (S.Zero, True)) + + +def Bernoulli(name, p, succ=1, fail=0): + r""" + Create a Finite Random Variable representing a Bernoulli process. + + Parameters + ========== + + p : Rational number between 0 and 1 + Represents probability of success + succ : Integer/symbol/string + Represents event of success + fail : Integer/symbol/string + Represents event of failure + + Examples + ======== + + >>> from sympy.stats import Bernoulli, density + >>> from sympy import S + + >>> X = Bernoulli('X', S(3)/4) # 1-0 Bernoulli variable, probability = 3/4 + >>> density(X).dict + {0: 1/4, 1: 3/4} + + >>> X = Bernoulli('X', S.Half, 'Heads', 'Tails') # A fair coin toss + >>> density(X).dict + {Heads: 1/2, Tails: 1/2} + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bernoulli_distribution + .. [2] https://mathworld.wolfram.com/BernoulliDistribution.html + + """ + + return rv(name, BernoulliDistribution, p, succ, fail) + + +def Coin(name, p=S.Half): + r""" + Create a Finite Random Variable representing a Coin toss. + + This is an equivalent of a Bernoulli random variable with + "H" and "T" as success and failure events respectively. + + Parameters + ========== + + p : Rational Number between 0 and 1 + Represents probability of getting "Heads", by default is Half + + Examples + ======== + + >>> from sympy.stats import Coin, density + >>> from sympy import Rational + + >>> C = Coin('C') # A fair coin toss + >>> density(C).dict + {H: 1/2, T: 1/2} + + >>> C2 = Coin('C2', Rational(3, 5)) # An unfair coin + >>> density(C2).dict + {H: 3/5, T: 2/5} + + Returns + ======= + + RandomSymbol + + See Also + ======== + + sympy.stats.Binomial + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Coin_flipping + + """ + return rv(name, BernoulliDistribution, p, 'H', 'T') + + +class BinomialDistribution(SingleFiniteDistribution): + _argnames = ('n', 'p', 'succ', 'fail') + + @staticmethod + def check(n, p, succ, fail): + _value_check((n.is_integer, n.is_nonnegative), + "'n' must be nonnegative integer.") + _value_check((p <= 1, p >= 0), + "p should be in range [0, 1].") + + @property + def high(self): + return self.n + + @property + def low(self): + return S.Zero + + @property + def is_symbolic(self): + return not self.n.is_number + + @property + def set(self): + if self.is_symbolic: + return Intersection(S.Naturals0, Interval(0, self.n)) + return set(self.dict.keys()) + + def pmf(self, x): + n, p = self.n, self.p + x = sympify(x) + if not (x.is_number or x.is_Symbol or is_random(x)): + raise ValueError("'x' expected as an argument of type 'number', 'Symbol', or " + "'RandomSymbol' not %s" % (type(x))) + cond = Ge(x, 0) & Le(x, n) & Contains(x, S.Integers) + return Piecewise((binomial(n, x) * p**x * (1 - p)**(n - x), cond), (S.Zero, True)) + + @property # type: ignore + @cacheit + def dict(self): + if self.is_symbolic: + return Density(self) + return {k*self.succ + (self.n-k)*self.fail: self.pmf(k) + for k in range(0, self.n + 1)} + + +def Binomial(name, n, p, succ=1, fail=0): + r""" + Create a Finite Random Variable representing a binomial distribution. + + Parameters + ========== + + n : Positive Integer + Represents number of trials + p : Rational Number between 0 and 1 + Represents probability of success + succ : Integer/symbol/string + Represents event of success, by default is 1 + fail : Integer/symbol/string + Represents event of failure, by default is 0 + + Examples + ======== + + >>> from sympy.stats import Binomial, density + >>> from sympy import S, Symbol + + >>> X = Binomial('X', 4, S.Half) # Four "coin flips" + >>> density(X).dict + {0: 1/16, 1: 1/4, 2: 3/8, 3: 1/4, 4: 1/16} + + >>> n = Symbol('n', positive=True, integer=True) + >>> p = Symbol('p', positive=True) + >>> X = Binomial('X', n, S.Half) # n "coin flips" + >>> density(X).dict + Density(BinomialDistribution(n, 1/2, 1, 0)) + >>> density(X).dict.subs(n, 4).doit() + {0: 1/16, 1: 1/4, 2: 3/8, 3: 1/4, 4: 1/16} + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Binomial_distribution + .. [2] https://mathworld.wolfram.com/BinomialDistribution.html + + """ + + return rv(name, BinomialDistribution, n, p, succ, fail) + +#------------------------------------------------------------------------------- +# Beta-binomial distribution ---------------------------------------------------------- + +class BetaBinomialDistribution(SingleFiniteDistribution): + _argnames = ('n', 'alpha', 'beta') + + @staticmethod + def check(n, alpha, beta): + _value_check((n.is_integer, n.is_nonnegative), + "'n' must be nonnegative integer. n = %s." % str(n)) + _value_check((alpha > 0), + "'alpha' must be: alpha > 0 . alpha = %s" % str(alpha)) + _value_check((beta > 0), + "'beta' must be: beta > 0 . beta = %s" % str(beta)) + + @property + def high(self): + return self.n + + @property + def low(self): + return S.Zero + + @property + def is_symbolic(self): + return not self.n.is_number + + @property + def set(self): + if self.is_symbolic: + return Intersection(S.Naturals0, Interval(0, self.n)) + return set(map(Integer, range(self.n + 1))) + + def pmf(self, k): + n, a, b = self.n, self.alpha, self.beta + return binomial(n, k) * beta_fn(k + a, n - k + b) / beta_fn(a, b) + + +def BetaBinomial(name, n, alpha, beta): + r""" + Create a Finite Random Variable representing a Beta-binomial distribution. + + Parameters + ========== + + n : Positive Integer + Represents number of trials + alpha : Real positive number + beta : Real positive number + + Examples + ======== + + >>> from sympy.stats import BetaBinomial, density + + >>> X = BetaBinomial('X', 2, 1, 1) + >>> density(X).dict + {0: 1/3, 1: 2*beta(2, 2), 2: 1/3} + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Beta-binomial_distribution + .. [2] https://mathworld.wolfram.com/BetaBinomialDistribution.html + + """ + + return rv(name, BetaBinomialDistribution, n, alpha, beta) + + +class HypergeometricDistribution(SingleFiniteDistribution): + _argnames = ('N', 'm', 'n') + + @staticmethod + def check(n, N, m): + _value_check((N.is_integer, N.is_nonnegative), + "'N' must be nonnegative integer. N = %s." % str(N)) + _value_check((n.is_integer, n.is_nonnegative), + "'n' must be nonnegative integer. n = %s." % str(n)) + _value_check((m.is_integer, m.is_nonnegative), + "'m' must be nonnegative integer. m = %s." % str(m)) + + @property + def is_symbolic(self): + return not all(x.is_number for x in (self.N, self.m, self.n)) + + @property + def high(self): + return Piecewise((self.n, Lt(self.n, self.m) != False), (self.m, True)) + + @property + def low(self): + return Piecewise((0, Gt(0, self.n + self.m - self.N) != False), (self.n + self.m - self.N, True)) + + @property + def set(self): + N, m, n = self.N, self.m, self.n + if self.is_symbolic: + return Intersection(S.Naturals0, Interval(self.low, self.high)) + return set(range(max(0, n + m - N), min(n, m) + 1)) + + def pmf(self, k): + N, m, n = self.N, self.m, self.n + return S(binomial(m, k) * binomial(N - m, n - k))/binomial(N, n) + + +def Hypergeometric(name, N, m, n): + r""" + Create a Finite Random Variable representing a hypergeometric distribution. + + Parameters + ========== + + N : Positive Integer + Represents finite population of size N. + m : Positive Integer + Represents number of trials with required feature. + n : Positive Integer + Represents numbers of draws. + + + Examples + ======== + + >>> from sympy.stats import Hypergeometric, density + + >>> X = Hypergeometric('X', 10, 5, 3) # 10 marbles, 5 white (success), 3 draws + >>> density(X).dict + {0: 1/12, 1: 5/12, 2: 5/12, 3: 1/12} + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hypergeometric_distribution + .. [2] https://mathworld.wolfram.com/HypergeometricDistribution.html + + """ + return rv(name, HypergeometricDistribution, N, m, n) + + +class RademacherDistribution(SingleFiniteDistribution): + + @property + def set(self): + return {-1, 1} + + @property + def pmf(self): + k = Dummy('k') + return Lambda(k, Piecewise((S.Half, Or(Eq(k, -1), Eq(k, 1))), (S.Zero, True))) + +def Rademacher(name): + r""" + Create a Finite Random Variable representing a Rademacher distribution. + + Examples + ======== + + >>> from sympy.stats import Rademacher, density + + >>> X = Rademacher('X') + >>> density(X).dict + {-1: 1/2, 1: 1/2} + + Returns + ======= + + RandomSymbol + + See Also + ======== + + sympy.stats.Bernoulli + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Rademacher_distribution + + """ + return rv(name, RademacherDistribution) + +class IdealSolitonDistribution(SingleFiniteDistribution): + _argnames = ('k',) + + @staticmethod + def check(k): + _value_check(k.is_integer and k.is_positive, + "'k' must be a positive integer.") + + @property + def low(self): + return S.One + + @property + def high(self): + return self.k + + @property + def set(self): + return set(map(Integer, range(1, self.k + 1))) + + @property # type: ignore + @cacheit + def dict(self): + if self.k.is_Symbol: + return Density(self) + d = {1: Rational(1, self.k)} + d.update({i: Rational(1, i*(i - 1)) for i in range(2, self.k + 1)}) + return d + + def pmf(self, x): + x = sympify(x) + if not (x.is_number or x.is_Symbol or is_random(x)): + raise ValueError("'x' expected as an argument of type 'number', 'Symbol', or " + "'RandomSymbol' not %s" % (type(x))) + cond1 = Eq(x, 1) & x.is_integer + cond2 = Ge(x, 1) & Le(x, self.k) & x.is_integer + return Piecewise((1/self.k, cond1), (1/(x*(x - 1)), cond2), (S.Zero, True)) + +def IdealSoliton(name, k): + r""" + Create a Finite Random Variable of Ideal Soliton Distribution + + Parameters + ========== + + k : Positive Integer + Represents the number of input symbols in an LT (Luby Transform) code. + + Examples + ======== + + >>> from sympy.stats import IdealSoliton, density, P, E + >>> sol = IdealSoliton('sol', 5) + >>> density(sol).dict + {1: 1/5, 2: 1/2, 3: 1/6, 4: 1/12, 5: 1/20} + >>> density(sol).set + {1, 2, 3, 4, 5} + + >>> from sympy import Symbol + >>> k = Symbol('k', positive=True, integer=True) + >>> sol = IdealSoliton('sol', k) + >>> density(sol).dict + Density(IdealSolitonDistribution(k)) + >>> density(sol).dict.subs(k, 10).doit() + {1: 1/10, 2: 1/2, 3: 1/6, 4: 1/12, 5: 1/20, 6: 1/30, 7: 1/42, 8: 1/56, 9: 1/72, 10: 1/90} + + >>> E(sol.subs(k, 10)) + 7381/2520 + + >>> P(sol.subs(k, 4) > 2) + 1/4 + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Soliton_distribution#Ideal_distribution + .. [2] https://pages.cs.wisc.edu/~suman/courses/740/papers/luby02lt.pdf + + """ + return rv(name, IdealSolitonDistribution, k) + +class RobustSolitonDistribution(SingleFiniteDistribution): + _argnames= ('k', 'delta', 'c') + + @staticmethod + def check(k, delta, c): + _value_check(k.is_integer and k.is_positive, + "'k' must be a positive integer") + _value_check(Gt(delta, 0) and Le(delta, 1), + "'delta' must be a real number in the interval (0,1)") + _value_check(c.is_positive, + "'c' must be a positive real number.") + + @property + def R(self): + return self.c * log(self.k/self.delta) * self.k**0.5 + + @property + def Z(self): + z = 0 + for i in Range(1, round(self.k/self.R)): + z += (1/i) + z += log(self.R/self.delta) + return 1 + z * self.R/self.k + + @property + def low(self): + return S.One + + @property + def high(self): + return self.k + + @property + def set(self): + return set(map(Integer, range(1, self.k + 1))) + + @property + def is_symbolic(self): + return not (self.k.is_number and self.c.is_number and self.delta.is_number) + + def pmf(self, x): + x = sympify(x) + if not (x.is_number or x.is_Symbol or is_random(x)): + raise ValueError("'x' expected as an argument of type 'number', 'Symbol', or " + "'RandomSymbol' not %s" % (type(x))) + + cond1 = Eq(x, 1) & x.is_integer + cond2 = Ge(x, 1) & Le(x, self.k) & x.is_integer + rho = Piecewise((Rational(1, self.k), cond1), (Rational(1, x*(x-1)), cond2), (S.Zero, True)) + + cond1 = Ge(x, 1) & Le(x, round(self.k/self.R)-1) + cond2 = Eq(x, round(self.k/self.R)) + tau = Piecewise((self.R/(self.k * x), cond1), (self.R * log(self.R/self.delta)/self.k, cond2), (S.Zero, True)) + + return (rho + tau)/self.Z + +def RobustSoliton(name, k, delta, c): + r''' + Create a Finite Random Variable of Robust Soliton Distribution + + Parameters + ========== + + k : Positive Integer + Represents the number of input symbols in an LT (Luby Transform) code. + delta : Positive Rational Number + Represents the failure probability. Must be in the interval (0,1). + c : Positive Rational Number + Constant of proportionality. Values close to 1 are recommended + + Examples + ======== + + >>> from sympy.stats import RobustSoliton, density, P, E + >>> robSol = RobustSoliton('robSol', 5, 0.5, 0.01) + >>> density(robSol).dict + {1: 0.204253668152708, 2: 0.490631107897393, 3: 0.165210624506162, 4: 0.0834387731899302, 5: 0.0505633404760675} + >>> density(robSol).set + {1, 2, 3, 4, 5} + + >>> from sympy import Symbol + >>> k = Symbol('k', positive=True, integer=True) + >>> c = Symbol('c', positive=True) + >>> robSol = RobustSoliton('robSol', k, 0.5, c) + >>> density(robSol).dict + Density(RobustSolitonDistribution(k, 0.5, c)) + >>> density(robSol).dict.subs(k, 10).subs(c, 0.03).doit() + {1: 0.116641095387194, 2: 0.467045731687165, 3: 0.159984123349381, 4: 0.0821431680681869, 5: 0.0505765646770100, + 6: 0.0345781523420719, 7: 0.0253132820710503, 8: 0.0194459129233227, 9: 0.0154831166726115, 10: 0.0126733075238887} + + >>> E(robSol.subs(k, 10).subs(c, 0.05)) + 2.91358846104106 + + >>> P(robSol.subs(k, 4).subs(c, 0.1) > 2) + 0.243650614389834 + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Soliton_distribution#Robust_distribution + .. [2] https://www.inference.org.uk/mackay/itprnn/ps/588.596.pdf + .. [3] https://pages.cs.wisc.edu/~suman/courses/740/papers/luby02lt.pdf + + ''' + return rv(name, RobustSolitonDistribution, k, delta, c) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/joint_rv.py b/.venv/lib/python3.13/site-packages/sympy/stats/joint_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..d147942f08b998e167b246628360fa27fc8ef348 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/joint_rv.py @@ -0,0 +1,426 @@ +""" +Joint Random Variables Module + +See Also +======== +sympy.stats.rv +sympy.stats.frv +sympy.stats.crv +sympy.stats.drv +""" +from math import prod + +from sympy.core.basic import Basic +from sympy.core.function import Lambda +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.core.sympify import sympify +from sympy.sets.sets import ProductSet +from sympy.tensor.indexed import Indexed +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum, summation +from sympy.core.containers import Tuple +from sympy.integrals.integrals import Integral, integrate +from sympy.matrices import ImmutableMatrix, matrix2numpy, list2numpy +from sympy.stats.crv import SingleContinuousDistribution, SingleContinuousPSpace +from sympy.stats.drv import SingleDiscreteDistribution, SingleDiscretePSpace +from sympy.stats.rv import (ProductPSpace, NamedArgsMixin, Distribution, + ProductDomain, RandomSymbol, random_symbols, + SingleDomain, _symbol_converter) +from sympy.utilities.iterables import iterable +from sympy.utilities.misc import filldedent +from sympy.external import import_module + +# __all__ = ['marginal_distribution'] + +class JointPSpace(ProductPSpace): + """ + Represents a joint probability space. Represented using symbols for + each component and a distribution. + """ + def __new__(cls, sym, dist): + if isinstance(dist, SingleContinuousDistribution): + return SingleContinuousPSpace(sym, dist) + if isinstance(dist, SingleDiscreteDistribution): + return SingleDiscretePSpace(sym, dist) + sym = _symbol_converter(sym) + return Basic.__new__(cls, sym, dist) + + @property + def set(self): + return self.domain.set + + @property + def symbol(self): + return self.args[0] + + @property + def distribution(self): + return self.args[1] + + @property + def value(self): + return JointRandomSymbol(self.symbol, self) + + @property + def component_count(self): + _set = self.distribution.set + if isinstance(_set, ProductSet): + return S(len(_set.args)) + elif isinstance(_set, Product): + return _set.limits[0][-1] + return S.One + + @property + def pdf(self): + sym = [Indexed(self.symbol, i) for i in range(self.component_count)] + return self.distribution(*sym) + + @property + def domain(self): + rvs = random_symbols(self.distribution) + if not rvs: + return SingleDomain(self.symbol, self.distribution.set) + return ProductDomain(*[rv.pspace.domain for rv in rvs]) + + def component_domain(self, index): + return self.set.args[index] + + def marginal_distribution(self, *indices): + count = self.component_count + if count.atoms(Symbol): + raise ValueError("Marginal distributions cannot be computed " + "for symbolic dimensions. It is a work under progress.") + orig = [Indexed(self.symbol, i) for i in range(count)] + all_syms = [Symbol(str(i)) for i in orig] + replace_dict = dict(zip(all_syms, orig)) + sym = tuple(Symbol(str(Indexed(self.symbol, i))) for i in indices) + limits = [[i,] for i in all_syms if i not in sym] + index = 0 + for i in range(count): + if i not in indices: + limits[index].append(self.distribution.set.args[i]) + limits[index] = tuple(limits[index]) + index += 1 + if self.distribution.is_Continuous: + f = Lambda(sym, integrate(self.distribution(*all_syms), *limits)) + elif self.distribution.is_Discrete: + f = Lambda(sym, summation(self.distribution(*all_syms), *limits)) + return f.xreplace(replace_dict) + + def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): + syms = tuple(self.value[i] for i in range(self.component_count)) + rvs = rvs or syms + if not any(i in rvs for i in syms): + return expr + expr = expr*self.pdf + for rv in rvs: + if isinstance(rv, Indexed): + expr = expr.xreplace({rv: Indexed(str(rv.base), rv.args[1])}) + elif isinstance(rv, RandomSymbol): + expr = expr.xreplace({rv: rv.symbol}) + if self.value in random_symbols(expr): + raise NotImplementedError(filldedent(''' + Expectations of expression with unindexed joint random symbols + cannot be calculated yet.''')) + limits = tuple((Indexed(str(rv.base),rv.args[1]), + self.distribution.set.args[rv.args[1]]) for rv in syms) + return Integral(expr, *limits) + + def where(self, condition): + raise NotImplementedError() + + def compute_density(self, expr): + raise NotImplementedError() + + def sample(self, size=(), library='scipy', seed=None): + """ + Internal sample method + + Returns dictionary mapping RandomSymbol to realization value. + """ + return {RandomSymbol(self.symbol, self): self.distribution.sample(size, + library=library, seed=seed)} + + def probability(self, condition): + raise NotImplementedError() + + +class SampleJointScipy: + """Returns the sample from scipy of the given distribution""" + def __new__(cls, dist, size, seed=None): + return cls._sample_scipy(dist, size, seed) + + @classmethod + def _sample_scipy(cls, dist, size, seed): + """Sample from SciPy.""" + + import numpy + if seed is None or isinstance(seed, int): + rand_state = numpy.random.default_rng(seed=seed) + else: + rand_state = seed + from scipy import stats as scipy_stats + scipy_rv_map = { + 'MultivariateNormalDistribution': lambda dist, size: scipy_stats.multivariate_normal.rvs( + mean=matrix2numpy(dist.mu).flatten(), + cov=matrix2numpy(dist.sigma), size=size, random_state=rand_state), + 'MultivariateBetaDistribution': lambda dist, size: scipy_stats.dirichlet.rvs( + alpha=list2numpy(dist.alpha, float).flatten(), size=size, random_state=rand_state), + 'MultinomialDistribution': lambda dist, size: scipy_stats.multinomial.rvs( + n=int(dist.n), p=list2numpy(dist.p, float).flatten(), size=size, random_state=rand_state) + } + + sample_shape = { + 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape, + 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape, + 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape + } + + dist_list = scipy_rv_map.keys() + + if dist.__class__.__name__ not in dist_list: + return None + + samples = scipy_rv_map[dist.__class__.__name__](dist, size) + return samples.reshape(size + sample_shape[dist.__class__.__name__](dist)) + +class SampleJointNumpy: + """Returns the sample from numpy of the given distribution""" + + def __new__(cls, dist, size, seed=None): + return cls._sample_numpy(dist, size, seed) + + @classmethod + def _sample_numpy(cls, dist, size, seed): + """Sample from NumPy.""" + + import numpy + if seed is None or isinstance(seed, int): + rand_state = numpy.random.default_rng(seed=seed) + else: + rand_state = seed + numpy_rv_map = { + 'MultivariateNormalDistribution': lambda dist, size: rand_state.multivariate_normal( + mean=matrix2numpy(dist.mu, float).flatten(), + cov=matrix2numpy(dist.sigma, float), size=size), + 'MultivariateBetaDistribution': lambda dist, size: rand_state.dirichlet( + alpha=list2numpy(dist.alpha, float).flatten(), size=size), + 'MultinomialDistribution': lambda dist, size: rand_state.multinomial( + n=int(dist.n), pvals=list2numpy(dist.p, float).flatten(), size=size) + } + + sample_shape = { + 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape, + 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape, + 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape + } + + dist_list = numpy_rv_map.keys() + + if dist.__class__.__name__ not in dist_list: + return None + + samples = numpy_rv_map[dist.__class__.__name__](dist, prod(size)) + return samples.reshape(size + sample_shape[dist.__class__.__name__](dist)) + +class SampleJointPymc: + """Returns the sample from pymc of the given distribution""" + + def __new__(cls, dist, size, seed=None): + return cls._sample_pymc(dist, size, seed) + + @classmethod + def _sample_pymc(cls, dist, size, seed): + """Sample from PyMC.""" + + try: + import pymc + except ImportError: + import pymc3 as pymc + pymc_rv_map = { + 'MultivariateNormalDistribution': lambda dist: + pymc.MvNormal('X', mu=matrix2numpy(dist.mu, float).flatten(), + cov=matrix2numpy(dist.sigma, float), shape=(1, dist.mu.shape[0])), + 'MultivariateBetaDistribution': lambda dist: + pymc.Dirichlet('X', a=list2numpy(dist.alpha, float).flatten()), + 'MultinomialDistribution': lambda dist: + pymc.Multinomial('X', n=int(dist.n), + p=list2numpy(dist.p, float).flatten(), shape=(1, len(dist.p))) + } + + sample_shape = { + 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape, + 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape, + 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape + } + + dist_list = pymc_rv_map.keys() + + if dist.__class__.__name__ not in dist_list: + return None + + import logging + logging.getLogger("pymc3").setLevel(logging.ERROR) + with pymc.Model(): + pymc_rv_map[dist.__class__.__name__](dist) + samples = pymc.sample(draws=prod(size), chains=1, progressbar=False, random_seed=seed, return_inferencedata=False, compute_convergence_checks=False)[:]['X'] + return samples.reshape(size + sample_shape[dist.__class__.__name__](dist)) + + +_get_sample_class_jrv = { + 'scipy': SampleJointScipy, + 'pymc3': SampleJointPymc, + 'pymc': SampleJointPymc, + 'numpy': SampleJointNumpy +} + +class JointDistribution(Distribution, NamedArgsMixin): + """ + Represented by the random variables part of the joint distribution. + Contains methods for PDF, CDF, sampling, marginal densities, etc. + """ + + _argnames = ('pdf', ) + + def __new__(cls, *args): + args = list(map(sympify, args)) + for i in range(len(args)): + if isinstance(args[i], list): + args[i] = ImmutableMatrix(args[i]) + return Basic.__new__(cls, *args) + + @property + def domain(self): + return ProductDomain(self.symbols) + + @property + def pdf(self): + return self.density.args[1] + + def cdf(self, other): + if not isinstance(other, dict): + raise ValueError("%s should be of type dict, got %s"%(other, type(other))) + rvs = other.keys() + _set = self.domain.set.sets + expr = self.pdf(tuple(i.args[0] for i in self.symbols)) + for i in range(len(other)): + if rvs[i].is_Continuous: + density = Integral(expr, (rvs[i], _set[i].inf, + other[rvs[i]])) + elif rvs[i].is_Discrete: + density = Sum(expr, (rvs[i], _set[i].inf, + other[rvs[i]])) + return density + + def sample(self, size=(), library='scipy', seed=None): + """ A random realization from the distribution """ + + libraries = ('scipy', 'numpy', 'pymc3', 'pymc') + if library not in libraries: + raise NotImplementedError("Sampling from %s is not supported yet." + % str(library)) + if not import_module(library): + raise ValueError("Failed to import %s" % library) + + samps = _get_sample_class_jrv[library](self, size, seed=seed) + + if samps is not None: + return samps + raise NotImplementedError( + "Sampling for %s is not currently implemented from %s" + % (self.__class__.__name__, library) + ) + + def __call__(self, *args): + return self.pdf(*args) + +class JointRandomSymbol(RandomSymbol): + """ + Representation of random symbols with joint probability distributions + to allow indexing." + """ + def __getitem__(self, key): + if isinstance(self.pspace, JointPSpace): + if (self.pspace.component_count <= key) == True: + raise ValueError("Index keys for %s can only up to %s." % + (self.name, self.pspace.component_count - 1)) + return Indexed(self, key) + + + +class MarginalDistribution(Distribution): + """ + Represents the marginal distribution of a joint probability space. + + Initialised using a probability distribution and random variables(or + their indexed components) which should be a part of the resultant + distribution. + """ + + def __new__(cls, dist, *rvs): + if len(rvs) == 1 and iterable(rvs[0]): + rvs = tuple(rvs[0]) + if not all(isinstance(rv, (Indexed, RandomSymbol)) for rv in rvs): + raise ValueError(filldedent('''Marginal distribution can be + intitialised only in terms of random variables or indexed random + variables''')) + rvs = Tuple.fromiter(rv for rv in rvs) + if not isinstance(dist, JointDistribution) and len(random_symbols(dist)) == 0: + return dist + return Basic.__new__(cls, dist, rvs) + + def check(self): + pass + + @property + def set(self): + rvs = [i for i in self.args[1] if isinstance(i, RandomSymbol)] + return ProductSet(*[rv.pspace.set for rv in rvs]) + + @property + def symbols(self): + rvs = self.args[1] + return {rv.pspace.symbol for rv in rvs} + + def pdf(self, *x): + expr, rvs = self.args[0], self.args[1] + marginalise_out = [i for i in random_symbols(expr) if i not in rvs] + if isinstance(expr, JointDistribution): + count = len(expr.domain.args) + x = Dummy('x', real=True) + syms = tuple(Indexed(x, i) for i in count) + expr = expr.pdf(syms) + else: + syms = tuple(rv.pspace.symbol if isinstance(rv, RandomSymbol) else rv.args[0] for rv in rvs) + return Lambda(syms, self.compute_pdf(expr, marginalise_out))(*x) + + def compute_pdf(self, expr, rvs): + for rv in rvs: + lpdf = 1 + if isinstance(rv, RandomSymbol): + lpdf = rv.pspace.pdf + expr = self.marginalise_out(expr*lpdf, rv) + return expr + + def marginalise_out(self, expr, rv): + from sympy.concrete.summations import Sum + if isinstance(rv, RandomSymbol): + dom = rv.pspace.set + elif isinstance(rv, Indexed): + dom = rv.base.component_domain( + rv.pspace.component_domain(rv.args[1])) + expr = expr.xreplace({rv: rv.pspace.symbol}) + if rv.pspace.is_Continuous: + #TODO: Modify to support integration + #for all kinds of sets. + expr = Integral(expr, (rv.pspace.symbol, dom)) + elif rv.pspace.is_Discrete: + #incorporate this into `Sum`/`summation` + if dom in (S.Integers, S.Naturals, S.Naturals0): + dom = (dom.inf, dom.sup) + expr = Sum(expr, (rv.pspace.symbol, dom)) + return expr + + def __call__(self, *args): + return self.pdf(*args) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/joint_rv_types.py b/.venv/lib/python3.13/site-packages/sympy/stats/joint_rv_types.py new file mode 100644 index 0000000000000000000000000000000000000000..6cee9f9aa30897593ffb7c7b930a55a38f0c518a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/joint_rv_types.py @@ -0,0 +1,945 @@ +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.function import Lambda +from sympy.core.mul import Mul +from sympy.core.numbers import (Integer, Rational, pi) +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (rf, factorial) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.bessel import besselk +from sympy.functions.special.gamma_functions import gamma +from sympy.matrices.dense import (Matrix, ones) +from sympy.sets.fancysets import Range +from sympy.sets.sets import (Intersection, Interval) +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.matrices import ImmutableMatrix, MatrixSymbol +from sympy.matrices.expressions.determinant import det +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.stats.joint_rv import JointDistribution, JointPSpace, MarginalDistribution +from sympy.stats.rv import _value_check, random_symbols + +__all__ = ['JointRV', +'MultivariateNormal', +'MultivariateLaplace', +'Dirichlet', +'GeneralizedMultivariateLogGamma', +'GeneralizedMultivariateLogGammaOmega', +'Multinomial', +'MultivariateBeta', +'MultivariateEwens', +'MultivariateT', +'NegativeMultinomial', +'NormalGamma' +] + +def multivariate_rv(cls, sym, *args): + args = list(map(sympify, args)) + dist = cls(*args) + args = dist.args + dist.check(*args) + return JointPSpace(sym, dist).value + + +def marginal_distribution(rv, *indices): + """ + Marginal distribution function of a joint random variable. + + Parameters + ========== + + rv : A random variable with a joint probability distribution. + indices : Component indices or the indexed random symbol + for which the joint distribution is to be calculated + + Returns + ======= + + A Lambda expression in `sym`. + + Examples + ======== + + >>> from sympy.stats import MultivariateNormal, marginal_distribution + >>> m = MultivariateNormal('X', [1, 2], [[2, 1], [1, 2]]) + >>> marginal_distribution(m, m[0])(1) + 1/(2*sqrt(pi)) + + """ + indices = list(indices) + for i in range(len(indices)): + if isinstance(indices[i], Indexed): + indices[i] = indices[i].args[1] + prob_space = rv.pspace + if not indices: + raise ValueError( + "At least one component for marginal density is needed.") + if hasattr(prob_space.distribution, '_marginal_distribution'): + return prob_space.distribution._marginal_distribution(indices, rv.symbol) + return prob_space.marginal_distribution(*indices) + + +class JointDistributionHandmade(JointDistribution): + + _argnames = ('pdf',) + is_Continuous = True + + @property + def set(self): + return self.args[1] + + +def JointRV(symbol, pdf, _set=None): + """ + Create a Joint Random Variable where each of its component is continuous, + given the following: + + Parameters + ========== + + symbol : Symbol + Represents name of the random variable. + pdf : A PDF in terms of indexed symbols of the symbol given + as the first argument + + NOTE + ==== + + As of now, the set for each component for a ``JointRV`` is + equal to the set of all integers, which cannot be changed. + + Examples + ======== + + >>> from sympy import exp, pi, Indexed, S + >>> from sympy.stats import density, JointRV + >>> x1, x2 = (Indexed('x', i) for i in (1, 2)) + >>> pdf = exp(-x1**2/2 + x1 - x2**2/2 - S(1)/2)/(2*pi) + >>> N1 = JointRV('x', pdf) #Multivariate Normal distribution + >>> density(N1)(1, 2) + exp(-2)/(2*pi) + + Returns + ======= + + RandomSymbol + + """ + #TODO: Add support for sets provided by the user + symbol = sympify(symbol) + syms = [i for i in pdf.free_symbols if isinstance(i, Indexed) + and i.base == IndexedBase(symbol)] + syms = tuple(sorted(syms, key = lambda index: index.args[1])) + _set = S.Reals**len(syms) + pdf = Lambda(syms, pdf) + dist = JointDistributionHandmade(pdf, _set) + jrv = JointPSpace(symbol, dist).value + rvs = random_symbols(pdf) + if len(rvs) != 0: + dist = MarginalDistribution(dist, (jrv,)) + return JointPSpace(symbol, dist).value + return jrv + +#------------------------------------------------------------------------------- +# Multivariate Normal distribution --------------------------------------------- + +class MultivariateNormalDistribution(JointDistribution): + _argnames = ('mu', 'sigma') + + is_Continuous=True + + @property + def set(self): + k = self.mu.shape[0] + return S.Reals**k + + @staticmethod + def check(mu, sigma): + _value_check(mu.shape[0] == sigma.shape[0], + "Size of the mean vector and covariance matrix are incorrect.") + #check if covariance matrix is positive semi definite or not. + if not isinstance(sigma, MatrixSymbol): + _value_check(sigma.is_positive_semidefinite, + "The covariance matrix must be positive semi definite. ") + + def pdf(self, *args): + mu, sigma = self.mu, self.sigma + k = mu.shape[0] + if len(args) == 1 and args[0].is_Matrix: + args = args[0] + else: + args = ImmutableMatrix(args) + x = args - mu + density = S.One/sqrt((2*pi)**(k)*det(sigma))*exp( + Rational(-1, 2)*x.transpose()*(sigma.inv()*x)) + return MatrixElement(density, 0, 0) + + def _marginal_distribution(self, indices, sym): + sym = ImmutableMatrix([Indexed(sym, i) for i in indices]) + _mu, _sigma = self.mu, self.sigma + k = self.mu.shape[0] + for i in range(k): + if i not in indices: + _mu = _mu.row_del(i) + _sigma = _sigma.col_del(i) + _sigma = _sigma.row_del(i) + return Lambda(tuple(sym), S.One/sqrt((2*pi)**(len(_mu))*det(_sigma))*exp( + Rational(-1, 2)*(_mu - sym).transpose()*(_sigma.inv()*\ + (_mu - sym)))[0]) + +def MultivariateNormal(name, mu, sigma): + r""" + Creates a continuous random variable with Multivariate Normal + Distribution. + + The density of the multivariate normal distribution can be found at [1]. + + Parameters + ========== + + mu : List representing the mean or the mean vector + sigma : Positive semidefinite square matrix + Represents covariance Matrix. + If `\sigma` is noninvertible then only sampling is supported currently + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import MultivariateNormal, density, marginal_distribution + >>> from sympy import symbols, MatrixSymbol + >>> X = MultivariateNormal('X', [3, 4], [[2, 1], [1, 2]]) + >>> y, z = symbols('y z') + >>> density(X)(y, z) + sqrt(3)*exp(-y**2/3 + y*z/3 + 2*y/3 - z**2/3 + 5*z/3 - 13/3)/(6*pi) + >>> density(X)(1, 2) + sqrt(3)*exp(-4/3)/(6*pi) + >>> marginal_distribution(X, X[1])(y) + exp(-(y - 4)**2/4)/(2*sqrt(pi)) + >>> marginal_distribution(X, X[0])(y) + exp(-(y - 3)**2/4)/(2*sqrt(pi)) + + The example below shows that it is also possible to use + symbolic parameters to define the MultivariateNormal class. + + >>> n = symbols('n', integer=True, positive=True) + >>> Sg = MatrixSymbol('Sg', n, n) + >>> mu = MatrixSymbol('mu', n, 1) + >>> obs = MatrixSymbol('obs', n, 1) + >>> X = MultivariateNormal('X', mu, Sg) + + The density of a multivariate normal can be + calculated using a matrix argument, as shown below. + + >>> density(X)(obs) + (exp(((1/2)*mu.T - (1/2)*obs.T)*Sg**(-1)*(-mu + obs))/sqrt((2*pi)**n*Determinant(Sg)))[0, 0] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Multivariate_normal_distribution + + """ + return multivariate_rv(MultivariateNormalDistribution, name, mu, sigma) + +#------------------------------------------------------------------------------- +# Multivariate Laplace distribution -------------------------------------------- + +class MultivariateLaplaceDistribution(JointDistribution): + _argnames = ('mu', 'sigma') + is_Continuous=True + + @property + def set(self): + k = self.mu.shape[0] + return S.Reals**k + + @staticmethod + def check(mu, sigma): + _value_check(mu.shape[0] == sigma.shape[0], + "Size of the mean vector and covariance matrix are incorrect.") + # check if covariance matrix is positive definite or not. + if not isinstance(sigma, MatrixSymbol): + _value_check(sigma.is_positive_definite, + "The covariance matrix must be positive definite. ") + + def pdf(self, *args): + mu, sigma = self.mu, self.sigma + mu_T = mu.transpose() + k = S(mu.shape[0]) + sigma_inv = sigma.inv() + args = ImmutableMatrix(args) + args_T = args.transpose() + x = (mu_T*sigma_inv*mu)[0] + y = (args_T*sigma_inv*args)[0] + v = 1 - k/2 + return (2 * (y/(2 + x))**(v/2) * besselk(v, sqrt((2 + x)*y)) * + exp((args_T * sigma_inv * mu)[0]) / + ((2 * pi)**(k/2) * sqrt(det(sigma)))) + + +def MultivariateLaplace(name, mu, sigma): + """ + Creates a continuous random variable with Multivariate Laplace + Distribution. + + The density of the multivariate Laplace distribution can be found at [1]. + + Parameters + ========== + + mu : List representing the mean or the mean vector + sigma : Positive definite square matrix + Represents covariance Matrix + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import MultivariateLaplace, density + >>> from sympy import symbols + >>> y, z = symbols('y z') + >>> X = MultivariateLaplace('X', [2, 4], [[3, 1], [1, 3]]) + >>> density(X)(y, z) + sqrt(2)*exp(y/4 + 5*z/4)*besselk(0, sqrt(15*y*(3*y/8 - z/8)/2 + 15*z*(-y/8 + 3*z/8)/2))/(4*pi) + >>> density(X)(1, 2) + sqrt(2)*exp(11/4)*besselk(0, sqrt(165)/4)/(4*pi) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Multivariate_Laplace_distribution + + """ + return multivariate_rv(MultivariateLaplaceDistribution, name, mu, sigma) + +#------------------------------------------------------------------------------- +# Multivariate StudentT distribution ------------------------------------------- + +class MultivariateTDistribution(JointDistribution): + _argnames = ('mu', 'shape_mat', 'dof') + is_Continuous=True + + @property + def set(self): + k = self.mu.shape[0] + return S.Reals**k + + @staticmethod + def check(mu, sigma, v): + _value_check(mu.shape[0] == sigma.shape[0], + "Size of the location vector and shape matrix are incorrect.") + # check if covariance matrix is positive definite or not. + if not isinstance(sigma, MatrixSymbol): + _value_check(sigma.is_positive_definite, + "The shape matrix must be positive definite. ") + + def pdf(self, *args): + mu, sigma = self.mu, self.shape_mat + v = S(self.dof) + k = S(mu.shape[0]) + sigma_inv = sigma.inv() + args = ImmutableMatrix(args) + x = args - mu + return gamma((k + v)/2)/(gamma(v/2)*(v*pi)**(k/2)*sqrt(det(sigma)))\ + *(1 + 1/v*(x.transpose()*sigma_inv*x)[0])**((-v - k)/2) + +def MultivariateT(syms, mu, sigma, v): + """ + Creates a joint random variable with multivariate T-distribution. + + Parameters + ========== + + syms : A symbol/str + For identifying the random variable. + mu : A list/matrix + Representing the location vector + sigma : The shape matrix for the distribution + + Examples + ======== + + >>> from sympy.stats import density, MultivariateT + >>> from sympy import Symbol + + >>> x = Symbol("x") + >>> X = MultivariateT("x", [1, 1], [[1, 0], [0, 1]], 2) + + >>> density(X)(1, 2) + 2/(9*pi) + + Returns + ======= + + RandomSymbol + + """ + return multivariate_rv(MultivariateTDistribution, syms, mu, sigma, v) + + +#------------------------------------------------------------------------------- +# Multivariate Normal Gamma distribution --------------------------------------- + +class NormalGammaDistribution(JointDistribution): + + _argnames = ('mu', 'lamda', 'alpha', 'beta') + is_Continuous=True + + @staticmethod + def check(mu, lamda, alpha, beta): + _value_check(mu.is_real, "Location must be real.") + _value_check(lamda > 0, "Lambda must be positive") + _value_check(alpha > 0, "alpha must be positive") + _value_check(beta > 0, "beta must be positive") + + @property + def set(self): + return S.Reals*Interval(0, S.Infinity) + + def pdf(self, x, tau): + beta, alpha, lamda = self.beta, self.alpha, self.lamda + mu = self.mu + + return beta**alpha*sqrt(lamda)/(gamma(alpha)*sqrt(2*pi))*\ + tau**(alpha - S.Half)*exp(-1*beta*tau)*\ + exp(-1*(lamda*tau*(x - mu)**2)/S(2)) + + def _marginal_distribution(self, indices, *sym): + if len(indices) == 2: + return self.pdf(*sym) + if indices[0] == 0: + #For marginal over `x`, return non-standardized Student-T's + #distribution + x = sym[0] + v, mu, sigma = self.alpha - S.Half, self.mu, \ + S(self.beta)/(self.lamda * self.alpha) + return Lambda(sym, gamma((v + 1)/2)/(gamma(v/2)*sqrt(pi*v)*sigma)*\ + (1 + 1/v*((x - mu)/sigma)**2)**((-v -1)/2)) + #For marginal over `tau`, return Gamma distribution as per construction + from sympy.stats.crv_types import GammaDistribution + return Lambda(sym, GammaDistribution(self.alpha, self.beta)(sym[0])) + +def NormalGamma(sym, mu, lamda, alpha, beta): + """ + Creates a bivariate joint random variable with multivariate Normal gamma + distribution. + + Parameters + ========== + + sym : A symbol/str + For identifying the random variable. + mu : A real number + The mean of the normal distribution + lamda : A positive integer + Parameter of joint distribution + alpha : A positive integer + Parameter of joint distribution + beta : A positive integer + Parameter of joint distribution + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, NormalGamma + >>> from sympy import symbols + + >>> X = NormalGamma('x', 0, 1, 2, 3) + >>> y, z = symbols('y z') + + >>> density(X)(y, z) + 9*sqrt(2)*z**(3/2)*exp(-3*z)*exp(-y**2*z/2)/(2*sqrt(pi)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Normal-gamma_distribution + + """ + return multivariate_rv(NormalGammaDistribution, sym, mu, lamda, alpha, beta) + +#------------------------------------------------------------------------------- +# Multivariate Beta/Dirichlet distribution ------------------------------------- + +class MultivariateBetaDistribution(JointDistribution): + + _argnames = ('alpha',) + is_Continuous = True + + @staticmethod + def check(alpha): + _value_check(len(alpha) >= 2, "At least two categories should be passed.") + for a_k in alpha: + _value_check((a_k > 0) != False, "Each concentration parameter" + " should be positive.") + + @property + def set(self): + k = len(self.alpha) + return Interval(0, 1)**k + + def pdf(self, *syms): + alpha = self.alpha + B = Mul.fromiter(map(gamma, alpha))/gamma(Add(*alpha)) + return Mul.fromiter(sym**(a_k - 1) for a_k, sym in zip(alpha, syms))/B + +def MultivariateBeta(syms, *alpha): + """ + Creates a continuous random variable with Dirichlet/Multivariate Beta + Distribution. + + The density of the Dirichlet distribution can be found at [1]. + + Parameters + ========== + + alpha : Positive real numbers + Signifies concentration numbers. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, MultivariateBeta, marginal_distribution + >>> from sympy import Symbol + >>> a1 = Symbol('a1', positive=True) + >>> a2 = Symbol('a2', positive=True) + >>> B = MultivariateBeta('B', [a1, a2]) + >>> C = MultivariateBeta('C', a1, a2) + >>> x = Symbol('x') + >>> y = Symbol('y') + >>> density(B)(x, y) + x**(a1 - 1)*y**(a2 - 1)*gamma(a1 + a2)/(gamma(a1)*gamma(a2)) + >>> marginal_distribution(C, C[0])(x) + x**(a1 - 1)*gamma(a1 + a2)/(a2*gamma(a1)*gamma(a2)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Dirichlet_distribution + .. [2] https://mathworld.wolfram.com/DirichletDistribution.html + + """ + if not isinstance(alpha[0], list): + alpha = (list(alpha),) + return multivariate_rv(MultivariateBetaDistribution, syms, alpha[0]) + +Dirichlet = MultivariateBeta + +#------------------------------------------------------------------------------- +# Multivariate Ewens distribution ---------------------------------------------- + +class MultivariateEwensDistribution(JointDistribution): + + _argnames = ('n', 'theta') + is_Discrete = True + is_Continuous = False + + @staticmethod + def check(n, theta): + _value_check((n > 0), + "sample size should be positive integer.") + _value_check(theta.is_positive, "mutation rate should be positive.") + + @property + def set(self): + if not isinstance(self.n, Integer): + i = Symbol('i', integer=True, positive=True) + return Product(Intersection(S.Naturals0, Interval(0, self.n//i)), + (i, 1, self.n)) + prod_set = Range(0, self.n + 1) + for i in range(2, self.n + 1): + prod_set *= Range(0, self.n//i + 1) + return prod_set.flatten() + + def pdf(self, *syms): + n, theta = self.n, self.theta + condi = isinstance(self.n, Integer) + if not (isinstance(syms[0], IndexedBase) or condi): + raise ValueError("Please use IndexedBase object for syms as " + "the dimension is symbolic") + term_1 = factorial(n)/rf(theta, n) + if condi: + term_2 = Mul.fromiter(theta**syms[j]/((j+1)**syms[j]*factorial(syms[j])) + for j in range(n)) + cond = Eq(sum((k + 1)*syms[k] for k in range(n)), n) + return Piecewise((term_1 * term_2, cond), (0, True)) + syms = syms[0] + j, k = symbols('j, k', positive=True, integer=True) + term_2 = Product(theta**syms[j]/((j+1)**syms[j]*factorial(syms[j])), + (j, 0, n - 1)) + cond = Eq(Sum((k + 1)*syms[k], (k, 0, n - 1)), n) + return Piecewise((term_1 * term_2, cond), (0, True)) + + +def MultivariateEwens(syms, n, theta): + """ + Creates a discrete random variable with Multivariate Ewens + Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + n : Positive integer + Size of the sample or the integer whose partitions are considered + theta : Positive real number + Denotes Mutation rate + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, marginal_distribution, MultivariateEwens + >>> from sympy import Symbol + >>> a1 = Symbol('a1', positive=True) + >>> a2 = Symbol('a2', positive=True) + >>> ed = MultivariateEwens('E', 2, 1) + >>> density(ed)(a1, a2) + Piecewise((1/(2**a2*factorial(a1)*factorial(a2)), Eq(a1 + 2*a2, 2)), (0, True)) + >>> marginal_distribution(ed, ed[0])(a1) + Piecewise((1/factorial(a1), Eq(a1, 2)), (0, True)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Ewens%27s_sampling_formula + .. [2] https://www.jstor.org/stable/24780825 + """ + return multivariate_rv(MultivariateEwensDistribution, syms, n, theta) + +#------------------------------------------------------------------------------- +# Generalized Multivariate Log Gamma distribution ------------------------------ + +class GeneralizedMultivariateLogGammaDistribution(JointDistribution): + + _argnames = ('delta', 'v', 'lamda', 'mu') + is_Continuous=True + + def check(self, delta, v, l, mu): + _value_check((delta >= 0, delta <= 1), "delta must be in range [0, 1].") + _value_check((v > 0), "v must be positive") + for lk in l: + _value_check((lk > 0), "lamda must be a positive vector.") + for muk in mu: + _value_check((muk > 0), "mu must be a positive vector.") + _value_check(len(l) > 1,"the distribution should have at least" + " two random variables.") + + @property + def set(self): + return S.Reals**len(self.lamda) + + def pdf(self, *y): + d, v, l, mu = self.delta, self.v, self.lamda, self.mu + n = Symbol('n', negative=False, integer=True) + k = len(l) + sterm1 = Pow((1 - d), n)/\ + ((gamma(v + n)**(k - 1))*gamma(v)*gamma(n + 1)) + sterm2 = Mul.fromiter(mui*li**(-v - n) for mui, li in zip(mu, l)) + term1 = sterm1 * sterm2 + sterm3 = (v + n) * sum(mui * yi for mui, yi in zip(mu, y)) + sterm4 = sum(exp(mui * yi)/li for (mui, yi, li) in zip(mu, y, l)) + term2 = exp(sterm3 - sterm4) + return Pow(d, v) * Sum(term1 * term2, (n, 0, S.Infinity)) + +def GeneralizedMultivariateLogGamma(syms, delta, v, lamda, mu): + """ + Creates a joint random variable with generalized multivariate log gamma + distribution. + + The joint pdf can be found at [1]. + + Parameters + ========== + + syms : list/tuple/set of symbols for identifying each component + delta : A constant in range $[0, 1]$ + v : Positive real number + lamda : List of positive real numbers + mu : List of positive real numbers + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density + >>> from sympy.stats.joint_rv_types import GeneralizedMultivariateLogGamma + >>> from sympy import symbols, S + >>> v = 1 + >>> l, mu = [1, 1, 1], [1, 1, 1] + >>> d = S.Half + >>> y = symbols('y_1:4', positive=True) + >>> Gd = GeneralizedMultivariateLogGamma('G', d, v, l, mu) + >>> density(Gd)(y[0], y[1], y[2]) + Sum(exp((n + 1)*(y_1 + y_2 + y_3) - exp(y_1) - exp(y_2) - + exp(y_3))/(2**n*gamma(n + 1)**3), (n, 0, oo))/2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Generalized_multivariate_log-gamma_distribution + .. [2] https://www.researchgate.net/publication/234137346_On_a_multivariate_log-gamma_distribution_and_the_use_of_the_distribution_in_the_Bayesian_analysis + + Note + ==== + + If the GeneralizedMultivariateLogGamma is too long to type use, + + >>> from sympy.stats.joint_rv_types import GeneralizedMultivariateLogGamma as GMVLG + >>> Gd = GMVLG('G', d, v, l, mu) + + If you want to pass the matrix omega instead of the constant delta, then use + ``GeneralizedMultivariateLogGammaOmega``. + + """ + return multivariate_rv(GeneralizedMultivariateLogGammaDistribution, + syms, delta, v, lamda, mu) + +def GeneralizedMultivariateLogGammaOmega(syms, omega, v, lamda, mu): + """ + Extends GeneralizedMultivariateLogGamma. + + Parameters + ========== + + syms : list/tuple/set of symbols + For identifying each component + omega : A square matrix + Every element of square matrix must be absolute value of + square root of correlation coefficient + v : Positive real number + lamda : List of positive real numbers + mu : List of positive real numbers + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density + >>> from sympy.stats.joint_rv_types import GeneralizedMultivariateLogGammaOmega + >>> from sympy import Matrix, symbols, S + >>> omega = Matrix([[1, S.Half, S.Half], [S.Half, 1, S.Half], [S.Half, S.Half, 1]]) + >>> v = 1 + >>> l, mu = [1, 1, 1], [1, 1, 1] + >>> G = GeneralizedMultivariateLogGammaOmega('G', omega, v, l, mu) + >>> y = symbols('y_1:4', positive=True) + >>> density(G)(y[0], y[1], y[2]) + sqrt(2)*Sum((1 - sqrt(2)/2)**n*exp((n + 1)*(y_1 + y_2 + y_3) - exp(y_1) - + exp(y_2) - exp(y_3))/gamma(n + 1)**3, (n, 0, oo))/2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Generalized_multivariate_log-gamma_distribution + .. [2] https://www.researchgate.net/publication/234137346_On_a_multivariate_log-gamma_distribution_and_the_use_of_the_distribution_in_the_Bayesian_analysis + + Notes + ===== + + If the GeneralizedMultivariateLogGammaOmega is too long to type use, + + >>> from sympy.stats.joint_rv_types import GeneralizedMultivariateLogGammaOmega as GMVLGO + >>> G = GMVLGO('G', omega, v, l, mu) + + """ + _value_check((omega.is_square, isinstance(omega, Matrix)), "omega must be a" + " square matrix") + for val in omega.values(): + _value_check((val >= 0, val <= 1), + "all values in matrix must be between 0 and 1(both inclusive).") + _value_check(omega.diagonal().equals(ones(1, omega.shape[0])), + "all the elements of diagonal should be 1.") + _value_check((omega.shape[0] == len(lamda), len(lamda) == len(mu)), + "lamda, mu should be of same length and omega should " + " be of shape (length of lamda, length of mu)") + _value_check(len(lamda) > 1,"the distribution should have at least" + " two random variables.") + delta = Pow(Rational(omega.det()), Rational(1, len(lamda) - 1)) + return GeneralizedMultivariateLogGamma(syms, delta, v, lamda, mu) + + +#------------------------------------------------------------------------------- +# Multinomial distribution ----------------------------------------------------- + +class MultinomialDistribution(JointDistribution): + + _argnames = ('n', 'p') + is_Continuous=False + is_Discrete = True + + @staticmethod + def check(n, p): + _value_check(n > 0, + "number of trials must be a positive integer") + for p_k in p: + _value_check((p_k >= 0, p_k <= 1), + "probability must be in range [0, 1]") + _value_check(Eq(sum(p), 1), + "probabilities must sum to 1") + + @property + def set(self): + return Intersection(S.Naturals0, Interval(0, self.n))**len(self.p) + + def pdf(self, *x): + n, p = self.n, self.p + term_1 = factorial(n)/Mul.fromiter(factorial(x_k) for x_k in x) + term_2 = Mul.fromiter(p_k**x_k for p_k, x_k in zip(p, x)) + return Piecewise((term_1 * term_2, Eq(sum(x), n)), (0, True)) + +def Multinomial(syms, n, *p): + """ + Creates a discrete random variable with Multinomial Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + n : Positive integer + Represents number of trials + p : List of event probabilities + Must be in the range of $[0, 1]$. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, Multinomial, marginal_distribution + >>> from sympy import symbols + >>> x1, x2, x3 = symbols('x1, x2, x3', nonnegative=True, integer=True) + >>> p1, p2, p3 = symbols('p1, p2, p3', positive=True) + >>> M = Multinomial('M', 3, p1, p2, p3) + >>> density(M)(x1, x2, x3) + Piecewise((6*p1**x1*p2**x2*p3**x3/(factorial(x1)*factorial(x2)*factorial(x3)), + Eq(x1 + x2 + x3, 3)), (0, True)) + >>> marginal_distribution(M, M[0])(x1).subs(x1, 1) + 3*p1*p2**2 + 6*p1*p2*p3 + 3*p1*p3**2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Multinomial_distribution + .. [2] https://mathworld.wolfram.com/MultinomialDistribution.html + + """ + if not isinstance(p[0], list): + p = (list(p), ) + return multivariate_rv(MultinomialDistribution, syms, n, p[0]) + +#------------------------------------------------------------------------------- +# Negative Multinomial Distribution -------------------------------------------- + +class NegativeMultinomialDistribution(JointDistribution): + + _argnames = ('k0', 'p') + is_Continuous=False + is_Discrete = True + + @staticmethod + def check(k0, p): + _value_check(k0 > 0, + "number of failures must be a positive integer") + for p_k in p: + _value_check((p_k >= 0, p_k <= 1), + "probability must be in range [0, 1].") + _value_check(sum(p) <= 1, + "success probabilities must not be greater than 1.") + + @property + def set(self): + return Range(0, S.Infinity)**len(self.p) + + def pdf(self, *k): + k0, p = self.k0, self.p + term_1 = (gamma(k0 + sum(k))*(1 - sum(p))**k0)/gamma(k0) + term_2 = Mul.fromiter(pi**ki/factorial(ki) for pi, ki in zip(p, k)) + return term_1 * term_2 + +def NegativeMultinomial(syms, k0, *p): + """ + Creates a discrete random variable with Negative Multinomial Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + k0 : positive integer + Represents number of failures before the experiment is stopped + p : List of event probabilities + Must be in the range of $[0, 1]$ + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, NegativeMultinomial, marginal_distribution + >>> from sympy import symbols + >>> x1, x2, x3 = symbols('x1, x2, x3', nonnegative=True, integer=True) + >>> p1, p2, p3 = symbols('p1, p2, p3', positive=True) + >>> N = NegativeMultinomial('M', 3, p1, p2, p3) + >>> N_c = NegativeMultinomial('M', 3, 0.1, 0.1, 0.1) + >>> density(N)(x1, x2, x3) + p1**x1*p2**x2*p3**x3*(-p1 - p2 - p3 + 1)**3*gamma(x1 + x2 + + x3 + 3)/(2*factorial(x1)*factorial(x2)*factorial(x3)) + >>> marginal_distribution(N_c, N_c[0])(1).evalf().round(2) + 0.25 + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Negative_multinomial_distribution + .. [2] https://mathworld.wolfram.com/NegativeBinomialDistribution.html + + """ + if not isinstance(p[0], list): + p = (list(p), ) + return multivariate_rv(NegativeMultinomialDistribution, syms, k0, p[0]) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/matrix_distributions.py b/.venv/lib/python3.13/site-packages/sympy/stats/matrix_distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..9a43c0226bc25702211a910ebbe30e280ad0cf50 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/matrix_distributions.py @@ -0,0 +1,610 @@ +from math import prod + +from sympy.core.basic import Basic +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.special.gamma_functions import multigamma +from sympy.core.sympify import sympify, _sympify +from sympy.matrices import (ImmutableMatrix, Inverse, Trace, Determinant, + MatrixSymbol, MatrixBase, Transpose, MatrixSet, + matrix2numpy) +from sympy.stats.rv import (_value_check, RandomMatrixSymbol, NamedArgsMixin, PSpace, + _symbol_converter, MatrixDomain, Distribution) +from sympy.external import import_module + + +################################################################################ +#------------------------Matrix Probability Space------------------------------# +################################################################################ +class MatrixPSpace(PSpace): + """ + Represents probability space for + Matrix Distributions. + """ + def __new__(cls, sym, distribution, dim_n, dim_m): + sym = _symbol_converter(sym) + dim_n, dim_m = _sympify(dim_n), _sympify(dim_m) + if not (dim_n.is_integer and dim_m.is_integer): + raise ValueError("Dimensions should be integers") + return Basic.__new__(cls, sym, distribution, dim_n, dim_m) + + distribution = property(lambda self: self.args[1]) + symbol = property(lambda self: self.args[0]) + + @property + def domain(self): + return MatrixDomain(self.symbol, self.distribution.set) + + @property + def value(self): + return RandomMatrixSymbol(self.symbol, self.args[2], self.args[3], self) + + @property + def values(self): + return {self.value} + + def compute_density(self, expr, *args): + rms = expr.atoms(RandomMatrixSymbol) + if len(rms) > 1 or (not isinstance(expr, RandomMatrixSymbol)): + raise NotImplementedError("Currently, no algorithm has been " + "implemented to handle general expressions containing " + "multiple matrix distributions.") + return self.distribution.pdf(expr) + + def sample(self, size=(), library='scipy', seed=None): + """ + Internal sample method + + Returns dictionary mapping RandomMatrixSymbol to realization value. + """ + return {self.value: self.distribution.sample(size, library=library, seed=seed)} + + +def rv(symbol, cls, args): + args = list(map(sympify, args)) + dist = cls(*args) + dist.check(*args) + dim = dist.dimension + pspace = MatrixPSpace(symbol, dist, dim[0], dim[1]) + return pspace.value + + +class SampleMatrixScipy: + """Returns the sample from scipy of the given distribution""" + def __new__(cls, dist, size, seed=None): + return cls._sample_scipy(dist, size, seed) + + @classmethod + def _sample_scipy(cls, dist, size, seed): + """Sample from SciPy.""" + + from scipy import stats as scipy_stats + import numpy + scipy_rv_map = { + 'WishartDistribution': lambda dist, size, rand_state: scipy_stats.wishart.rvs( + df=int(dist.n), scale=matrix2numpy(dist.scale_matrix, float), size=size), + 'MatrixNormalDistribution': lambda dist, size, rand_state: scipy_stats.matrix_normal.rvs( + mean=matrix2numpy(dist.location_matrix, float), + rowcov=matrix2numpy(dist.scale_matrix_1, float), + colcov=matrix2numpy(dist.scale_matrix_2, float), size=size, random_state=rand_state) + } + + sample_shape = { + 'WishartDistribution': lambda dist: dist.scale_matrix.shape, + 'MatrixNormalDistribution' : lambda dist: dist.location_matrix.shape + } + + dist_list = scipy_rv_map.keys() + + if dist.__class__.__name__ not in dist_list: + return None + + if seed is None or isinstance(seed, int): + rand_state = numpy.random.default_rng(seed=seed) + else: + rand_state = seed + samp = scipy_rv_map[dist.__class__.__name__](dist, prod(size), rand_state) + return samp.reshape(size + sample_shape[dist.__class__.__name__](dist)) + + +class SampleMatrixNumpy: + """Returns the sample from numpy of the given distribution""" + + ### TODO: Add tests after adding matrix distributions in numpy_rv_map + def __new__(cls, dist, size, seed=None): + return cls._sample_numpy(dist, size, seed) + + @classmethod + def _sample_numpy(cls, dist, size, seed): + """Sample from NumPy.""" + + numpy_rv_map = { + } + + sample_shape = { + } + + dist_list = numpy_rv_map.keys() + + if dist.__class__.__name__ not in dist_list: + return None + + import numpy + if seed is None or isinstance(seed, int): + rand_state = numpy.random.default_rng(seed=seed) + else: + rand_state = seed + samp = numpy_rv_map[dist.__class__.__name__](dist, prod(size), rand_state) + return samp.reshape(size + sample_shape[dist.__class__.__name__](dist)) + + +class SampleMatrixPymc: + """Returns the sample from pymc of the given distribution""" + + def __new__(cls, dist, size, seed=None): + return cls._sample_pymc(dist, size, seed) + + @classmethod + def _sample_pymc(cls, dist, size, seed): + """Sample from PyMC.""" + + try: + import pymc + except ImportError: + import pymc3 as pymc + pymc_rv_map = { + 'MatrixNormalDistribution': lambda dist: pymc.MatrixNormal('X', + mu=matrix2numpy(dist.location_matrix, float), + rowcov=matrix2numpy(dist.scale_matrix_1, float), + colcov=matrix2numpy(dist.scale_matrix_2, float), + shape=dist.location_matrix.shape), + 'WishartDistribution': lambda dist: pymc.WishartBartlett('X', + nu=int(dist.n), S=matrix2numpy(dist.scale_matrix, float)) + } + + sample_shape = { + 'WishartDistribution': lambda dist: dist.scale_matrix.shape, + 'MatrixNormalDistribution' : lambda dist: dist.location_matrix.shape + } + + dist_list = pymc_rv_map.keys() + + if dist.__class__.__name__ not in dist_list: + return None + import logging + logging.getLogger("pymc").setLevel(logging.ERROR) + with pymc.Model(): + pymc_rv_map[dist.__class__.__name__](dist) + samps = pymc.sample(draws=prod(size), chains=1, progressbar=False, random_seed=seed, return_inferencedata=False, compute_convergence_checks=False)['X'] + return samps.reshape(size + sample_shape[dist.__class__.__name__](dist)) + +_get_sample_class_matrixrv = { + 'scipy': SampleMatrixScipy, + 'pymc3': SampleMatrixPymc, + 'pymc': SampleMatrixPymc, + 'numpy': SampleMatrixNumpy +} + +################################################################################ +#-------------------------Matrix Distribution----------------------------------# +################################################################################ + +class MatrixDistribution(Distribution, NamedArgsMixin): + """ + Abstract class for Matrix Distribution. + """ + def __new__(cls, *args): + args = [ImmutableMatrix(arg) if isinstance(arg, list) + else _sympify(arg) for arg in args] + return Basic.__new__(cls, *args) + + @staticmethod + def check(*args): + pass + + def __call__(self, expr): + if isinstance(expr, list): + expr = ImmutableMatrix(expr) + return self.pdf(expr) + + def sample(self, size=(), library='scipy', seed=None): + """ + Internal sample method + + Returns dictionary mapping RandomSymbol to realization value. + """ + + libraries = ['scipy', 'numpy', 'pymc3', 'pymc'] + if library not in libraries: + raise NotImplementedError("Sampling from %s is not supported yet." + % str(library)) + if not import_module(library): + raise ValueError("Failed to import %s" % library) + + samps = _get_sample_class_matrixrv[library](self, size, seed) + + if samps is not None: + return samps + raise NotImplementedError( + "Sampling for %s is not currently implemented from %s" + % (self.__class__.__name__, library) + ) + +################################################################################ +#------------------------Matrix Distribution Types-----------------------------# +################################################################################ + +#------------------------------------------------------------------------------- +# Matrix Gamma distribution ---------------------------------------------------- + +class MatrixGammaDistribution(MatrixDistribution): + + _argnames = ('alpha', 'beta', 'scale_matrix') + + @staticmethod + def check(alpha, beta, scale_matrix): + if not isinstance(scale_matrix, MatrixSymbol): + _value_check(scale_matrix.is_positive_definite, "The shape " + "matrix must be positive definite.") + _value_check(scale_matrix.is_square, "Should " + "be square matrix") + _value_check(alpha.is_positive, "Shape parameter should be positive.") + _value_check(beta.is_positive, "Scale parameter should be positive.") + + @property + def set(self): + k = self.scale_matrix.shape[0] + return MatrixSet(k, k, S.Reals) + + @property + def dimension(self): + return self.scale_matrix.shape + + def pdf(self, x): + alpha, beta, scale_matrix = self.alpha, self.beta, self.scale_matrix + p = scale_matrix.shape[0] + if isinstance(x, list): + x = ImmutableMatrix(x) + if not isinstance(x, (MatrixBase, MatrixSymbol)): + raise ValueError("%s should be an isinstance of Matrix " + "or MatrixSymbol" % str(x)) + sigma_inv_x = - Inverse(scale_matrix)*x / beta + term1 = exp(Trace(sigma_inv_x))/((beta**(p*alpha)) * multigamma(alpha, p)) + term2 = (Determinant(scale_matrix))**(-alpha) + term3 = (Determinant(x))**(alpha - S(p + 1)/2) + return term1 * term2 * term3 + +def MatrixGamma(symbol, alpha, beta, scale_matrix): + """ + Creates a random variable with Matrix Gamma Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + alpha: Positive Real number + Shape Parameter + beta: Positive Real number + Scale Parameter + scale_matrix: Positive definite real square matrix + Scale Matrix + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, MatrixGamma + >>> from sympy import MatrixSymbol, symbols + >>> a, b = symbols('a b', positive=True) + >>> M = MatrixGamma('M', a, b, [[2, 1], [1, 2]]) + >>> X = MatrixSymbol('X', 2, 2) + >>> density(M)(X).doit() + exp(Trace(Matrix([ + [-2/3, 1/3], + [ 1/3, -2/3]])*X)/b)*Determinant(X)**(a - 3/2)/(3**a*sqrt(pi)*b**(2*a)*gamma(a)*gamma(a - 1/2)) + >>> density(M)([[1, 0], [0, 1]]).doit() + exp(-4/(3*b))/(3**a*sqrt(pi)*b**(2*a)*gamma(a)*gamma(a - 1/2)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Matrix_gamma_distribution + + """ + if isinstance(scale_matrix, list): + scale_matrix = ImmutableMatrix(scale_matrix) + return rv(symbol, MatrixGammaDistribution, (alpha, beta, scale_matrix)) + +#------------------------------------------------------------------------------- +# Wishart Distribution --------------------------------------------------------- + +class WishartDistribution(MatrixDistribution): + + _argnames = ('n', 'scale_matrix') + + @staticmethod + def check(n, scale_matrix): + if not isinstance(scale_matrix, MatrixSymbol): + _value_check(scale_matrix.is_positive_definite, "The shape " + "matrix must be positive definite.") + _value_check(scale_matrix.is_square, "Should " + "be square matrix") + _value_check(n.is_positive, "Shape parameter should be positive.") + + @property + def set(self): + k = self.scale_matrix.shape[0] + return MatrixSet(k, k, S.Reals) + + @property + def dimension(self): + return self.scale_matrix.shape + + def pdf(self, x): + n, scale_matrix = self.n, self.scale_matrix + p = scale_matrix.shape[0] + if isinstance(x, list): + x = ImmutableMatrix(x) + if not isinstance(x, (MatrixBase, MatrixSymbol)): + raise ValueError("%s should be an isinstance of Matrix " + "or MatrixSymbol" % str(x)) + sigma_inv_x = - Inverse(scale_matrix)*x / S(2) + term1 = exp(Trace(sigma_inv_x))/((2**(p*n/S(2))) * multigamma(n/S(2), p)) + term2 = (Determinant(scale_matrix))**(-n/S(2)) + term3 = (Determinant(x))**(S(n - p - 1)/2) + return term1 * term2 * term3 + +def Wishart(symbol, n, scale_matrix): + """ + Creates a random variable with Wishart Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + n: Positive Real number + Represents degrees of freedom + scale_matrix: Positive definite real square matrix + Scale Matrix + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, Wishart + >>> from sympy import MatrixSymbol, symbols + >>> n = symbols('n', positive=True) + >>> W = Wishart('W', n, [[2, 1], [1, 2]]) + >>> X = MatrixSymbol('X', 2, 2) + >>> density(W)(X).doit() + exp(Trace(Matrix([ + [-1/3, 1/6], + [ 1/6, -1/3]])*X))*Determinant(X)**(n/2 - 3/2)/(2**n*3**(n/2)*sqrt(pi)*gamma(n/2)*gamma(n/2 - 1/2)) + >>> density(W)([[1, 0], [0, 1]]).doit() + exp(-2/3)/(2**n*3**(n/2)*sqrt(pi)*gamma(n/2)*gamma(n/2 - 1/2)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Wishart_distribution + + """ + if isinstance(scale_matrix, list): + scale_matrix = ImmutableMatrix(scale_matrix) + return rv(symbol, WishartDistribution, (n, scale_matrix)) + +#------------------------------------------------------------------------------- +# Matrix Normal distribution --------------------------------------------------- + +class MatrixNormalDistribution(MatrixDistribution): + + _argnames = ('location_matrix', 'scale_matrix_1', 'scale_matrix_2') + + @staticmethod + def check(location_matrix, scale_matrix_1, scale_matrix_2): + if not isinstance(scale_matrix_1, MatrixSymbol): + _value_check(scale_matrix_1.is_positive_definite, "The shape " + "matrix must be positive definite.") + if not isinstance(scale_matrix_2, MatrixSymbol): + _value_check(scale_matrix_2.is_positive_definite, "The shape " + "matrix must be positive definite.") + _value_check(scale_matrix_1.is_square, "Scale matrix 1 should be " + "be square matrix") + _value_check(scale_matrix_2.is_square, "Scale matrix 2 should be " + "be square matrix") + n = location_matrix.shape[0] + p = location_matrix.shape[1] + _value_check(scale_matrix_1.shape[0] == n, "Scale matrix 1 should be" + " of shape %s x %s"% (str(n), str(n))) + _value_check(scale_matrix_2.shape[0] == p, "Scale matrix 2 should be" + " of shape %s x %s"% (str(p), str(p))) + + @property + def set(self): + n, p = self.location_matrix.shape + return MatrixSet(n, p, S.Reals) + + @property + def dimension(self): + return self.location_matrix.shape + + def pdf(self, x): + M, U, V = self.location_matrix, self.scale_matrix_1, self.scale_matrix_2 + n, p = M.shape + if isinstance(x, list): + x = ImmutableMatrix(x) + if not isinstance(x, (MatrixBase, MatrixSymbol)): + raise ValueError("%s should be an isinstance of Matrix " + "or MatrixSymbol" % str(x)) + term1 = Inverse(V)*Transpose(x - M)*Inverse(U)*(x - M) + num = exp(-Trace(term1)/S(2)) + den = (2*pi)**(S(n*p)/2) * Determinant(U)**(S(p)/2) * Determinant(V)**(S(n)/2) + return num/den + +def MatrixNormal(symbol, location_matrix, scale_matrix_1, scale_matrix_2): + """ + Creates a random variable with Matrix Normal Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + location_matrix: Real ``n x p`` matrix + Represents degrees of freedom + scale_matrix_1: Positive definite matrix + Scale Matrix of shape ``n x n`` + scale_matrix_2: Positive definite matrix + Scale Matrix of shape ``p x p`` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy import MatrixSymbol + >>> from sympy.stats import density, MatrixNormal + >>> M = MatrixNormal('M', [[1, 2]], [1], [[1, 0], [0, 1]]) + >>> X = MatrixSymbol('X', 1, 2) + >>> density(M)(X).doit() + exp(-Trace((Matrix([ + [-1], + [-2]]) + X.T)*(Matrix([[-1, -2]]) + X))/2)/(2*pi) + >>> density(M)([[3, 4]]).doit() + exp(-4)/(2*pi) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Matrix_normal_distribution + + """ + if isinstance(location_matrix, list): + location_matrix = ImmutableMatrix(location_matrix) + if isinstance(scale_matrix_1, list): + scale_matrix_1 = ImmutableMatrix(scale_matrix_1) + if isinstance(scale_matrix_2, list): + scale_matrix_2 = ImmutableMatrix(scale_matrix_2) + args = (location_matrix, scale_matrix_1, scale_matrix_2) + return rv(symbol, MatrixNormalDistribution, args) + +#------------------------------------------------------------------------------- +# Matrix Student's T distribution --------------------------------------------------- + +class MatrixStudentTDistribution(MatrixDistribution): + + _argnames = ('nu', 'location_matrix', 'scale_matrix_1', 'scale_matrix_2') + + @staticmethod + def check(nu, location_matrix, scale_matrix_1, scale_matrix_2): + if not isinstance(scale_matrix_1, MatrixSymbol): + _value_check(scale_matrix_1.is_positive_definite != False, "The shape " + "matrix must be positive definite.") + if not isinstance(scale_matrix_2, MatrixSymbol): + _value_check(scale_matrix_2.is_positive_definite != False, "The shape " + "matrix must be positive definite.") + _value_check(scale_matrix_1.is_square != False, "Scale matrix 1 should be " + "be square matrix") + _value_check(scale_matrix_2.is_square != False, "Scale matrix 2 should be " + "be square matrix") + n = location_matrix.shape[0] + p = location_matrix.shape[1] + _value_check(scale_matrix_1.shape[0] == p, "Scale matrix 1 should be" + " of shape %s x %s" % (str(p), str(p))) + _value_check(scale_matrix_2.shape[0] == n, "Scale matrix 2 should be" + " of shape %s x %s" % (str(n), str(n))) + _value_check(nu.is_positive != False, "Degrees of freedom must be positive") + + @property + def set(self): + n, p = self.location_matrix.shape + return MatrixSet(n, p, S.Reals) + + @property + def dimension(self): + return self.location_matrix.shape + + def pdf(self, x): + from sympy.matrices.dense import eye + if isinstance(x, list): + x = ImmutableMatrix(x) + if not isinstance(x, (MatrixBase, MatrixSymbol)): + raise ValueError("%s should be an isinstance of Matrix " + "or MatrixSymbol" % str(x)) + nu, M, Omega, Sigma = self.nu, self.location_matrix, self.scale_matrix_1, self.scale_matrix_2 + n, p = M.shape + + K = multigamma((nu + n + p - 1)/2, p) * Determinant(Omega)**(-n/2) * Determinant(Sigma)**(-p/2) \ + / ((pi)**(n*p/2) * multigamma((nu + p - 1)/2, p)) + return K * (Determinant(eye(n) + Inverse(Sigma)*(x - M)*Inverse(Omega)*Transpose(x - M))) \ + **(-(nu + n + p -1)/2) + + + +def MatrixStudentT(symbol, nu, location_matrix, scale_matrix_1, scale_matrix_2): + """ + Creates a random variable with Matrix Gamma Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + nu: Positive Real number + degrees of freedom + location_matrix: Positive definite real square matrix + Location Matrix of shape ``n x p`` + scale_matrix_1: Positive definite real square matrix + Scale Matrix of shape ``p x p`` + scale_matrix_2: Positive definite real square matrix + Scale Matrix of shape ``n x n`` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy import MatrixSymbol,symbols + >>> from sympy.stats import density, MatrixStudentT + >>> v = symbols('v',positive=True) + >>> M = MatrixStudentT('M', v, [[1, 2]], [[1, 0], [0, 1]], [1]) + >>> X = MatrixSymbol('X', 1, 2) + >>> density(M)(X) + gamma(v/2 + 1)*Determinant((Matrix([[-1, -2]]) + X)*(Matrix([ + [-1], + [-2]]) + X.T) + Matrix([[1]]))**(-v/2 - 1)/(pi**1.0*gamma(v/2)*Determinant(Matrix([[1]]))**1.0*Determinant(Matrix([ + [1, 0], + [0, 1]]))**0.5) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Matrix_t-distribution + + """ + if isinstance(location_matrix, list): + location_matrix = ImmutableMatrix(location_matrix) + if isinstance(scale_matrix_1, list): + scale_matrix_1 = ImmutableMatrix(scale_matrix_1) + if isinstance(scale_matrix_2, list): + scale_matrix_2 = ImmutableMatrix(scale_matrix_2) + args = (nu, location_matrix, scale_matrix_1, scale_matrix_2) + return rv(symbol, MatrixStudentTDistribution, args) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/random_matrix.py b/.venv/lib/python3.13/site-packages/sympy/stats/random_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd25cb9ad23fed9d3a85982b24bef33d04928f0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/random_matrix.py @@ -0,0 +1,30 @@ +from sympy.core.basic import Basic +from sympy.stats.rv import PSpace, _symbol_converter, RandomMatrixSymbol + +class RandomMatrixPSpace(PSpace): + """ + Represents probability space for + random matrices. It contains the mechanics + for handling the API calls for random matrices. + """ + def __new__(cls, sym, model=None): + sym = _symbol_converter(sym) + if model: + return Basic.__new__(cls, sym, model) + else: + return Basic.__new__(cls, sym) + + @property + def model(self): + try: + return self.args[1] + except IndexError: + return None + + def compute_density(self, expr, *args): + rms = expr.atoms(RandomMatrixSymbol) + if len(rms) > 2 or (not isinstance(expr, RandomMatrixSymbol)): + raise NotImplementedError("Currently, no algorithm has been " + "implemented to handle general expressions containing " + "multiple random matrices.") + return self.model.density(expr) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/random_matrix_models.py b/.venv/lib/python3.13/site-packages/sympy/stats/random_matrix_models.py new file mode 100644 index 0000000000000000000000000000000000000000..6327a248ea5919c0bbb0ffc2c984105e04fe20e9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/random_matrix_models.py @@ -0,0 +1,457 @@ +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.function import Lambda +from sympy.core.numbers import (I, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import Integral +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.trace import Trace +from sympy.tensor.indexed import IndexedBase +from sympy.core.sympify import _sympify +from sympy.stats.rv import _symbol_converter, Density, RandomMatrixSymbol, is_random +from sympy.stats.joint_rv_types import JointDistributionHandmade +from sympy.stats.random_matrix import RandomMatrixPSpace +from sympy.tensor.array import ArrayComprehension + +__all__ = [ + 'CircularEnsemble', + 'CircularUnitaryEnsemble', + 'CircularOrthogonalEnsemble', + 'CircularSymplecticEnsemble', + 'GaussianEnsemble', + 'GaussianUnitaryEnsemble', + 'GaussianOrthogonalEnsemble', + 'GaussianSymplecticEnsemble', + 'joint_eigen_distribution', + 'JointEigenDistribution', + 'level_spacing_distribution' +] + +@is_random.register(RandomMatrixSymbol) +def _(x): + return True + + +class RandomMatrixEnsembleModel(Basic): + """ + Base class for random matrix ensembles. + It acts as an umbrella and contains + the methods common to all the ensembles + defined in sympy.stats.random_matrix_models. + """ + def __new__(cls, sym, dim=None): + sym, dim = _symbol_converter(sym), _sympify(dim) + if dim.is_integer == False: + raise ValueError("Dimension of the random matrices must be " + "integers, received %s instead."%(dim)) + return Basic.__new__(cls, sym, dim) + + symbol = property(lambda self: self.args[0]) + dimension = property(lambda self: self.args[1]) + + def density(self, expr): + return Density(expr) + + def __call__(self, expr): + return self.density(expr) + +class GaussianEnsembleModel(RandomMatrixEnsembleModel): + """ + Abstract class for Gaussian ensembles. + Contains the properties common to all the + gaussian ensembles. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Random_matrix#Gaussian_ensembles + .. [2] https://arxiv.org/pdf/1712.07903.pdf + """ + def _compute_normalization_constant(self, beta, n): + """ + Helper function for computing normalization + constant for joint probability density of eigen + values of Gaussian ensembles. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Selberg_integral#Mehta's_integral + """ + n = S(n) + prod_term = lambda j: gamma(1 + beta*S(j)/2)/gamma(S.One + beta/S(2)) + j = Dummy('j', integer=True, positive=True) + term1 = Product(prod_term(j), (j, 1, n)).doit() + term2 = (2/(beta*n))**(beta*n*(n - 1)/4 + n/2) + term3 = (2*pi)**(n/2) + return term1 * term2 * term3 + + def _compute_joint_eigen_distribution(self, beta): + """ + Helper function for computing the joint + probability distribution of eigen values + of the random matrix. + """ + n = self.dimension + Zbn = self._compute_normalization_constant(beta, n) + l = IndexedBase('l') + i = Dummy('i', integer=True, positive=True) + j = Dummy('j', integer=True, positive=True) + k = Dummy('k', integer=True, positive=True) + term1 = exp((-S(n)/2) * Sum(l[k]**2, (k, 1, n)).doit()) + sub_term = Lambda(i, Product(Abs(l[j] - l[i])**beta, (j, i + 1, n))) + term2 = Product(sub_term(i).doit(), (i, 1, n - 1)).doit() + syms = ArrayComprehension(l[k], (k, 1, n)).doit() + return Lambda(tuple(syms), (term1 * term2)/Zbn) + +class GaussianUnitaryEnsembleModel(GaussianEnsembleModel): + @property + def normalization_constant(self): + n = self.dimension + return 2**(S(n)/2) * pi**(S(n**2)/2) + + def density(self, expr): + n, ZGUE = self.dimension, self.normalization_constant + h_pspace = RandomMatrixPSpace('P', model=self) + H = RandomMatrixSymbol('H', n, n, pspace=h_pspace) + return Lambda(H, exp(-S(n)/2 * Trace(H**2))/ZGUE)(expr) + + def joint_eigen_distribution(self): + return self._compute_joint_eigen_distribution(S(2)) + + def level_spacing_distribution(self): + s = Dummy('s') + f = (32/pi**2)*(s**2)*exp((-4/pi)*s**2) + return Lambda(s, f) + +class GaussianOrthogonalEnsembleModel(GaussianEnsembleModel): + @property + def normalization_constant(self): + n = self.dimension + _H = MatrixSymbol('_H', n, n) + return Integral(exp(-S(n)/4 * Trace(_H**2))) + + def density(self, expr): + n, ZGOE = self.dimension, self.normalization_constant + h_pspace = RandomMatrixPSpace('P', model=self) + H = RandomMatrixSymbol('H', n, n, pspace=h_pspace) + return Lambda(H, exp(-S(n)/4 * Trace(H**2))/ZGOE)(expr) + + def joint_eigen_distribution(self): + return self._compute_joint_eigen_distribution(S.One) + + def level_spacing_distribution(self): + s = Dummy('s') + f = (pi/2)*s*exp((-pi/4)*s**2) + return Lambda(s, f) + +class GaussianSymplecticEnsembleModel(GaussianEnsembleModel): + @property + def normalization_constant(self): + n = self.dimension + _H = MatrixSymbol('_H', n, n) + return Integral(exp(-S(n) * Trace(_H**2))) + + def density(self, expr): + n, ZGSE = self.dimension, self.normalization_constant + h_pspace = RandomMatrixPSpace('P', model=self) + H = RandomMatrixSymbol('H', n, n, pspace=h_pspace) + return Lambda(H, exp(-S(n) * Trace(H**2))/ZGSE)(expr) + + def joint_eigen_distribution(self): + return self._compute_joint_eigen_distribution(S(4)) + + def level_spacing_distribution(self): + s = Dummy('s') + f = ((S(2)**18)/((S(3)**6)*(pi**3)))*(s**4)*exp((-64/(9*pi))*s**2) + return Lambda(s, f) + +def GaussianEnsemble(sym, dim): + sym, dim = _symbol_converter(sym), _sympify(dim) + model = GaussianEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def GaussianUnitaryEnsemble(sym, dim): + """ + Represents Gaussian Unitary Ensembles. + + Examples + ======== + + >>> from sympy.stats import GaussianUnitaryEnsemble as GUE, density + >>> from sympy import MatrixSymbol + >>> G = GUE('U', 2) + >>> X = MatrixSymbol('X', 2, 2) + >>> density(G)(X) + exp(-Trace(X**2))/(2*pi**2) + """ + sym, dim = _symbol_converter(sym), _sympify(dim) + model = GaussianUnitaryEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def GaussianOrthogonalEnsemble(sym, dim): + """ + Represents Gaussian Orthogonal Ensembles. + + Examples + ======== + + >>> from sympy.stats import GaussianOrthogonalEnsemble as GOE, density + >>> from sympy import MatrixSymbol + >>> G = GOE('U', 2) + >>> X = MatrixSymbol('X', 2, 2) + >>> density(G)(X) + exp(-Trace(X**2)/2)/Integral(exp(-Trace(_H**2)/2), _H) + """ + sym, dim = _symbol_converter(sym), _sympify(dim) + model = GaussianOrthogonalEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def GaussianSymplecticEnsemble(sym, dim): + """ + Represents Gaussian Symplectic Ensembles. + + Examples + ======== + + >>> from sympy.stats import GaussianSymplecticEnsemble as GSE, density + >>> from sympy import MatrixSymbol + >>> G = GSE('U', 2) + >>> X = MatrixSymbol('X', 2, 2) + >>> density(G)(X) + exp(-2*Trace(X**2))/Integral(exp(-2*Trace(_H**2)), _H) + """ + sym, dim = _symbol_converter(sym), _sympify(dim) + model = GaussianSymplecticEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +class CircularEnsembleModel(RandomMatrixEnsembleModel): + """ + Abstract class for Circular ensembles. + Contains the properties and methods + common to all the circular ensembles. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Circular_ensemble + """ + def density(self, expr): + # TODO : Add support for Lie groups(as extensions of sympy.diffgeom) + # and define measures on them + raise NotImplementedError("Support for Haar measure hasn't been " + "implemented yet, therefore the density of " + "%s cannot be computed."%(self)) + + def _compute_joint_eigen_distribution(self, beta): + """ + Helper function to compute the joint distribution of phases + of the complex eigen values of matrices belonging to any + circular ensembles. + """ + n = self.dimension + Zbn = ((2*pi)**n)*(gamma(beta*n/2 + 1)/S(gamma(beta/2 + 1))**n) + t = IndexedBase('t') + i, j, k = (Dummy('i', integer=True), Dummy('j', integer=True), + Dummy('k', integer=True)) + syms = ArrayComprehension(t[i], (i, 1, n)).doit() + f = Product(Product(Abs(exp(I*t[k]) - exp(I*t[j]))**beta, (j, k + 1, n)).doit(), + (k, 1, n - 1)).doit() + return Lambda(tuple(syms), f/Zbn) + +class CircularUnitaryEnsembleModel(CircularEnsembleModel): + def joint_eigen_distribution(self): + return self._compute_joint_eigen_distribution(S(2)) + +class CircularOrthogonalEnsembleModel(CircularEnsembleModel): + def joint_eigen_distribution(self): + return self._compute_joint_eigen_distribution(S.One) + +class CircularSymplecticEnsembleModel(CircularEnsembleModel): + def joint_eigen_distribution(self): + return self._compute_joint_eigen_distribution(S(4)) + +def CircularEnsemble(sym, dim): + sym, dim = _symbol_converter(sym), _sympify(dim) + model = CircularEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def CircularUnitaryEnsemble(sym, dim): + """ + Represents Circular Unitary Ensembles. + + Examples + ======== + + >>> from sympy.stats import CircularUnitaryEnsemble as CUE + >>> from sympy.stats import joint_eigen_distribution + >>> C = CUE('U', 1) + >>> joint_eigen_distribution(C) + Lambda(t[1], Product(Abs(exp(I*t[_j]) - exp(I*t[_k]))**2, (_j, _k + 1, 1), (_k, 1, 0))/(2*pi)) + + Note + ==== + + As can be seen above in the example, density of CiruclarUnitaryEnsemble + is not evaluated because the exact definition is based on haar measure of + unitary group which is not unique. + """ + sym, dim = _symbol_converter(sym), _sympify(dim) + model = CircularUnitaryEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def CircularOrthogonalEnsemble(sym, dim): + """ + Represents Circular Orthogonal Ensembles. + + Examples + ======== + + >>> from sympy.stats import CircularOrthogonalEnsemble as COE + >>> from sympy.stats import joint_eigen_distribution + >>> C = COE('O', 1) + >>> joint_eigen_distribution(C) + Lambda(t[1], Product(Abs(exp(I*t[_j]) - exp(I*t[_k])), (_j, _k + 1, 1), (_k, 1, 0))/(2*pi)) + + Note + ==== + + As can be seen above in the example, density of CiruclarOrthogonalEnsemble + is not evaluated because the exact definition is based on haar measure of + unitary group which is not unique. + """ + sym, dim = _symbol_converter(sym), _sympify(dim) + model = CircularOrthogonalEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def CircularSymplecticEnsemble(sym, dim): + """ + Represents Circular Symplectic Ensembles. + + Examples + ======== + + >>> from sympy.stats import CircularSymplecticEnsemble as CSE + >>> from sympy.stats import joint_eigen_distribution + >>> C = CSE('S', 1) + >>> joint_eigen_distribution(C) + Lambda(t[1], Product(Abs(exp(I*t[_j]) - exp(I*t[_k]))**4, (_j, _k + 1, 1), (_k, 1, 0))/(2*pi)) + + Note + ==== + + As can be seen above in the example, density of CiruclarSymplecticEnsemble + is not evaluated because the exact definition is based on haar measure of + unitary group which is not unique. + """ + sym, dim = _symbol_converter(sym), _sympify(dim) + model = CircularSymplecticEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def joint_eigen_distribution(mat): + """ + For obtaining joint probability distribution + of eigen values of random matrix. + + Parameters + ========== + + mat: RandomMatrixSymbol + The matrix symbol whose eigen values are to be considered. + + Returns + ======= + + Lambda + + Examples + ======== + + >>> from sympy.stats import GaussianUnitaryEnsemble as GUE + >>> from sympy.stats import joint_eigen_distribution + >>> U = GUE('U', 2) + >>> joint_eigen_distribution(U) + Lambda((l[1], l[2]), exp(-l[1]**2 - l[2]**2)*Product(Abs(l[_i] - l[_j])**2, (_j, _i + 1, 2), (_i, 1, 1))/pi) + """ + if not isinstance(mat, RandomMatrixSymbol): + raise ValueError("%s is not of type, RandomMatrixSymbol."%(mat)) + return mat.pspace.model.joint_eigen_distribution() + +def JointEigenDistribution(mat): + """ + Creates joint distribution of eigen values of matrices with random + expressions. + + Parameters + ========== + + mat: Matrix + The matrix under consideration. + + Returns + ======= + + JointDistributionHandmade + + Examples + ======== + + >>> from sympy.stats import Normal, JointEigenDistribution + >>> from sympy import Matrix + >>> A = [[Normal('A00', 0, 1), Normal('A01', 0, 1)], + ... [Normal('A10', 0, 1), Normal('A11', 0, 1)]] + >>> JointEigenDistribution(Matrix(A)) + JointDistributionHandmade(-sqrt(A00**2 - 2*A00*A11 + 4*A01*A10 + A11**2)/2 + + A00/2 + A11/2, sqrt(A00**2 - 2*A00*A11 + 4*A01*A10 + A11**2)/2 + A00/2 + A11/2) + + """ + eigenvals = mat.eigenvals(multiple=True) + if not all(is_random(eigenval) for eigenval in set(eigenvals)): + raise ValueError("Eigen values do not have any random expression, " + "joint distribution cannot be generated.") + return JointDistributionHandmade(*eigenvals) + +def level_spacing_distribution(mat): + """ + For obtaining distribution of level spacings. + + Parameters + ========== + + mat: RandomMatrixSymbol + The random matrix symbol whose eigen values are + to be considered for finding the level spacings. + + Returns + ======= + + Lambda + + Examples + ======== + + >>> from sympy.stats import GaussianUnitaryEnsemble as GUE + >>> from sympy.stats import level_spacing_distribution + >>> U = GUE('U', 2) + >>> level_spacing_distribution(U) + Lambda(_s, 32*_s**2*exp(-4*_s**2/pi)/pi**2) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Random_matrix#Distribution_of_level_spacings + """ + return mat.pspace.model.level_spacing_distribution() diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/rv.py b/.venv/lib/python3.13/site-packages/sympy/stats/rv.py new file mode 100644 index 0000000000000000000000000000000000000000..75ab54deb551b7ff3d4d06f37482a1f16a789ba6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/rv.py @@ -0,0 +1,1798 @@ +""" +Main Random Variables Module + +Defines abstract random variable type. +Contains interfaces for probability space object (PSpace) as well as standard +operators, P, E, sample, density, where, quantile + +See Also +======== + +sympy.stats.crv +sympy.stats.frv +sympy.stats.rv_interface +""" + +from __future__ import annotations +from functools import singledispatch +from math import prod + +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import (Function, Lambda) +from sympy.core.logic import fuzzy_and +from sympy.core.mul import Mul +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.core.sympify import sympify +from sympy.functions.special.delta_functions import DiracDelta +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.logic.boolalg import (And, Or) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.tensor.indexed import Indexed +from sympy.utilities.lambdify import lambdify +from sympy.core.relational import Relational +from sympy.core.sympify import _sympify +from sympy.sets.sets import FiniteSet, ProductSet, Intersection +from sympy.solvers.solveset import solveset +from sympy.external import import_module +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import iterable + + +__doctest_requires__ = {('sample',): ['scipy']} + + +x = Symbol('x') + + +@singledispatch +def is_random(x): + return False + + +@is_random.register(Basic) +def _(x): + atoms = x.free_symbols + return any(is_random(i) for i in atoms) + + +class RandomDomain(Basic): + """ + Represents a set of variables and the values which they can take. + + See Also + ======== + + sympy.stats.crv.ContinuousDomain + sympy.stats.frv.FiniteDomain + """ + + is_ProductDomain = False + is_Finite = False + is_Continuous = False + is_Discrete = False + + def __new__(cls, symbols, *args): + symbols = FiniteSet(*symbols) + return Basic.__new__(cls, symbols, *args) + + @property + def symbols(self): + return self.args[0] + + @property + def set(self): + return self.args[1] + + def __contains__(self, other): + raise NotImplementedError() + + def compute_expectation(self, expr): + raise NotImplementedError() + + +class SingleDomain(RandomDomain): + """ + A single variable and its domain. + + See Also + ======== + + sympy.stats.crv.SingleContinuousDomain + sympy.stats.frv.SingleFiniteDomain + """ + def __new__(cls, symbol, set): + assert symbol.is_Symbol + return Basic.__new__(cls, symbol, set) + + @property + def symbol(self): + return self.args[0] + + @property + def symbols(self): + return FiniteSet(self.symbol) + + def __contains__(self, other): + if len(other) != 1: + return False + sym, val = tuple(other)[0] + return self.symbol == sym and val in self.set + + +class MatrixDomain(RandomDomain): + """ + A Random Matrix variable and its domain. + + """ + def __new__(cls, symbol, set): + symbol, set = _symbol_converter(symbol), _sympify(set) + return Basic.__new__(cls, symbol, set) + + @property + def symbol(self): + return self.args[0] + + @property + def symbols(self): + return FiniteSet(self.symbol) + + +class ConditionalDomain(RandomDomain): + """ + A RandomDomain with an attached condition. + + See Also + ======== + + sympy.stats.crv.ConditionalContinuousDomain + sympy.stats.frv.ConditionalFiniteDomain + """ + def __new__(cls, fulldomain, condition): + condition = condition.xreplace({rs: rs.symbol + for rs in random_symbols(condition)}) + return Basic.__new__(cls, fulldomain, condition) + + @property + def symbols(self): + return self.fulldomain.symbols + + @property + def fulldomain(self): + return self.args[0] + + @property + def condition(self): + return self.args[1] + + @property + def set(self): + raise NotImplementedError("Set of Conditional Domain not Implemented") + + def as_boolean(self): + return And(self.fulldomain.as_boolean(), self.condition) + + +class PSpace(Basic): + """ + A Probability Space. + + Explanation + =========== + + Probability Spaces encode processes that equal different values + probabilistically. These underly Random Symbols which occur in SymPy + expressions and contain the mechanics to evaluate statistical statements. + + See Also + ======== + + sympy.stats.crv.ContinuousPSpace + sympy.stats.frv.FinitePSpace + """ + + is_Finite: bool | None = None # Fails test if not set to None + is_Continuous: bool | None = None # Fails test if not set to None + is_Discrete: bool | None = None # Fails test if not set to None + is_real: bool | None + + @property + def domain(self): + return self.args[0] + + @property + def density(self): + return self.args[1] + + @property + def values(self): + return frozenset(RandomSymbol(sym, self) for sym in self.symbols) + + @property + def symbols(self): + return self.domain.symbols + + def where(self, condition): + raise NotImplementedError() + + def compute_density(self, expr): + raise NotImplementedError() + + def sample(self, size=(), library='scipy', seed=None): + raise NotImplementedError() + + def probability(self, condition): + raise NotImplementedError() + + def compute_expectation(self, expr): + raise NotImplementedError() + + +class SinglePSpace(PSpace): + """ + Represents the probabilities of a set of random events that can be + attributed to a single variable/symbol. + """ + def __new__(cls, s, distribution): + s = _symbol_converter(s) + return Basic.__new__(cls, s, distribution) + + @property + def value(self): + return RandomSymbol(self.symbol, self) + + @property + def symbol(self): + return self.args[0] + + @property + def distribution(self): + return self.args[1] + + @property + def pdf(self): + return self.distribution.pdf(self.symbol) + + +class RandomSymbol(Expr): + """ + Random Symbols represent ProbabilitySpaces in SymPy Expressions. + In principle they can take on any value that their symbol can take on + within the associated PSpace with probability determined by the PSpace + Density. + + Explanation + =========== + + Random Symbols contain pspace and symbol properties. + The pspace property points to the represented Probability Space + The symbol is a standard SymPy Symbol that is used in that probability space + for example in defining a density. + + You can form normal SymPy expressions using RandomSymbols and operate on + those expressions with the Functions + + E - Expectation of a random expression + P - Probability of a condition + density - Probability Density of an expression + given - A new random expression (with new random symbols) given a condition + + An object of the RandomSymbol type should almost never be created by the + user. They tend to be created instead by the PSpace class's value method. + Traditionally a user does not even do this but instead calls one of the + convenience functions Normal, Exponential, Coin, Die, FiniteRV, etc.... + """ + + def __new__(cls, symbol, pspace=None): + from sympy.stats.joint_rv import JointRandomSymbol + if pspace is None: + # Allow single arg, representing pspace == PSpace() + pspace = PSpace() + symbol = _symbol_converter(symbol) + if not isinstance(pspace, PSpace): + raise TypeError("pspace variable should be of type PSpace") + if cls == JointRandomSymbol and isinstance(pspace, SinglePSpace): + cls = RandomSymbol + return Basic.__new__(cls, symbol, pspace) + + is_finite = True + is_symbol = True + is_Atom = True + + _diff_wrt = True + + pspace = property(lambda self: self.args[1]) + symbol = property(lambda self: self.args[0]) + name = property(lambda self: self.symbol.name) + + def _eval_is_positive(self): + return self.symbol.is_positive + + def _eval_is_integer(self): + return self.symbol.is_integer + + def _eval_is_real(self): + return self.symbol.is_real or self.pspace.is_real + + @property + def is_commutative(self): + return self.symbol.is_commutative + + @property + def free_symbols(self): + return {self} + +class RandomIndexedSymbol(RandomSymbol): + + def __new__(cls, idx_obj, pspace=None): + if pspace is None: + # Allow single arg, representing pspace == PSpace() + pspace = PSpace() + if not isinstance(idx_obj, (Indexed, Function)): + raise TypeError("An Function or Indexed object is expected not %s"%(idx_obj)) + return Basic.__new__(cls, idx_obj, pspace) + + symbol = property(lambda self: self.args[0]) + name = property(lambda self: str(self.args[0])) + + @property + def key(self): + if isinstance(self.symbol, Indexed): + return self.symbol.args[1] + elif isinstance(self.symbol, Function): + return self.symbol.args[0] + + @property + def free_symbols(self): + if self.key.free_symbols: + free_syms = self.key.free_symbols + free_syms.add(self) + return free_syms + return {self} + + @property + def pspace(self): + return self.args[1] + +class RandomMatrixSymbol(RandomSymbol, MatrixSymbol): # type: ignore + def __new__(cls, symbol, n, m, pspace=None): + n, m = _sympify(n), _sympify(m) + symbol = _symbol_converter(symbol) + if pspace is None: + # Allow single arg, representing pspace == PSpace() + pspace = PSpace() + return Basic.__new__(cls, symbol, n, m, pspace) + + symbol = property(lambda self: self.args[0]) + pspace = property(lambda self: self.args[3]) + + +class ProductPSpace(PSpace): + """ + Abstract class for representing probability spaces with multiple random + variables. + + See Also + ======== + + sympy.stats.rv.IndependentProductPSpace + sympy.stats.joint_rv.JointPSpace + """ + pass + +class IndependentProductPSpace(ProductPSpace): + """ + A probability space resulting from the merger of two independent probability + spaces. + + Often created using the function, pspace. + """ + + def __new__(cls, *spaces): + rs_space_dict = {} + for space in spaces: + for value in space.values: + rs_space_dict[value] = space + + symbols = FiniteSet(*[val.symbol for val in rs_space_dict.keys()]) + + # Overlapping symbols + from sympy.stats.joint_rv import MarginalDistribution + from sympy.stats.compound_rv import CompoundDistribution + if len(symbols) < sum(len(space.symbols) for space in spaces if not + isinstance(space.distribution, ( + CompoundDistribution, MarginalDistribution))): + raise ValueError("Overlapping Random Variables") + + if all(space.is_Finite for space in spaces): + from sympy.stats.frv import ProductFinitePSpace + cls = ProductFinitePSpace + + obj = Basic.__new__(cls, *FiniteSet(*spaces)) + + return obj + + @property + def pdf(self): + p = Mul(*[space.pdf for space in self.spaces]) + return p.subs({rv: rv.symbol for rv in self.values}) + + @property + def rs_space_dict(self): + d = {} + for space in self.spaces: + for value in space.values: + d[value] = space + return d + + @property + def symbols(self): + return FiniteSet(*[val.symbol for val in self.rs_space_dict.keys()]) + + @property + def spaces(self): + return FiniteSet(*self.args) + + @property + def values(self): + return sumsets(space.values for space in self.spaces) + + def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): + rvs = rvs or self.values + rvs = frozenset(rvs) + for space in self.spaces: + expr = space.compute_expectation(expr, rvs & space.values, evaluate=False, **kwargs) + if evaluate and hasattr(expr, 'doit'): + return expr.doit(**kwargs) + return expr + + @property + def domain(self): + return ProductDomain(*[space.domain for space in self.spaces]) + + @property + def density(self): + raise NotImplementedError("Density not available for ProductSpaces") + + def sample(self, size=(), library='scipy', seed=None): + return {k: v for space in self.spaces + for k, v in space.sample(size=size, library=library, seed=seed).items()} + + + def probability(self, condition, **kwargs): + cond_inv = False + if isinstance(condition, Ne): + condition = Eq(condition.args[0], condition.args[1]) + cond_inv = True + elif isinstance(condition, And): # they are independent + return Mul(*[self.probability(arg) for arg in condition.args]) + elif isinstance(condition, Or): # they are independent + return Add(*[self.probability(arg) for arg in condition.args]) + expr = condition.lhs - condition.rhs + rvs = random_symbols(expr) + dens = self.compute_density(expr) + if any(pspace(rv).is_Continuous for rv in rvs): + from sympy.stats.crv import SingleContinuousPSpace + from sympy.stats.crv_types import ContinuousDistributionHandmade + if expr in self.values: + # Marginalize all other random symbols out of the density + randomsymbols = tuple(set(self.values) - frozenset([expr])) + symbols = tuple(rs.symbol for rs in randomsymbols) + pdf = self.domain.integrate(self.pdf, symbols, **kwargs) + return Lambda(expr.symbol, pdf) + dens = ContinuousDistributionHandmade(dens) + z = Dummy('z', real=True) + space = SingleContinuousPSpace(z, dens) + result = space.probability(condition.__class__(space.value, 0)) + else: + from sympy.stats.drv import SingleDiscretePSpace + from sympy.stats.drv_types import DiscreteDistributionHandmade + dens = DiscreteDistributionHandmade(dens) + z = Dummy('z', integer=True) + space = SingleDiscretePSpace(z, dens) + result = space.probability(condition.__class__(space.value, 0)) + return result if not cond_inv else S.One - result + + def compute_density(self, expr, **kwargs): + rvs = random_symbols(expr) + if any(pspace(rv).is_Continuous for rv in rvs): + z = Dummy('z', real=True) + expr = self.compute_expectation(DiracDelta(expr - z), + **kwargs) + else: + z = Dummy('z', integer=True) + expr = self.compute_expectation(KroneckerDelta(expr, z), + **kwargs) + return Lambda(z, expr) + + def compute_cdf(self, expr, **kwargs): + raise ValueError("CDF not well defined on multivariate expressions") + + def conditional_space(self, condition, normalize=True, **kwargs): + rvs = random_symbols(condition) + condition = condition.xreplace({rv: rv.symbol for rv in self.values}) + pspaces = [pspace(rv) for rv in rvs] + if any(ps.is_Continuous for ps in pspaces): + from sympy.stats.crv import (ConditionalContinuousDomain, + ContinuousPSpace) + space = ContinuousPSpace + domain = ConditionalContinuousDomain(self.domain, condition) + elif any(ps.is_Discrete for ps in pspaces): + from sympy.stats.drv import (ConditionalDiscreteDomain, + DiscretePSpace) + space = DiscretePSpace + domain = ConditionalDiscreteDomain(self.domain, condition) + elif all(ps.is_Finite for ps in pspaces): + from sympy.stats.frv import FinitePSpace + return FinitePSpace.conditional_space(self, condition) + if normalize: + replacement = {rv: Dummy(str(rv)) for rv in self.symbols} + norm = domain.compute_expectation(self.pdf, **kwargs) + pdf = self.pdf / norm.xreplace(replacement) + # XXX: Converting symbols from set to tuple. The order matters to + # Lambda though so we shouldn't be starting with a set here... + density = Lambda(tuple(domain.symbols), pdf) + + return space(domain, density) + +class ProductDomain(RandomDomain): + """ + A domain resulting from the merger of two independent domains. + + See Also + ======== + sympy.stats.crv.ProductContinuousDomain + sympy.stats.frv.ProductFiniteDomain + """ + is_ProductDomain = True + + def __new__(cls, *domains): + # Flatten any product of products + domains2 = [] + for domain in domains: + if not domain.is_ProductDomain: + domains2.append(domain) + else: + domains2.extend(domain.domains) + domains2 = FiniteSet(*domains2) + + if all(domain.is_Finite for domain in domains2): + from sympy.stats.frv import ProductFiniteDomain + cls = ProductFiniteDomain + if all(domain.is_Continuous for domain in domains2): + from sympy.stats.crv import ProductContinuousDomain + cls = ProductContinuousDomain + if all(domain.is_Discrete for domain in domains2): + from sympy.stats.drv import ProductDiscreteDomain + cls = ProductDiscreteDomain + + return Basic.__new__(cls, *domains2) + + @property + def sym_domain_dict(self): + return {symbol: domain for domain in self.domains + for symbol in domain.symbols} + + @property + def symbols(self): + return FiniteSet(*[sym for domain in self.domains + for sym in domain.symbols]) + + @property + def domains(self): + return self.args + + @property + def set(self): + return ProductSet(*(domain.set for domain in self.domains)) + + def __contains__(self, other): + # Split event into each subdomain + for domain in self.domains: + # Collect the parts of this event which associate to this domain + elem = frozenset([item for item in other + if sympify(domain.symbols.contains(item[0])) + is S.true]) + # Test this sub-event + if elem not in domain: + return False + # All subevents passed + return True + + def as_boolean(self): + return And(*[domain.as_boolean() for domain in self.domains]) + + +def random_symbols(expr): + """ + Returns all RandomSymbols within a SymPy Expression. + """ + atoms = getattr(expr, 'atoms', None) + if atoms is not None: + comp = lambda rv: rv.symbol.name + l = list(atoms(RandomSymbol)) + return sorted(l, key=comp) + else: + return [] + + +def pspace(expr): + """ + Returns the underlying Probability Space of a random expression. + + For internal use. + + Examples + ======== + + >>> from sympy.stats import pspace, Normal + >>> X = Normal('X', 0, 1) + >>> pspace(2*X + 1) == X.pspace + True + """ + expr = sympify(expr) + if isinstance(expr, RandomSymbol) and expr.pspace is not None: + return expr.pspace + if expr.has(RandomMatrixSymbol): + rm = list(expr.atoms(RandomMatrixSymbol))[0] + return rm.pspace + + rvs = random_symbols(expr) + if not rvs: + raise ValueError("Expression containing Random Variable expected, not %s" % (expr)) + # If only one space present + if all(rv.pspace == rvs[0].pspace for rv in rvs): + return rvs[0].pspace + from sympy.stats.compound_rv import CompoundPSpace + from sympy.stats.stochastic_process import StochasticPSpace + for rv in rvs: + if isinstance(rv.pspace, (CompoundPSpace, StochasticPSpace)): + return rv.pspace + # Otherwise make a product space + return IndependentProductPSpace(*[rv.pspace for rv in rvs]) + + +def sumsets(sets): + """ + Union of sets + """ + return frozenset().union(*sets) + + +def rs_swap(a, b): + """ + Build a dictionary to swap RandomSymbols based on their underlying symbol. + + i.e. + if ``X = ('x', pspace1)`` + and ``Y = ('x', pspace2)`` + then ``X`` and ``Y`` match and the key, value pair + ``{X:Y}`` will appear in the result + + Inputs: collections a and b of random variables which share common symbols + Output: dict mapping RVs in a to RVs in b + """ + d = {} + for rsa in a: + d[rsa] = [rsb for rsb in b if rsa.symbol == rsb.symbol][0] + return d + + +def given(expr, condition=None, **kwargs): + r""" Conditional Random Expression. + + Explanation + =========== + + From a random expression and a condition on that expression creates a new + probability space from the condition and returns the same expression on that + conditional probability space. + + Examples + ======== + + >>> from sympy.stats import given, density, Die + >>> X = Die('X', 6) + >>> Y = given(X, X > 3) + >>> density(Y).dict + {4: 1/3, 5: 1/3, 6: 1/3} + + Following convention, if the condition is a random symbol then that symbol + is considered fixed. + + >>> from sympy.stats import Normal + >>> from sympy import pprint + >>> from sympy.abc import z + + >>> X = Normal('X', 0, 1) + >>> Y = Normal('Y', 0, 1) + >>> pprint(density(X + Y, Y)(z), use_unicode=False) + 2 + -(-Y + z) + ----------- + ___ 2 + \/ 2 *e + ------------------ + ____ + 2*\/ pi + """ + + if not is_random(condition) or pspace_independent(expr, condition): + return expr + + if isinstance(condition, RandomSymbol): + condition = Eq(condition, condition.symbol) + + condsymbols = random_symbols(condition) + if (isinstance(condition, Eq) and len(condsymbols) == 1 and + not isinstance(pspace(expr).domain, ConditionalDomain)): + rv = tuple(condsymbols)[0] + + results = solveset(condition, rv) + if isinstance(results, Intersection) and S.Reals in results.args: + results = list(results.args[1]) + + sums = 0 + for res in results: + temp = expr.subs(rv, res) + if temp == True: + return True + if temp != False: + # XXX: This seems nonsensical but preserves existing behaviour + # after the change that Relational is no longer a subclass of + # Expr. Here expr is sometimes Relational and sometimes Expr + # but we are trying to add them with +=. This needs to be + # fixed somehow. + if sums == 0 and isinstance(expr, Relational): + sums = expr.subs(rv, res) + else: + sums += expr.subs(rv, res) + if sums == 0: + return False + return sums + + # Get full probability space of both the expression and the condition + fullspace = pspace(Tuple(expr, condition)) + # Build new space given the condition + space = fullspace.conditional_space(condition, **kwargs) + # Dictionary to swap out RandomSymbols in expr with new RandomSymbols + # That point to the new conditional space + swapdict = rs_swap(fullspace.values, space.values) + # Swap random variables in the expression + expr = expr.xreplace(swapdict) + return expr + + +def expectation(expr, condition=None, numsamples=None, evaluate=True, **kwargs): + """ + Returns the expected value of a random expression. + + Parameters + ========== + + expr : Expr containing RandomSymbols + The expression of which you want to compute the expectation value + given : Expr containing RandomSymbols + A conditional expression. E(X, X>0) is expectation of X given X > 0 + numsamples : int + Enables sampling and approximates the expectation with this many samples + evalf : Bool (defaults to True) + If sampling return a number rather than a complex expression + evaluate : Bool (defaults to True) + In case of continuous systems return unevaluated integral + + Examples + ======== + + >>> from sympy.stats import E, Die + >>> X = Die('X', 6) + >>> E(X) + 7/2 + >>> E(2*X + 1) + 8 + + >>> E(X, X > 3) # Expectation of X given that it is above 3 + 5 + """ + + if not is_random(expr): # expr isn't random? + return expr + kwargs['numsamples'] = numsamples + from sympy.stats.symbolic_probability import Expectation + if evaluate: + return Expectation(expr, condition).doit(**kwargs) + return Expectation(expr, condition) + + +def probability(condition, given_condition=None, numsamples=None, + evaluate=True, **kwargs): + """ + Probability that a condition is true, optionally given a second condition. + + Parameters + ========== + + condition : Combination of Relationals containing RandomSymbols + The condition of which you want to compute the probability + given_condition : Combination of Relationals containing RandomSymbols + A conditional expression. P(X > 1, X > 0) is expectation of X > 1 + given X > 0 + numsamples : int + Enables sampling and approximates the probability with this many samples + evaluate : Bool (defaults to True) + In case of continuous systems return unevaluated integral + + Examples + ======== + + >>> from sympy.stats import P, Die + >>> from sympy import Eq + >>> X, Y = Die('X', 6), Die('Y', 6) + >>> P(X > 3) + 1/2 + >>> P(Eq(X, 5), X > 2) # Probability that X == 5 given that X > 2 + 1/4 + >>> P(X > Y) + 5/12 + """ + + kwargs['numsamples'] = numsamples + from sympy.stats.symbolic_probability import Probability + if evaluate: + return Probability(condition, given_condition).doit(**kwargs) + return Probability(condition, given_condition) + + +class Density(Basic): + expr = property(lambda self: self.args[0]) + + def __new__(cls, expr, condition = None): + expr = _sympify(expr) + if condition is None: + obj = Basic.__new__(cls, expr) + else: + condition = _sympify(condition) + obj = Basic.__new__(cls, expr, condition) + return obj + + @property + def condition(self): + if len(self.args) > 1: + return self.args[1] + else: + return None + + def doit(self, evaluate=True, **kwargs): + from sympy.stats.random_matrix import RandomMatrixPSpace + from sympy.stats.joint_rv import JointPSpace + from sympy.stats.matrix_distributions import MatrixPSpace + from sympy.stats.compound_rv import CompoundPSpace + from sympy.stats.frv import SingleFiniteDistribution + expr, condition = self.expr, self.condition + + if isinstance(expr, SingleFiniteDistribution): + return expr.dict + if condition is not None: + # Recompute on new conditional expr + expr = given(expr, condition, **kwargs) + if not random_symbols(expr): + return Lambda(x, DiracDelta(x - expr)) + if isinstance(expr, RandomSymbol): + if isinstance(expr.pspace, (SinglePSpace, JointPSpace, MatrixPSpace)) and \ + hasattr(expr.pspace, 'distribution'): + return expr.pspace.distribution + elif isinstance(expr.pspace, RandomMatrixPSpace): + return expr.pspace.model + if isinstance(pspace(expr), CompoundPSpace): + kwargs['compound_evaluate'] = evaluate + result = pspace(expr).compute_density(expr, **kwargs) + + if evaluate and hasattr(result, 'doit'): + return result.doit() + else: + return result + + +def density(expr, condition=None, evaluate=True, numsamples=None, **kwargs): + """ + Probability density of a random expression, optionally given a second + condition. + + Explanation + =========== + + This density will take on different forms for different types of + probability spaces. Discrete variables produce Dicts. Continuous + variables produce Lambdas. + + Parameters + ========== + + expr : Expr containing RandomSymbols + The expression of which you want to compute the density value + condition : Relational containing RandomSymbols + A conditional expression. density(X > 1, X > 0) is density of X > 1 + given X > 0 + numsamples : int + Enables sampling and approximates the density with this many samples + + Examples + ======== + + >>> from sympy.stats import density, Die, Normal + >>> from sympy import Symbol + + >>> x = Symbol('x') + >>> D = Die('D', 6) + >>> X = Normal(x, 0, 1) + + >>> density(D).dict + {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6} + >>> density(2*D).dict + {2: 1/6, 4: 1/6, 6: 1/6, 8: 1/6, 10: 1/6, 12: 1/6} + >>> density(X)(x) + sqrt(2)*exp(-x**2/2)/(2*sqrt(pi)) + """ + + if numsamples: + return sampling_density(expr, condition, numsamples=numsamples, + **kwargs) + + return Density(expr, condition).doit(evaluate=evaluate, **kwargs) + + +def cdf(expr, condition=None, evaluate=True, **kwargs): + """ + Cumulative Distribution Function of a random expression. + + optionally given a second condition. + + Explanation + =========== + + This density will take on different forms for different types of + probability spaces. + Discrete variables produce Dicts. + Continuous variables produce Lambdas. + + Examples + ======== + + >>> from sympy.stats import density, Die, Normal, cdf + + >>> D = Die('D', 6) + >>> X = Normal('X', 0, 1) + + >>> density(D).dict + {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6} + >>> cdf(D) + {1: 1/6, 2: 1/3, 3: 1/2, 4: 2/3, 5: 5/6, 6: 1} + >>> cdf(3*D, D > 2) + {9: 1/4, 12: 1/2, 15: 3/4, 18: 1} + + >>> cdf(X) + Lambda(_z, erf(sqrt(2)*_z/2)/2 + 1/2) + """ + if condition is not None: # If there is a condition + # Recompute on new conditional expr + return cdf(given(expr, condition, **kwargs), **kwargs) + + # Otherwise pass work off to the ProbabilitySpace + result = pspace(expr).compute_cdf(expr, **kwargs) + + if evaluate and hasattr(result, 'doit'): + return result.doit() + else: + return result + + +def characteristic_function(expr, condition=None, evaluate=True, **kwargs): + """ + Characteristic function of a random expression, optionally given a second condition. + + Returns a Lambda. + + Examples + ======== + + >>> from sympy.stats import Normal, DiscreteUniform, Poisson, characteristic_function + + >>> X = Normal('X', 0, 1) + >>> characteristic_function(X) + Lambda(_t, exp(-_t**2/2)) + + >>> Y = DiscreteUniform('Y', [1, 2, 7]) + >>> characteristic_function(Y) + Lambda(_t, exp(7*_t*I)/3 + exp(2*_t*I)/3 + exp(_t*I)/3) + + >>> Z = Poisson('Z', 2) + >>> characteristic_function(Z) + Lambda(_t, exp(2*exp(_t*I) - 2)) + """ + if condition is not None: + return characteristic_function(given(expr, condition, **kwargs), **kwargs) + + result = pspace(expr).compute_characteristic_function(expr, **kwargs) + + if evaluate and hasattr(result, 'doit'): + return result.doit() + else: + return result + +def moment_generating_function(expr, condition=None, evaluate=True, **kwargs): + if condition is not None: + return moment_generating_function(given(expr, condition, **kwargs), **kwargs) + + result = pspace(expr).compute_moment_generating_function(expr, **kwargs) + + if evaluate and hasattr(result, 'doit'): + return result.doit() + else: + return result + +def where(condition, given_condition=None, **kwargs): + """ + Returns the domain where a condition is True. + + Examples + ======== + + >>> from sympy.stats import where, Die, Normal + >>> from sympy import And + + >>> D1, D2 = Die('a', 6), Die('b', 6) + >>> a, b = D1.symbol, D2.symbol + >>> X = Normal('x', 0, 1) + + >>> where(X**2<1) + Domain: (-1 < x) & (x < 1) + + >>> where(X**2<1).set + Interval.open(-1, 1) + + >>> where(And(D1<=D2, D2<3)) + Domain: (Eq(a, 1) & Eq(b, 1)) | (Eq(a, 1) & Eq(b, 2)) | (Eq(a, 2) & Eq(b, 2)) + """ + if given_condition is not None: # If there is a condition + # Recompute on new conditional expr + return where(given(condition, given_condition, **kwargs), **kwargs) + + # Otherwise pass work off to the ProbabilitySpace + return pspace(condition).where(condition, **kwargs) + + +@doctest_depends_on(modules=('scipy',)) +def sample(expr, condition=None, size=(), library='scipy', + numsamples=1, seed=None, **kwargs): + """ + A realization of the random expression. + + Parameters + ========== + + expr : Expression of random variables + Expression from which sample is extracted + condition : Expr containing RandomSymbols + A conditional expression + size : int, tuple + Represents size of each sample in numsamples + library : str + - 'scipy' : Sample using scipy + - 'numpy' : Sample using numpy + - 'pymc' : Sample using PyMC + + Choose any of the available options to sample from as string, + by default is 'scipy' + numsamples : int + Number of samples, each with size as ``size``. + + .. deprecated:: 1.9 + + The ``numsamples`` parameter is deprecated and is only provided for + compatibility with v1.8. Use a list comprehension or an additional + dimension in ``size`` instead. See + :ref:`deprecated-sympy-stats-numsamples` for details. + + seed : + An object to be used as seed by the given external library for sampling `expr`. + Following is the list of possible types of object for the supported libraries, + + - 'scipy': int, numpy.random.RandomState, numpy.random.Generator + - 'numpy': int, numpy.random.RandomState, numpy.random.Generator + - 'pymc': int + + Optional, by default None, in which case seed settings + related to the given library will be used. + No modifications to environment's global seed settings + are done by this argument. + + Returns + ======= + + sample: float/list/numpy.ndarray + one sample or a collection of samples of the random expression. + + - sample(X) returns float/numpy.float64/numpy.int64 object. + - sample(X, size=int/tuple) returns numpy.ndarray object. + + Examples + ======== + + >>> from sympy.stats import Die, sample, Normal, Geometric + >>> X, Y, Z = Die('X', 6), Die('Y', 6), Die('Z', 6) # Finite Random Variable + >>> die_roll = sample(X + Y + Z) + >>> die_roll # doctest: +SKIP + 3 + >>> N = Normal('N', 3, 4) # Continuous Random Variable + >>> samp = sample(N) + >>> samp in N.pspace.domain.set + True + >>> samp = sample(N, N>0) + >>> samp > 0 + True + >>> samp_list = sample(N, size=4) + >>> [sam in N.pspace.domain.set for sam in samp_list] + [True, True, True, True] + >>> sample(N, size = (2,3)) # doctest: +SKIP + array([[5.42519758, 6.40207856, 4.94991743], + [1.85819627, 6.83403519, 1.9412172 ]]) + >>> G = Geometric('G', 0.5) # Discrete Random Variable + >>> samp_list = sample(G, size=3) + >>> samp_list # doctest: +SKIP + [1, 3, 2] + >>> [sam in G.pspace.domain.set for sam in samp_list] + [True, True, True] + >>> MN = Normal("MN", [3, 4], [[2, 1], [1, 2]]) # Joint Random Variable + >>> samp_list = sample(MN, size=4) + >>> samp_list # doctest: +SKIP + [array([2.85768055, 3.38954165]), + array([4.11163337, 4.3176591 ]), + array([0.79115232, 1.63232916]), + array([4.01747268, 3.96716083])] + >>> [tuple(sam) in MN.pspace.domain.set for sam in samp_list] + [True, True, True, True] + + .. versionchanged:: 1.7.0 + sample used to return an iterator containing the samples instead of value. + + .. versionchanged:: 1.9.0 + sample returns values or array of values instead of an iterator and numsamples is deprecated. + + """ + + iterator = sample_iter(expr, condition, size=size, library=library, + numsamples=numsamples, seed=seed) + + if numsamples != 1: + sympy_deprecation_warning( + f""" + The numsamples parameter to sympy.stats.sample() is deprecated. + Either use a list comprehension, like + + [sample(...) for i in range({numsamples})] + + or add a dimension to size, like + + sample(..., size={(numsamples,) + size}) + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-sympy-stats-numsamples", + ) + return [next(iterator) for i in range(numsamples)] + + return next(iterator) + + +def quantile(expr, evaluate=True, **kwargs): + r""" + Return the :math:`p^{th}` order quantile of a probability distribution. + + Explanation + =========== + + Quantile is defined as the value at which the probability of the random + variable is less than or equal to the given probability. + + .. math:: + Q(p) = \inf\{x \in (-\infty, \infty) : p \le F(x)\} + + Examples + ======== + + >>> from sympy.stats import quantile, Die, Exponential + >>> from sympy import Symbol, pprint + >>> p = Symbol("p") + + >>> l = Symbol("lambda", positive=True) + >>> X = Exponential("x", l) + >>> quantile(X)(p) + -log(1 - p)/lambda + + >>> D = Die("d", 6) + >>> pprint(quantile(D)(p), use_unicode=False) + /nan for Or(p > 1, p < 0) + | + | 1 for p <= 1/6 + | + | 2 for p <= 1/3 + | + < 3 for p <= 1/2 + | + | 4 for p <= 2/3 + | + | 5 for p <= 5/6 + | + \ 6 for p <= 1 + + """ + result = pspace(expr).compute_quantile(expr, **kwargs) + + if evaluate and hasattr(result, 'doit'): + return result.doit() + else: + return result + +def sample_iter(expr, condition=None, size=(), library='scipy', + numsamples=S.Infinity, seed=None, **kwargs): + + """ + Returns an iterator of realizations from the expression given a condition. + + Parameters + ========== + + expr: Expr + Random expression to be realized + condition: Expr, optional + A conditional expression + size : int, tuple + Represents size of each sample in numsamples + numsamples: integer, optional + Length of the iterator (defaults to infinity) + seed : + An object to be used as seed by the given external library for sampling `expr`. + Following is the list of possible types of object for the supported libraries, + + - 'scipy': int, numpy.random.RandomState, numpy.random.Generator + - 'numpy': int, numpy.random.RandomState, numpy.random.Generator + - 'pymc': int + + Optional, by default None, in which case seed settings + related to the given library will be used. + No modifications to environment's global seed settings + are done by this argument. + + Examples + ======== + + >>> from sympy.stats import Normal, sample_iter + >>> X = Normal('X', 0, 1) + >>> expr = X*X + 3 + >>> iterator = sample_iter(expr, numsamples=3) # doctest: +SKIP + >>> list(iterator) # doctest: +SKIP + [12, 4, 7] + + Returns + ======= + + sample_iter: iterator object + iterator object containing the sample/samples of given expr + + See Also + ======== + + sample + sampling_P + sampling_E + + """ + from sympy.stats.joint_rv import JointRandomSymbol + if not import_module(library): + raise ValueError("Failed to import %s" % library) + + if condition is not None: + ps = pspace(Tuple(expr, condition)) + else: + ps = pspace(expr) + + rvs = list(ps.values) + if isinstance(expr, JointRandomSymbol): + expr = expr.subs({expr: RandomSymbol(expr.symbol, expr.pspace)}) + else: + sub = {} + for arg in expr.args: + if isinstance(arg, JointRandomSymbol): + sub[arg] = RandomSymbol(arg.symbol, arg.pspace) + expr = expr.subs(sub) + + def fn_subs(*args): + return expr.subs(dict(zip(rvs, args))) + + def given_fn_subs(*args): + if condition is not None: + return condition.subs(dict(zip(rvs, args))) + return False + + if library in ('pymc', 'pymc3'): + # Currently unable to lambdify in pymc + # TODO : Remove when lambdify accepts 'pymc' as module + fn = lambdify(rvs, expr, **kwargs) + else: + fn = lambdify(rvs, expr, modules=library, **kwargs) + + + if condition is not None: + given_fn = lambdify(rvs, condition, **kwargs) + + def return_generator_infinite(): + count = 0 + _size = (1,)+((size,) if isinstance(size, int) else size) + while count < numsamples: + d = ps.sample(size=_size, library=library, seed=seed) # a dictionary that maps RVs to values + args = [d[rv][0] for rv in rvs] + + if condition is not None: # Check that these values satisfy the condition + # TODO: Replace the try-except block with only given_fn(*args) + # once lambdify works with unevaluated SymPy objects. + try: + gd = given_fn(*args) + except (NameError, TypeError): + gd = given_fn_subs(*args) + if gd != True and gd != False: + raise ValueError( + "Conditions must not contain free symbols") + if not gd: # If the values don't satisfy then try again + continue + + yield fn(*args) + count += 1 + + def return_generator_finite(): + faulty = True + while faulty: + d = ps.sample(size=(numsamples,) + ((size,) if isinstance(size, int) else size), + library=library, seed=seed) # a dictionary that maps RVs to values + + faulty = False + count = 0 + while count < numsamples and not faulty: + args = [d[rv][count] for rv in rvs] + if condition is not None: # Check that these values satisfy the condition + # TODO: Replace the try-except block with only given_fn(*args) + # once lambdify works with unevaluated SymPy objects. + try: + gd = given_fn(*args) + except (NameError, TypeError): + gd = given_fn_subs(*args) + if gd != True and gd != False: + raise ValueError( + "Conditions must not contain free symbols") + if not gd: # If the values don't satisfy then try again + faulty = True + + count += 1 + + count = 0 + while count < numsamples: + args = [d[rv][count] for rv in rvs] + # TODO: Replace the try-except block with only fn(*args) + # once lambdify works with unevaluated SymPy objects. + try: + yield fn(*args) + except (NameError, TypeError): + yield fn_subs(*args) + count += 1 + + if numsamples is S.Infinity: + return return_generator_infinite() + + return return_generator_finite() + +def sample_iter_lambdify(expr, condition=None, size=(), + numsamples=S.Infinity, seed=None, **kwargs): + + return sample_iter(expr, condition=condition, size=size, + numsamples=numsamples, seed=seed, **kwargs) + +def sample_iter_subs(expr, condition=None, size=(), + numsamples=S.Infinity, seed=None, **kwargs): + + return sample_iter(expr, condition=condition, size=size, + numsamples=numsamples, seed=seed, **kwargs) + + +def sampling_P(condition, given_condition=None, library='scipy', numsamples=1, + evalf=True, seed=None, **kwargs): + """ + Sampling version of P. + + See Also + ======== + + P + sampling_E + sampling_density + + """ + + count_true = 0 + count_false = 0 + samples = sample_iter(condition, given_condition, library=library, + numsamples=numsamples, seed=seed, **kwargs) + + for sample in samples: + if sample: + count_true += 1 + else: + count_false += 1 + + result = S(count_true) / numsamples + if evalf: + return result.evalf() + else: + return result + + +def sampling_E(expr, given_condition=None, library='scipy', numsamples=1, + evalf=True, seed=None, **kwargs): + """ + Sampling version of E. + + See Also + ======== + + P + sampling_P + sampling_density + """ + samples = list(sample_iter(expr, given_condition, library=library, + numsamples=numsamples, seed=seed, **kwargs)) + result = Add(*samples) / numsamples + + if evalf: + return result.evalf() + else: + return result + +def sampling_density(expr, given_condition=None, library='scipy', + numsamples=1, seed=None, **kwargs): + """ + Sampling version of density. + + See Also + ======== + density + sampling_P + sampling_E + """ + + results = {} + for result in sample_iter(expr, given_condition, library=library, + numsamples=numsamples, seed=seed, **kwargs): + results[result] = results.get(result, 0) + 1 + + return results + + +def dependent(a, b): + """ + Dependence of two random expressions. + + Two expressions are independent if knowledge of one does not change + computations on the other. + + Examples + ======== + + >>> from sympy.stats import Normal, dependent, given + >>> from sympy import Tuple, Eq + + >>> X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + >>> dependent(X, Y) + False + >>> dependent(2*X + Y, -Y) + True + >>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3)) + >>> dependent(X, Y) + True + + See Also + ======== + + independent + """ + if pspace_independent(a, b): + return False + + z = Symbol('z', real=True) + # Dependent if density is unchanged when one is given information about + # the other + return (density(a, Eq(b, z)) != density(a) or + density(b, Eq(a, z)) != density(b)) + + +def independent(a, b): + """ + Independence of two random expressions. + + Two expressions are independent if knowledge of one does not change + computations on the other. + + Examples + ======== + + >>> from sympy.stats import Normal, independent, given + >>> from sympy import Tuple, Eq + + >>> X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + >>> independent(X, Y) + True + >>> independent(2*X + Y, -Y) + False + >>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3)) + >>> independent(X, Y) + False + + See Also + ======== + + dependent + """ + return not dependent(a, b) + + +def pspace_independent(a, b): + """ + Tests for independence between a and b by checking if their PSpaces have + overlapping symbols. This is a sufficient but not necessary condition for + independence and is intended to be used internally. + + Notes + ===== + + pspace_independent(a, b) implies independent(a, b) + independent(a, b) does not imply pspace_independent(a, b) + """ + a_symbols = set(pspace(b).symbols) + b_symbols = set(pspace(a).symbols) + + if len(set(random_symbols(a)).intersection(random_symbols(b))) != 0: + return False + + if len(a_symbols.intersection(b_symbols)) == 0: + return True + return None + + +def rv_subs(expr, symbols=None): + """ + Given a random expression replace all random variables with their symbols. + + If symbols keyword is given restrict the swap to only the symbols listed. + """ + if symbols is None: + symbols = random_symbols(expr) + if not symbols: + return expr + swapdict = {rv: rv.symbol for rv in symbols} + return expr.subs(swapdict) + + +class NamedArgsMixin: + _argnames: tuple[str, ...] = () + + def __getattr__(self, attr): + try: + return self.args[self._argnames.index(attr)] + except ValueError: + raise AttributeError("'%s' object has no attribute '%s'" % ( + type(self).__name__, attr)) + + +class Distribution(Basic): + + def sample(self, size=(), library='scipy', seed=None): + """ A random realization from the distribution """ + + module = import_module(library) + if library in {'scipy', 'numpy', 'pymc3', 'pymc'} and module is None: + raise ValueError("Failed to import %s" % library) + + if library == 'scipy': + # scipy does not require map as it can handle using custom distributions. + # However, we will still use a map where we can. + + # TODO: do this for drv.py and frv.py if necessary. + # TODO: add more distributions here if there are more + # See links below referring to sections beginning with "A common parametrization..." + # I will remove all these comments if everything is ok. + + from sympy.stats.sampling.sample_scipy import do_sample_scipy + import numpy + if seed is None or isinstance(seed, int): + rand_state = numpy.random.default_rng(seed=seed) + else: + rand_state = seed + samps = do_sample_scipy(self, size, rand_state) + + elif library == 'numpy': + from sympy.stats.sampling.sample_numpy import do_sample_numpy + import numpy + if seed is None or isinstance(seed, int): + rand_state = numpy.random.default_rng(seed=seed) + else: + rand_state = seed + _size = None if size == () else size + samps = do_sample_numpy(self, _size, rand_state) + elif library in ('pymc', 'pymc3'): + from sympy.stats.sampling.sample_pymc import do_sample_pymc + import logging + logging.getLogger("pymc").setLevel(logging.ERROR) + try: + import pymc + except ImportError: + import pymc3 as pymc + + with pymc.Model(): + if do_sample_pymc(self) is not None: + samps = pymc.sample(draws=prod(size), chains=1, compute_convergence_checks=False, + progressbar=False, random_seed=seed, return_inferencedata=False)[:]['X'] + samps = samps.reshape(size) + else: + samps = None + + else: + raise NotImplementedError("Sampling from %s is not supported yet." + % str(library)) + + if samps is not None: + return samps + raise NotImplementedError( + "Sampling for %s is not currently implemented from %s" + % (self, library)) + + +def _value_check(condition, message): + """ + Raise a ValueError with message if condition is False, else + return True if all conditions were True, else False. + + Examples + ======== + + >>> from sympy.stats.rv import _value_check + >>> from sympy.abc import a, b, c + >>> from sympy import And, Dummy + + >>> _value_check(2 < 3, '') + True + + Here, the condition is not False, but it does not evaluate to True + so False is returned (but no error is raised). So checking if the + return value is True or False will tell you if all conditions were + evaluated. + + >>> _value_check(a < b, '') + False + + In this case the condition is False so an error is raised: + + >>> r = Dummy(real=True) + >>> _value_check(r < r - 1, 'condition is not true') + Traceback (most recent call last): + ... + ValueError: condition is not true + + If no condition of many conditions must be False, they can be + checked by passing them as an iterable: + + >>> _value_check((a < 0, b < 0, c < 0), '') + False + + The iterable can be a generator, too: + + >>> _value_check((i < 0 for i in (a, b, c)), '') + False + + The following are equivalent to the above but do not pass + an iterable: + + >>> all(_value_check(i < 0, '') for i in (a, b, c)) + False + >>> _value_check(And(a < 0, b < 0, c < 0), '') + False + """ + if not iterable(condition): + condition = [condition] + truth = fuzzy_and(condition) + if truth == False: + raise ValueError(message) + return truth == True + +def _symbol_converter(sym): + """ + Casts the parameter to Symbol if it is 'str' + otherwise no operation is performed on it. + + Parameters + ========== + + sym + The parameter to be converted. + + Returns + ======= + + Symbol + the parameter converted to Symbol. + + Raises + ====== + + TypeError + If the parameter is not an instance of both str and + Symbol. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.stats.rv import _symbol_converter + >>> s = _symbol_converter('s') + >>> isinstance(s, Symbol) + True + >>> _symbol_converter(1) + Traceback (most recent call last): + ... + TypeError: 1 is neither a Symbol nor a string + >>> r = Symbol('r') + >>> isinstance(r, Symbol) + True + """ + if isinstance(sym, str): + sym = Symbol(sym) + if not isinstance(sym, Symbol): + raise TypeError("%s is neither a Symbol nor a string"%(sym)) + return sym + +def sample_stochastic_process(process): + """ + This function is used to sample from stochastic process. + + Parameters + ========== + + process: StochasticProcess + Process used to extract the samples. It must be an instance of + StochasticProcess + + Examples + ======== + + >>> from sympy.stats import sample_stochastic_process, DiscreteMarkovChain + >>> from sympy import Matrix + >>> T = Matrix([[0.5, 0.2, 0.3],[0.2, 0.5, 0.3],[0.2, 0.3, 0.5]]) + >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + >>> next(sample_stochastic_process(Y)) in Y.state_space + True + >>> next(sample_stochastic_process(Y)) # doctest: +SKIP + 0 + >>> next(sample_stochastic_process(Y)) # doctest: +SKIP + 2 + + Returns + ======= + + sample: iterator object + iterator object containing the sample of given process + + """ + from sympy.stats.stochastic_process_types import StochasticProcess + if not isinstance(process, StochasticProcess): + raise ValueError("Process must be an instance of Stochastic Process") + return process.sample() diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/rv_interface.py b/.venv/lib/python3.13/site-packages/sympy/stats/rv_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..16d65b83634cdb04ef7e5046175848cdf380434b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/rv_interface.py @@ -0,0 +1,519 @@ +from sympy.sets import FiniteSet +from sympy.core.numbers import Rational +from sympy.core.relational import Eq +from sympy.core.symbol import Dummy +from sympy.functions.combinatorial.factorials import FallingFactorial +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import piecewise_fold +from sympy.integrals.integrals import Integral +from sympy.solvers.solveset import solveset +from .rv import (probability, expectation, density, where, given, pspace, cdf, PSpace, + characteristic_function, sample, sample_iter, random_symbols, independent, dependent, + sampling_density, moment_generating_function, quantile, is_random, + sample_stochastic_process) + + +__all__ = ['P', 'E', 'H', 'density', 'where', 'given', 'sample', 'cdf', + 'characteristic_function', 'pspace', 'sample_iter', 'variance', 'std', + 'skewness', 'kurtosis', 'covariance', 'dependent', 'entropy', 'median', + 'independent', 'random_symbols', 'correlation', 'factorial_moment', + 'moment', 'cmoment', 'sampling_density', 'moment_generating_function', + 'smoment', 'quantile', 'sample_stochastic_process'] + + + +def moment(X, n, c=0, condition=None, *, evaluate=True, **kwargs): + """ + Return the nth moment of a random expression about c. + + .. math:: + moment(X, c, n) = E((X-c)^{n}) + + Default value of c is 0. + + Examples + ======== + + >>> from sympy.stats import Die, moment, E + >>> X = Die('X', 6) + >>> moment(X, 1, 6) + -5/2 + >>> moment(X, 2) + 91/6 + >>> moment(X, 1) == E(X) + True + """ + from sympy.stats.symbolic_probability import Moment + if evaluate: + return Moment(X, n, c, condition).doit() + return Moment(X, n, c, condition).rewrite(Integral) + + +def variance(X, condition=None, **kwargs): + """ + Variance of a random expression. + + .. math:: + variance(X) = E((X-E(X))^{2}) + + Examples + ======== + + >>> from sympy.stats import Die, Bernoulli, variance + >>> from sympy import simplify, Symbol + + >>> X = Die('X', 6) + >>> p = Symbol('p') + >>> B = Bernoulli('B', p, 1, 0) + + >>> variance(2*X) + 35/3 + + >>> simplify(variance(B)) + p*(1 - p) + """ + if is_random(X) and pspace(X) == PSpace(): + from sympy.stats.symbolic_probability import Variance + return Variance(X, condition) + + return cmoment(X, 2, condition, **kwargs) + + +def standard_deviation(X, condition=None, **kwargs): + r""" + Standard Deviation of a random expression + + .. math:: + std(X) = \sqrt(E((X-E(X))^{2})) + + Examples + ======== + + >>> from sympy.stats import Bernoulli, std + >>> from sympy import Symbol, simplify + + >>> p = Symbol('p') + >>> B = Bernoulli('B', p, 1, 0) + + >>> simplify(std(B)) + sqrt(p*(1 - p)) + """ + return sqrt(variance(X, condition, **kwargs)) +std = standard_deviation + +def entropy(expr, condition=None, **kwargs): + """ + Calculates entropy of a probability distribution. + + Parameters + ========== + + expression : the random expression whose entropy is to be calculated + condition : optional, to specify conditions on random expression + b: base of the logarithm, optional + By default, it is taken as Euler's number + + Returns + ======= + + result : Entropy of the expression, a constant + + Examples + ======== + + >>> from sympy.stats import Normal, Die, entropy + >>> X = Normal('X', 0, 1) + >>> entropy(X) + log(2)/2 + 1/2 + log(pi)/2 + + >>> D = Die('D', 4) + >>> entropy(D) + log(4) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Entropy_%28information_theory%29 + .. [2] https://www.crmarsh.com/static/pdf/Charles_Marsh_Continuous_Entropy.pdf + .. [3] https://kconrad.math.uconn.edu/blurbs/analysis/entropypost.pdf + """ + pdf = density(expr, condition, **kwargs) + base = kwargs.get('b', exp(1)) + if isinstance(pdf, dict): + return sum(-prob*log(prob, base) for prob in pdf.values()) + return expectation(-log(pdf(expr), base)) + +def covariance(X, Y, condition=None, **kwargs): + """ + Covariance of two random expressions. + + Explanation + =========== + + The expectation that the two variables will rise and fall together + + .. math:: + covariance(X,Y) = E((X-E(X)) (Y-E(Y))) + + Examples + ======== + + >>> from sympy.stats import Exponential, covariance + >>> from sympy import Symbol + + >>> rate = Symbol('lambda', positive=True, real=True) + >>> X = Exponential('X', rate) + >>> Y = Exponential('Y', rate) + + >>> covariance(X, X) + lambda**(-2) + >>> covariance(X, Y) + 0 + >>> covariance(X, Y + rate*X) + 1/lambda + """ + if (is_random(X) and pspace(X) == PSpace()) or (is_random(Y) and pspace(Y) == PSpace()): + from sympy.stats.symbolic_probability import Covariance + return Covariance(X, Y, condition) + + return expectation( + (X - expectation(X, condition, **kwargs)) * + (Y - expectation(Y, condition, **kwargs)), + condition, **kwargs) + + +def correlation(X, Y, condition=None, **kwargs): + r""" + Correlation of two random expressions, also known as correlation + coefficient or Pearson's correlation. + + Explanation + =========== + + The normalized expectation that the two variables will rise + and fall together + + .. math:: + correlation(X,Y) = E((X-E(X))(Y-E(Y)) / (\sigma_x \sigma_y)) + + Examples + ======== + + >>> from sympy.stats import Exponential, correlation + >>> from sympy import Symbol + + >>> rate = Symbol('lambda', positive=True, real=True) + >>> X = Exponential('X', rate) + >>> Y = Exponential('Y', rate) + + >>> correlation(X, X) + 1 + >>> correlation(X, Y) + 0 + >>> correlation(X, Y + rate*X) + 1/sqrt(1 + lambda**(-2)) + """ + return covariance(X, Y, condition, **kwargs)/(std(X, condition, **kwargs) + * std(Y, condition, **kwargs)) + + +def cmoment(X, n, condition=None, *, evaluate=True, **kwargs): + """ + Return the nth central moment of a random expression about its mean. + + .. math:: + cmoment(X, n) = E((X - E(X))^{n}) + + Examples + ======== + + >>> from sympy.stats import Die, cmoment, variance + >>> X = Die('X', 6) + >>> cmoment(X, 3) + 0 + >>> cmoment(X, 2) + 35/12 + >>> cmoment(X, 2) == variance(X) + True + """ + from sympy.stats.symbolic_probability import CentralMoment + if evaluate: + return CentralMoment(X, n, condition).doit() + return CentralMoment(X, n, condition).rewrite(Integral) + + +def smoment(X, n, condition=None, **kwargs): + r""" + Return the nth Standardized moment of a random expression. + + .. math:: + smoment(X, n) = E(((X - \mu)/\sigma_X)^{n}) + + Examples + ======== + + >>> from sympy.stats import skewness, Exponential, smoment + >>> from sympy import Symbol + >>> rate = Symbol('lambda', positive=True, real=True) + >>> Y = Exponential('Y', rate) + >>> smoment(Y, 4) + 9 + >>> smoment(Y, 4) == smoment(3*Y, 4) + True + >>> smoment(Y, 3) == skewness(Y) + True + """ + sigma = std(X, condition, **kwargs) + return (1/sigma)**n*cmoment(X, n, condition, **kwargs) + +def skewness(X, condition=None, **kwargs): + r""" + Measure of the asymmetry of the probability distribution. + + Explanation + =========== + + Positive skew indicates that most of the values lie to the right of + the mean. + + .. math:: + skewness(X) = E(((X - E(X))/\sigma_X)^{3}) + + Parameters + ========== + + condition : Expr containing RandomSymbols + A conditional expression. skewness(X, X>0) is skewness of X given X > 0 + + Examples + ======== + + >>> from sympy.stats import skewness, Exponential, Normal + >>> from sympy import Symbol + >>> X = Normal('X', 0, 1) + >>> skewness(X) + 0 + >>> skewness(X, X > 0) # find skewness given X > 0 + (-sqrt(2)/sqrt(pi) + 4*sqrt(2)/pi**(3/2))/(1 - 2/pi)**(3/2) + + >>> rate = Symbol('lambda', positive=True, real=True) + >>> Y = Exponential('Y', rate) + >>> skewness(Y) + 2 + """ + return smoment(X, 3, condition=condition, **kwargs) + +def kurtosis(X, condition=None, **kwargs): + r""" + Characterizes the tails/outliers of a probability distribution. + + Explanation + =========== + + Kurtosis of any univariate normal distribution is 3. Kurtosis less than + 3 means that the distribution produces fewer and less extreme outliers + than the normal distribution. + + .. math:: + kurtosis(X) = E(((X - E(X))/\sigma_X)^{4}) + + Parameters + ========== + + condition : Expr containing RandomSymbols + A conditional expression. kurtosis(X, X>0) is kurtosis of X given X > 0 + + Examples + ======== + + >>> from sympy.stats import kurtosis, Exponential, Normal + >>> from sympy import Symbol + >>> X = Normal('X', 0, 1) + >>> kurtosis(X) + 3 + >>> kurtosis(X, X > 0) # find kurtosis given X > 0 + (-4/pi - 12/pi**2 + 3)/(1 - 2/pi)**2 + + >>> rate = Symbol('lamda', positive=True, real=True) + >>> Y = Exponential('Y', rate) + >>> kurtosis(Y) + 9 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Kurtosis + .. [2] https://mathworld.wolfram.com/Kurtosis.html + """ + return smoment(X, 4, condition=condition, **kwargs) + + +def factorial_moment(X, n, condition=None, **kwargs): + """ + The factorial moment is a mathematical quantity defined as the expectation + or average of the falling factorial of a random variable. + + .. math:: + factorial-moment(X, n) = E(X(X - 1)(X - 2)...(X - n + 1)) + + Parameters + ========== + + n: A natural number, n-th factorial moment. + + condition : Expr containing RandomSymbols + A conditional expression. + + Examples + ======== + + >>> from sympy.stats import factorial_moment, Poisson, Binomial + >>> from sympy import Symbol, S + >>> lamda = Symbol('lamda') + >>> X = Poisson('X', lamda) + >>> factorial_moment(X, 2) + lamda**2 + >>> Y = Binomial('Y', 2, S.Half) + >>> factorial_moment(Y, 2) + 1/2 + >>> factorial_moment(Y, 2, Y > 1) # find factorial moment for Y > 1 + 2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Factorial_moment + .. [2] https://mathworld.wolfram.com/FactorialMoment.html + """ + return expectation(FallingFactorial(X, n), condition=condition, **kwargs) + +def median(X, evaluate=True, **kwargs): + r""" + Calculates the median of the probability distribution. + + Explanation + =========== + + Mathematically, median of Probability distribution is defined as all those + values of `m` for which the following condition is satisfied + + .. math:: + P(X\leq m) \geq \frac{1}{2} \text{ and} \text{ } P(X\geq m)\geq \frac{1}{2} + + Parameters + ========== + + X: The random expression whose median is to be calculated. + + Returns + ======= + + The FiniteSet or an Interval which contains the median of the + random expression. + + Examples + ======== + + >>> from sympy.stats import Normal, Die, median + >>> N = Normal('N', 3, 1) + >>> median(N) + {3} + >>> D = Die('D') + >>> median(D) + {3, 4} + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Median#Probability_distributions + + """ + if not is_random(X): + return X + + from sympy.stats.crv import ContinuousPSpace + from sympy.stats.drv import DiscretePSpace + from sympy.stats.frv import FinitePSpace + + if isinstance(pspace(X), FinitePSpace): + cdf = pspace(X).compute_cdf(X) + result = [] + for key, value in cdf.items(): + if value>= Rational(1, 2) and (1 - value) + \ + pspace(X).probability(Eq(X, key)) >= Rational(1, 2): + result.append(key) + return FiniteSet(*result) + if isinstance(pspace(X), (ContinuousPSpace, DiscretePSpace)): + cdf = pspace(X).compute_cdf(X) + x = Dummy('x') + result = solveset(piecewise_fold(cdf(x) - Rational(1, 2)), x, pspace(X).set) + return result + raise NotImplementedError("The median of %s is not implemented."%str(pspace(X))) + + +def coskewness(X, Y, Z, condition=None, **kwargs): + r""" + Calculates the co-skewness of three random variables. + + Explanation + =========== + + Mathematically Coskewness is defined as + + .. math:: + coskewness(X,Y,Z)=\frac{E[(X-E[X]) * (Y-E[Y]) * (Z-E[Z])]} {\sigma_{X}\sigma_{Y}\sigma_{Z}} + + Parameters + ========== + + X : RandomSymbol + Random Variable used to calculate coskewness + Y : RandomSymbol + Random Variable used to calculate coskewness + Z : RandomSymbol + Random Variable used to calculate coskewness + condition : Expr containing RandomSymbols + A conditional expression + + Examples + ======== + + >>> from sympy.stats import coskewness, Exponential, skewness + >>> from sympy import symbols + >>> p = symbols('p', positive=True) + >>> X = Exponential('X', p) + >>> Y = Exponential('Y', 2*p) + >>> coskewness(X, Y, Y) + 0 + >>> coskewness(X, Y + X, Y + 2*X) + 16*sqrt(85)/85 + >>> coskewness(X + 2*Y, Y + X, Y + 2*X, X > 3) + 9*sqrt(170)/85 + >>> coskewness(Y, Y, Y) == skewness(Y) + True + >>> coskewness(X, Y + p*X, Y + 2*p*X) + 4/(sqrt(1 + 1/(4*p**2))*sqrt(4 + 1/(4*p**2))) + + Returns + ======= + + coskewness : The coskewness of the three random variables + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Coskewness + + """ + num = expectation((X - expectation(X, condition, **kwargs)) \ + * (Y - expectation(Y, condition, **kwargs)) \ + * (Z - expectation(Z, condition, **kwargs)), condition, **kwargs) + den = std(X, condition, **kwargs) * std(Y, condition, **kwargs) \ + * std(Z, condition, **kwargs) + return num/den + + +P = probability +E = expectation +H = entropy diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/sampling/__init__.py b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/sampling/sample_numpy.py b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/sample_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..d65417945449ed8b62d2547215d8908f84820a9b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/sample_numpy.py @@ -0,0 +1,105 @@ +from functools import singledispatch + +from sympy.external import import_module +from sympy.stats.crv_types import BetaDistribution, ChiSquaredDistribution, ExponentialDistribution, GammaDistribution, \ + LogNormalDistribution, NormalDistribution, ParetoDistribution, UniformDistribution, FDistributionDistribution, GumbelDistribution, LaplaceDistribution, \ + LogisticDistribution, RayleighDistribution, TriangularDistribution +from sympy.stats.drv_types import GeometricDistribution, PoissonDistribution, ZetaDistribution +from sympy.stats.frv_types import BinomialDistribution, HypergeometricDistribution + + +numpy = import_module('numpy') + + +@singledispatch +def do_sample_numpy(dist, size, rand_state): + return None + + +# CRV: + +@do_sample_numpy.register(BetaDistribution) +def _(dist: BetaDistribution, size, rand_state): + return rand_state.beta(a=float(dist.alpha), b=float(dist.beta), size=size) + + +@do_sample_numpy.register(ChiSquaredDistribution) +def _(dist: ChiSquaredDistribution, size, rand_state): + return rand_state.chisquare(df=float(dist.k), size=size) + + +@do_sample_numpy.register(ExponentialDistribution) +def _(dist: ExponentialDistribution, size, rand_state): + return rand_state.exponential(1 / float(dist.rate), size=size) + +@do_sample_numpy.register(FDistributionDistribution) +def _(dist: FDistributionDistribution, size, rand_state): + return rand_state.f(dfnum = float(dist.d1), dfden = float(dist.d2), size=size) + +@do_sample_numpy.register(GammaDistribution) +def _(dist: GammaDistribution, size, rand_state): + return rand_state.gamma(shape = float(dist.k), scale = float(dist.theta), size=size) + +@do_sample_numpy.register(GumbelDistribution) +def _(dist: GumbelDistribution, size, rand_state): + return rand_state.gumbel(loc = float(dist.mu), scale = float(dist.beta), size=size) + +@do_sample_numpy.register(LaplaceDistribution) +def _(dist: LaplaceDistribution, size, rand_state): + return rand_state.laplace(loc = float(dist.mu), scale = float(dist.b), size=size) + +@do_sample_numpy.register(LogisticDistribution) +def _(dist: LogisticDistribution, size, rand_state): + return rand_state.logistic(loc = float(dist.mu), scale = float(dist.s), size=size) + +@do_sample_numpy.register(LogNormalDistribution) +def _(dist: LogNormalDistribution, size, rand_state): + return rand_state.lognormal(mean = float(dist.mean), sigma = float(dist.std), size=size) + +@do_sample_numpy.register(NormalDistribution) +def _(dist: NormalDistribution, size, rand_state): + return rand_state.normal(loc = float(dist.mean), scale = float(dist.std), size=size) + +@do_sample_numpy.register(RayleighDistribution) +def _(dist: RayleighDistribution, size, rand_state): + return rand_state.rayleigh(scale = float(dist.sigma), size=size) + +@do_sample_numpy.register(ParetoDistribution) +def _(dist: ParetoDistribution, size, rand_state): + return (numpy.random.pareto(a=float(dist.alpha), size=size) + 1) * float(dist.xm) + +@do_sample_numpy.register(TriangularDistribution) +def _(dist: TriangularDistribution, size, rand_state): + return rand_state.triangular(left = float(dist.a), mode = float(dist.b), right = float(dist.c), size=size) + +@do_sample_numpy.register(UniformDistribution) +def _(dist: UniformDistribution, size, rand_state): + return rand_state.uniform(low=float(dist.left), high=float(dist.right), size=size) + + +# DRV: + +@do_sample_numpy.register(GeometricDistribution) +def _(dist: GeometricDistribution, size, rand_state): + return rand_state.geometric(p=float(dist.p), size=size) + + +@do_sample_numpy.register(PoissonDistribution) +def _(dist: PoissonDistribution, size, rand_state): + return rand_state.poisson(lam=float(dist.lamda), size=size) + + +@do_sample_numpy.register(ZetaDistribution) +def _(dist: ZetaDistribution, size, rand_state): + return rand_state.zipf(a=float(dist.s), size=size) + + +# FRV: + +@do_sample_numpy.register(BinomialDistribution) +def _(dist: BinomialDistribution, size, rand_state): + return rand_state.binomial(n=int(dist.n), p=float(dist.p), size=size) + +@do_sample_numpy.register(HypergeometricDistribution) +def _(dist: HypergeometricDistribution, size, rand_state): + return rand_state.hypergeometric(ngood = int(dist.N), nbad = int(dist.m), nsample = int(dist.n), size=size) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/sampling/sample_pymc.py b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/sample_pymc.py new file mode 100644 index 0000000000000000000000000000000000000000..546f02a3092815af2e54b4a164463f62ece7a024 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/sample_pymc.py @@ -0,0 +1,99 @@ +from functools import singledispatch +from sympy.external import import_module +from sympy.stats.crv_types import BetaDistribution, CauchyDistribution, ChiSquaredDistribution, ExponentialDistribution, \ + GammaDistribution, LogNormalDistribution, NormalDistribution, ParetoDistribution, UniformDistribution, \ + GaussianInverseDistribution +from sympy.stats.drv_types import PoissonDistribution, GeometricDistribution, NegativeBinomialDistribution +from sympy.stats.frv_types import BinomialDistribution, BernoulliDistribution + + +try: + import pymc +except ImportError: + pymc = import_module('pymc3') + +@singledispatch +def do_sample_pymc(dist): + return None + + +# CRV: + +@do_sample_pymc.register(BetaDistribution) +def _(dist: BetaDistribution): + return pymc.Beta('X', alpha=float(dist.alpha), beta=float(dist.beta)) + + +@do_sample_pymc.register(CauchyDistribution) +def _(dist: CauchyDistribution): + return pymc.Cauchy('X', alpha=float(dist.x0), beta=float(dist.gamma)) + + +@do_sample_pymc.register(ChiSquaredDistribution) +def _(dist: ChiSquaredDistribution): + return pymc.ChiSquared('X', nu=float(dist.k)) + + +@do_sample_pymc.register(ExponentialDistribution) +def _(dist: ExponentialDistribution): + return pymc.Exponential('X', lam=float(dist.rate)) + + +@do_sample_pymc.register(GammaDistribution) +def _(dist: GammaDistribution): + return pymc.Gamma('X', alpha=float(dist.k), beta=1 / float(dist.theta)) + + +@do_sample_pymc.register(LogNormalDistribution) +def _(dist: LogNormalDistribution): + return pymc.Lognormal('X', mu=float(dist.mean), sigma=float(dist.std)) + + +@do_sample_pymc.register(NormalDistribution) +def _(dist: NormalDistribution): + return pymc.Normal('X', float(dist.mean), float(dist.std)) + + +@do_sample_pymc.register(GaussianInverseDistribution) +def _(dist: GaussianInverseDistribution): + return pymc.Wald('X', mu=float(dist.mean), lam=float(dist.shape)) + + +@do_sample_pymc.register(ParetoDistribution) +def _(dist: ParetoDistribution): + return pymc.Pareto('X', alpha=float(dist.alpha), m=float(dist.xm)) + + +@do_sample_pymc.register(UniformDistribution) +def _(dist: UniformDistribution): + return pymc.Uniform('X', lower=float(dist.left), upper=float(dist.right)) + + +# DRV: + +@do_sample_pymc.register(GeometricDistribution) +def _(dist: GeometricDistribution): + return pymc.Geometric('X', p=float(dist.p)) + + +@do_sample_pymc.register(NegativeBinomialDistribution) +def _(dist: NegativeBinomialDistribution): + return pymc.NegativeBinomial('X', mu=float((dist.p * dist.r) / (1 - dist.p)), + alpha=float(dist.r)) + + +@do_sample_pymc.register(PoissonDistribution) +def _(dist: PoissonDistribution): + return pymc.Poisson('X', mu=float(dist.lamda)) + + +# FRV: + +@do_sample_pymc.register(BernoulliDistribution) +def _(dist: BernoulliDistribution): + return pymc.Bernoulli('X', p=float(dist.p)) + + +@do_sample_pymc.register(BinomialDistribution) +def _(dist: BinomialDistribution): + return pymc.Binomial('X', n=int(dist.n), p=float(dist.p)) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/sampling/sample_scipy.py b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/sample_scipy.py new file mode 100644 index 0000000000000000000000000000000000000000..f12508f68844488e9a14b1476005164eb422796e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/sample_scipy.py @@ -0,0 +1,167 @@ +from functools import singledispatch + +from sympy.core.symbol import Dummy +from sympy.functions.elementary.exponential import exp +from sympy.utilities.lambdify import lambdify +from sympy.external import import_module +from sympy.stats import DiscreteDistributionHandmade +from sympy.stats.crv import SingleContinuousDistribution +from sympy.stats.crv_types import ChiSquaredDistribution, ExponentialDistribution, GammaDistribution, \ + LogNormalDistribution, NormalDistribution, ParetoDistribution, UniformDistribution, BetaDistribution, \ + StudentTDistribution, CauchyDistribution +from sympy.stats.drv_types import GeometricDistribution, LogarithmicDistribution, NegativeBinomialDistribution, \ + PoissonDistribution, SkellamDistribution, YuleSimonDistribution, ZetaDistribution +from sympy.stats.frv import SingleFiniteDistribution + + +scipy = import_module("scipy", import_kwargs={'fromlist':['stats']}) + + +@singledispatch +def do_sample_scipy(dist, size, seed): + return None + + +# CRV + +@do_sample_scipy.register(SingleContinuousDistribution) +def _(dist: SingleContinuousDistribution, size, seed): + # if we don't need to make a handmade pdf, we won't + import scipy.stats + + z = Dummy('z') + handmade_pdf = lambdify(z, dist.pdf(z), ['numpy', 'scipy']) + + class scipy_pdf(scipy.stats.rv_continuous): + def _pdf(dist, x): + return handmade_pdf(x) + + scipy_rv = scipy_pdf(a=float(dist.set._inf), + b=float(dist.set._sup), name='scipy_pdf') + return scipy_rv.rvs(size=size, random_state=seed) + + +@do_sample_scipy.register(ChiSquaredDistribution) +def _(dist: ChiSquaredDistribution, size, seed): + # same parametrisation + return scipy.stats.chi2.rvs(df=float(dist.k), size=size, random_state=seed) + + +@do_sample_scipy.register(ExponentialDistribution) +def _(dist: ExponentialDistribution, size, seed): + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.expon.html#scipy.stats.expon + return scipy.stats.expon.rvs(scale=1 / float(dist.rate), size=size, random_state=seed) + + +@do_sample_scipy.register(GammaDistribution) +def _(dist: GammaDistribution, size, seed): + # https://stackoverflow.com/questions/42150965/how-to-plot-gamma-distribution-with-alpha-and-beta-parameters-in-python + return scipy.stats.gamma.rvs(a=float(dist.k), scale=float(dist.theta), size=size, random_state=seed) + + +@do_sample_scipy.register(LogNormalDistribution) +def _(dist: LogNormalDistribution, size, seed): + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.lognorm.html + return scipy.stats.lognorm.rvs(scale=float(exp(dist.mean)), s=float(dist.std), size=size, random_state=seed) + + +@do_sample_scipy.register(NormalDistribution) +def _(dist: NormalDistribution, size, seed): + return scipy.stats.norm.rvs(loc=float(dist.mean), scale=float(dist.std), size=size, random_state=seed) + + +@do_sample_scipy.register(ParetoDistribution) +def _(dist: ParetoDistribution, size, seed): + # https://stackoverflow.com/questions/42260519/defining-pareto-distribution-in-python-scipy + return scipy.stats.pareto.rvs(b=float(dist.alpha), scale=float(dist.xm), size=size, random_state=seed) + + +@do_sample_scipy.register(StudentTDistribution) +def _(dist: StudentTDistribution, size, seed): + return scipy.stats.t.rvs(df=float(dist.nu), size=size, random_state=seed) + + +@do_sample_scipy.register(UniformDistribution) +def _(dist: UniformDistribution, size, seed): + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.uniform.html + return scipy.stats.uniform.rvs(loc=float(dist.left), scale=float(dist.right - dist.left), size=size, random_state=seed) + + +@do_sample_scipy.register(BetaDistribution) +def _(dist: BetaDistribution, size, seed): + # same parametrisation + return scipy.stats.beta.rvs(a=float(dist.alpha), b=float(dist.beta), size=size, random_state=seed) + + +@do_sample_scipy.register(CauchyDistribution) +def _(dist: CauchyDistribution, size, seed): + return scipy.stats.cauchy.rvs(loc=float(dist.x0), scale=float(dist.gamma), size=size, random_state=seed) + + +# DRV: + +@do_sample_scipy.register(DiscreteDistributionHandmade) +def _(dist: DiscreteDistributionHandmade, size, seed): + from scipy.stats import rv_discrete + + z = Dummy('z') + handmade_pmf = lambdify(z, dist.pdf(z), ['numpy', 'scipy']) + + class scipy_pmf(rv_discrete): + def _pmf(dist, x): + return handmade_pmf(x) + + scipy_rv = scipy_pmf(a=float(dist.set._inf), b=float(dist.set._sup), + name='scipy_pmf') + return scipy_rv.rvs(size=size, random_state=seed) + + +@do_sample_scipy.register(GeometricDistribution) +def _(dist: GeometricDistribution, size, seed): + return scipy.stats.geom.rvs(p=float(dist.p), size=size, random_state=seed) + + +@do_sample_scipy.register(LogarithmicDistribution) +def _(dist: LogarithmicDistribution, size, seed): + return scipy.stats.logser.rvs(p=float(dist.p), size=size, random_state=seed) + + +@do_sample_scipy.register(NegativeBinomialDistribution) +def _(dist: NegativeBinomialDistribution, size, seed): + return scipy.stats.nbinom.rvs(n=float(dist.r), p=float(dist.p), size=size, random_state=seed) + + +@do_sample_scipy.register(PoissonDistribution) +def _(dist: PoissonDistribution, size, seed): + return scipy.stats.poisson.rvs(mu=float(dist.lamda), size=size, random_state=seed) + + +@do_sample_scipy.register(SkellamDistribution) +def _(dist: SkellamDistribution, size, seed): + return scipy.stats.skellam.rvs(mu1=float(dist.mu1), mu2=float(dist.mu2), size=size, random_state=seed) + + +@do_sample_scipy.register(YuleSimonDistribution) +def _(dist: YuleSimonDistribution, size, seed): + return scipy.stats.yulesimon.rvs(alpha=float(dist.rho), size=size, random_state=seed) + + +@do_sample_scipy.register(ZetaDistribution) +def _(dist: ZetaDistribution, size, seed): + return scipy.stats.zipf.rvs(a=float(dist.s), size=size, random_state=seed) + + +# FRV: + +@do_sample_scipy.register(SingleFiniteDistribution) +def _(dist: SingleFiniteDistribution, size, seed): + # scipy can handle with custom distributions + + from scipy.stats import rv_discrete + density_ = dist.dict + x, y = [], [] + for k, v in density_.items(): + x.append(int(k)) + y.append(float(v)) + scipy_rv = rv_discrete(name='scipy_rv', values=(x, y)) + return scipy_rv.rvs(size=size, random_state=seed) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/sampling/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/sampling/tests/test_sample_continuous_rv.py b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/tests/test_sample_continuous_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..953bb602df5e63da2882ee118de9dbf24b6f7804 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/tests/test_sample_continuous_rv.py @@ -0,0 +1,181 @@ +from sympy.core.numbers import oo +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.sets.sets import Interval +from sympy.external import import_module +from sympy.stats import Beta, Chi, Normal, Gamma, Exponential, LogNormal, Pareto, ChiSquared, Uniform, sample, \ + BetaPrime, Cauchy, GammaInverse, GaussianInverse, StudentT, Weibull, density, ContinuousRV, FDistribution, \ + Gumbel, Laplace, Logistic, Rayleigh, Triangular +from sympy.testing.pytest import skip, raises + + +def test_sample_numpy(): + distribs_numpy = [ + Beta("B", 1, 1), + Normal("N", 0, 1), + Gamma("G", 2, 7), + Exponential("E", 2), + LogNormal("LN", 0, 1), + Pareto("P", 1, 1), + ChiSquared("CS", 2), + Uniform("U", 0, 1), + FDistribution("FD", 1, 2), + Gumbel("GB", 1, 2), + Laplace("L", 1, 2), + Logistic("LO", 1, 2), + Rayleigh("R", 1), + Triangular("T", 1, 2, 2), + ] + size = 3 + numpy = import_module('numpy') + if not numpy: + skip('Numpy is not installed. Abort tests for _sample_numpy.') + else: + for X in distribs_numpy: + samps = sample(X, size=size, library='numpy') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Chi("C", 1), library='numpy')) + raises(NotImplementedError, + lambda: Chi("C", 1).pspace.distribution.sample(library='tensorflow')) + + +def test_sample_scipy(): + distribs_scipy = [ + Beta("B", 1, 1), + BetaPrime("BP", 1, 1), + Cauchy("C", 1, 1), + Chi("C", 1), + Normal("N", 0, 1), + Gamma("G", 2, 7), + GammaInverse("GI", 1, 1), + GaussianInverse("GUI", 1, 1), + Exponential("E", 2), + LogNormal("LN", 0, 1), + Pareto("P", 1, 1), + StudentT("S", 2), + ChiSquared("CS", 2), + Uniform("U", 0, 1) + ] + size = 3 + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests for _sample_scipy.') + else: + for X in distribs_scipy: + samps = sample(X, size=size, library='scipy') + samps2 = sample(X, size=(2, 2), library='scipy') + for sam in samps: + assert sam in X.pspace.domain.set + for i in range(2): + for j in range(2): + assert samps2[i][j] in X.pspace.domain.set + + +def test_sample_pymc(): + distribs_pymc = [ + Beta("B", 1, 1), + Cauchy("C", 1, 1), + Normal("N", 0, 1), + Gamma("G", 2, 7), + GaussianInverse("GI", 1, 1), + Exponential("E", 2), + LogNormal("LN", 0, 1), + Pareto("P", 1, 1), + ChiSquared("CS", 2), + Uniform("U", 0, 1) + ] + size = 3 + pymc = import_module('pymc') + if not pymc: + skip('PyMC is not installed. Abort tests for _sample_pymc.') + else: + for X in distribs_pymc: + samps = sample(X, size=size, library='pymc') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Chi("C", 1), library='pymc')) + + +def test_sampling_gamma_inverse(): + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests for sampling of gamma inverse.') + X = GammaInverse("x", 1, 1) + assert sample(X) in X.pspace.domain.set + + +def test_lognormal_sampling(): + # Right now, only density function and sampling works + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + for i in range(3): + X = LogNormal('x', i, 1) + assert sample(X) in X.pspace.domain.set + + size = 5 + samps = sample(X, size=size) + for samp in samps: + assert samp in X.pspace.domain.set + + +def test_sampling_gaussian_inverse(): + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests for sampling of Gaussian inverse.') + X = GaussianInverse("x", 1, 1) + assert sample(X, library='scipy') in X.pspace.domain.set + + +def test_prefab_sampling(): + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + N = Normal('X', 0, 1) + L = LogNormal('L', 0, 1) + E = Exponential('Ex', 1) + P = Pareto('P', 1, 3) + W = Weibull('W', 1, 1) + U = Uniform('U', 0, 1) + B = Beta('B', 2, 5) + G = Gamma('G', 1, 3) + + variables = [N, L, E, P, W, U, B, G] + niter = 10 + size = 5 + for var in variables: + for _ in range(niter): + assert sample(var) in var.pspace.domain.set + samps = sample(var, size=size) + for samp in samps: + assert samp in var.pspace.domain.set + + +def test_sample_continuous(): + z = Symbol('z') + Z = ContinuousRV(z, exp(-z), set=Interval(0, oo)) + assert density(Z)(-1) == 0 + + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + assert sample(Z) in Z.pspace.domain.set + sym, val = list(Z.pspace.sample().items())[0] + assert sym == Z and val in Interval(0, oo) + + libraries = ['scipy', 'numpy', 'pymc'] + for lib in libraries: + try: + imported_lib = import_module(lib) + if imported_lib: + s0, s1, s2 = [], [], [] + s0 = sample(Z, size=10, library=lib, seed=0) + s1 = sample(Z, size=10, library=lib, seed=0) + s2 = sample(Z, size=10, library=lib, seed=1) + assert all(s0 == s1) + assert all(s1 != s2) + except NotImplementedError: + continue diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/sampling/tests/test_sample_discrete_rv.py b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/tests/test_sample_discrete_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..90d385cd599222fd7da7c1559b619bafbeb01831 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/tests/test_sample_discrete_rv.py @@ -0,0 +1,109 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.external import import_module +from sympy.stats import ( + Geometric, + Poisson, + Zeta, + sample, + Skellam, + Logarithmic, + NegativeBinomial, + YuleSimon, + DiscreteRV, +) +from sympy.testing.pytest import skip, raises, slow + + +def test_sample_numpy(): + distribs_numpy = [ + Geometric('G', 0.5), + Poisson('P', 1), + Zeta('Z', 2) + ] + size = 3 + numpy = import_module('numpy') + if not numpy: + skip('Numpy is not installed. Abort tests for _sample_numpy.') + else: + for X in distribs_numpy: + samps = sample(X, size=size, library='numpy') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Skellam('S', 1, 1), library='numpy')) + raises(NotImplementedError, + lambda: Skellam('S', 1, 1).pspace.distribution.sample(library='tensorflow')) + + +def test_sample_scipy(): + p = S(2)/3 + x = Symbol('x', integer=True, positive=True) + pdf = p*(1 - p)**(x - 1) # pdf of Geometric Distribution + distribs_scipy = [ + DiscreteRV(x, pdf, set=S.Naturals), + Geometric('G', 0.5), + Logarithmic('L', 0.5), + NegativeBinomial('N', 5, 0.4), + Poisson('P', 1), + Skellam('S', 1, 1), + YuleSimon('Y', 1), + Zeta('Z', 2) + ] + size = 3 + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests for _sample_scipy.') + else: + for X in distribs_scipy: + samps = sample(X, size=size, library='scipy') + samps2 = sample(X, size=(2, 2), library='scipy') + for sam in samps: + assert sam in X.pspace.domain.set + for i in range(2): + for j in range(2): + assert samps2[i][j] in X.pspace.domain.set + + +def test_sample_pymc(): + distribs_pymc = [ + Geometric('G', 0.5), + Poisson('P', 1), + NegativeBinomial('N', 5, 0.4) + ] + size = 3 + pymc = import_module('pymc') + if not pymc: + skip('PyMC is not installed. Abort tests for _sample_pymc.') + else: + for X in distribs_pymc: + samps = sample(X, size=size, library='pymc') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Skellam('S', 1, 1), library='pymc')) + +@slow +def test_sample_discrete(): + X = Geometric('X', S.Half) + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests') + assert sample(X) in X.pspace.domain.set + samps = sample(X, size=2) # This takes long time if ran without scipy + for samp in samps: + assert samp in X.pspace.domain.set + + libraries = ['scipy', 'numpy', 'pymc'] + for lib in libraries: + try: + imported_lib = import_module(lib) + if imported_lib: + s0, s1, s2 = [], [], [] + s0 = sample(X, size=10, library=lib, seed=0) + s1 = sample(X, size=10, library=lib, seed=0) + s2 = sample(X, size=10, library=lib, seed=1) + assert all(s0 == s1) + assert not all(s1 == s2) + except NotImplementedError: + continue diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/sampling/tests/test_sample_finite_rv.py b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/tests/test_sample_finite_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..96cabe0ff4aaa5977e16600217fbbdeb08b962ae --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/sampling/tests/test_sample_finite_rv.py @@ -0,0 +1,94 @@ +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.external import import_module +from sympy.stats import Binomial, sample, Die, FiniteRV, DiscreteUniform, Bernoulli, BetaBinomial, Hypergeometric, \ + Rademacher +from sympy.testing.pytest import skip, raises + +def test_given_sample(): + X = Die('X', 6) + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + assert sample(X, X > 5) == 6 + +def test_sample_numpy(): + distribs_numpy = [ + Binomial("B", 5, 0.4), + Hypergeometric("H", 2, 1, 1) + ] + size = 3 + numpy = import_module('numpy') + if not numpy: + skip('Numpy is not installed. Abort tests for _sample_numpy.') + else: + for X in distribs_numpy: + samps = sample(X, size=size, library='numpy') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Die("D"), library='numpy')) + raises(NotImplementedError, + lambda: Die("D").pspace.sample(library='tensorflow')) + + +def test_sample_scipy(): + distribs_scipy = [ + FiniteRV('F', {1: S.Half, 2: Rational(1, 4), 3: Rational(1, 4)}), + DiscreteUniform("Y", list(range(5))), + Die("D"), + Bernoulli("Be", 0.3), + Binomial("Bi", 5, 0.4), + BetaBinomial("Bb", 2, 1, 1), + Hypergeometric("H", 1, 1, 1), + Rademacher("R") + ] + + size = 3 + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests for _sample_scipy.') + else: + for X in distribs_scipy: + samps = sample(X, size=size) + samps2 = sample(X, size=(2, 2)) + for sam in samps: + assert sam in X.pspace.domain.set + for i in range(2): + for j in range(2): + assert samps2[i][j] in X.pspace.domain.set + + +def test_sample_pymc(): + distribs_pymc = [ + Bernoulli('B', 0.2), + Binomial('N', 5, 0.4) + ] + size = 3 + pymc = import_module('pymc') + if not pymc: + skip('PyMC is not installed. Abort tests for _sample_pymc.') + else: + for X in distribs_pymc: + samps = sample(X, size=size, library='pymc') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: (sample(Die("D"), library='pymc'))) + + +def test_sample_seed(): + F = FiniteRV('F', {1: S.Half, 2: Rational(1, 4), 3: Rational(1, 4)}) + size = 10 + libraries = ['scipy', 'numpy', 'pymc'] + for lib in libraries: + try: + imported_lib = import_module(lib) + if imported_lib: + s0 = sample(F, size=size, library=lib, seed=0) + s1 = sample(F, size=size, library=lib, seed=0) + s2 = sample(F, size=size, library=lib, seed=1) + assert all(s0 == s1) + assert not all(s1 == s2) + except NotImplementedError: + continue diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/stochastic_process.py b/.venv/lib/python3.13/site-packages/sympy/stats/stochastic_process.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb0e759c66be892ae38ddda004dfe928f683fee --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/stochastic_process.py @@ -0,0 +1,66 @@ +from sympy.core.basic import Basic +from sympy.stats.joint_rv import ProductPSpace +from sympy.stats.rv import ProductDomain, _symbol_converter, Distribution + + +class StochasticPSpace(ProductPSpace): + """ + Represents probability space of stochastic processes + and their random variables. Contains mechanics to do + computations for queries of stochastic processes. + + Explanation + =========== + + Initialized by symbol, the specific process and + distribution(optional) if the random indexed symbols + of the process follows any specific distribution, like, + in Bernoulli Process, each random indexed symbol follows + Bernoulli distribution. For processes with memory, this + parameter should not be passed. + """ + + def __new__(cls, sym, process, distribution=None): + sym = _symbol_converter(sym) + from sympy.stats.stochastic_process_types import StochasticProcess + if not isinstance(process, StochasticProcess): + raise TypeError("`process` must be an instance of StochasticProcess.") + if distribution is None: + distribution = Distribution() + return Basic.__new__(cls, sym, process, distribution) + + @property + def process(self): + """ + The associated stochastic process. + """ + return self.args[1] + + @property + def domain(self): + return ProductDomain(self.process.index_set, + self.process.state_space) + + @property + def symbol(self): + return self.args[0] + + @property + def distribution(self): + return self.args[2] + + def probability(self, condition, given_condition=None, evaluate=True, **kwargs): + """ + Transfers the task of handling queries to the specific stochastic + process because every process has their own logic of handling such + queries. + """ + return self.process.probability(condition, given_condition, evaluate, **kwargs) + + def compute_expectation(self, expr, condition=None, evaluate=True, **kwargs): + """ + Transfers the task of handling queries to the specific stochastic + process because every process has their own logic of handling such + queries. + """ + return self.process.expectation(expr, condition, evaluate, **kwargs) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/stochastic_process_types.py b/.venv/lib/python3.13/site-packages/sympy/stats/stochastic_process_types.py new file mode 100644 index 0000000000000000000000000000000000000000..7387cd3dbcf6defb3b7e475f542d04ecef6fecf6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/stochastic_process_types.py @@ -0,0 +1,2383 @@ +from __future__ import annotations +import random +import itertools +from typing import Sequence as tSequence +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import (Function, Lambda) +from sympy.core.mul import Mul +from sympy.core.intfunc import igcd +from sympy.core.numbers import (Integer, Rational, oo, pi) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.integers import ceiling +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.gamma_functions import gamma +from sympy.logic.boolalg import (And, Not, Or) +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices.dense import (Matrix, eye, ones, zeros) +from sympy.matrices.expressions.blockmatrix import BlockMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.matrices.immutable import ImmutableMatrix +from sympy.sets.conditionset import ConditionSet +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Range +from sympy.sets.sets import (FiniteSet, Intersection, Interval, Set, Union) +from sympy.solvers.solveset import linsolve +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.core.relational import Relational +from sympy.logic.boolalg import Boolean +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import strongly_connected_components +from sympy.stats.joint_rv import JointDistribution +from sympy.stats.joint_rv_types import JointDistributionHandmade +from sympy.stats.rv import (RandomIndexedSymbol, random_symbols, RandomSymbol, + _symbol_converter, _value_check, pspace, given, + dependent, is_random, sample_iter, Distribution, + Density) +from sympy.stats.stochastic_process import StochasticPSpace +from sympy.stats.symbolic_probability import Probability, Expectation +from sympy.stats.frv_types import Bernoulli, BernoulliDistribution, FiniteRV +from sympy.stats.drv_types import Poisson, PoissonDistribution +from sympy.stats.crv_types import Normal, NormalDistribution, Gamma, GammaDistribution +from sympy.core.sympify import _sympify, sympify + +EmptySet = S.EmptySet + +__all__ = [ + 'StochasticProcess', + 'DiscreteTimeStochasticProcess', + 'DiscreteMarkovChain', + 'TransitionMatrixOf', + 'StochasticStateSpaceOf', + 'GeneratorMatrixOf', + 'ContinuousMarkovChain', + 'BernoulliProcess', + 'PoissonProcess', + 'WienerProcess', + 'GammaProcess' +] + + +@is_random.register(Indexed) +def _(x): + return is_random(x.base) + +@is_random.register(RandomIndexedSymbol) # type: ignore +def _(x): + return True + +def _set_converter(itr): + """ + Helper function for converting list/tuple/set to Set. + If parameter is not an instance of list/tuple/set then + no operation is performed. + + Returns + ======= + + Set + The argument converted to Set. + + + Raises + ====== + + TypeError + If the argument is not an instance of list/tuple/set. + """ + if isinstance(itr, (list, tuple, set)): + itr = FiniteSet(*itr) + if not isinstance(itr, Set): + raise TypeError("%s is not an instance of list/tuple/set."%(itr)) + return itr + +def _state_converter(itr: tSequence) -> Tuple | Range: + """ + Helper function for converting list/tuple/set/Range/Tuple/FiniteSet + to tuple/Range. + """ + itr_ret: Tuple | Range + + if isinstance(itr, (Tuple, set, FiniteSet)): + itr_ret = Tuple(*(sympify(i) if isinstance(i, str) else i for i in itr)) + + elif isinstance(itr, (list, tuple)): + # check if states are unique + if len(set(itr)) != len(itr): + raise ValueError('The state space must have unique elements.') + itr_ret = Tuple(*(sympify(i) if isinstance(i, str) else i for i in itr)) + + elif isinstance(itr, Range): + # the only ordered set in SymPy I know of + # try to convert to tuple + try: + itr_ret = Tuple(*(sympify(i) if isinstance(i, str) else i for i in itr)) + except (TypeError, ValueError): + itr_ret = itr + + else: + raise TypeError("%s is not an instance of list/tuple/set/Range/Tuple/FiniteSet." % (itr)) + return itr_ret + +def _sym_sympify(arg): + """ + Converts an arbitrary expression to a type that can be used inside SymPy. + As generally strings are unwise to use in the expressions, + it returns the Symbol of argument if the string type argument is passed. + + Parameters + ========= + + arg: The parameter to be converted to be used in SymPy. + + Returns + ======= + + The converted parameter. + + """ + if isinstance(arg, str): + return Symbol(arg) + else: + return _sympify(arg) + +def _matrix_checks(matrix): + if not isinstance(matrix, (Matrix, MatrixSymbol, ImmutableMatrix)): + raise TypeError("Transition probabilities either should " + "be a Matrix or a MatrixSymbol.") + if matrix.shape[0] != matrix.shape[1]: + raise NonSquareMatrixError("%s is not a square matrix"%(matrix)) + if isinstance(matrix, Matrix): + matrix = ImmutableMatrix(matrix.tolist()) + return matrix + +class StochasticProcess(Basic): + """ + Base class for all the stochastic processes whether + discrete or continuous. + + Parameters + ========== + + sym: Symbol or str + state_space: Set + The state space of the stochastic process, by default S.Reals. + For discrete sets it is zero indexed. + + See Also + ======== + + DiscreteTimeStochasticProcess + """ + + index_set = S.Reals + + def __new__(cls, sym, state_space=S.Reals, **kwargs): + sym = _symbol_converter(sym) + state_space = _set_converter(state_space) + return Basic.__new__(cls, sym, state_space) + + @property + def symbol(self): + return self.args[0] + + @property + def state_space(self) -> FiniteSet | Range: + if not isinstance(self.args[1], (FiniteSet, Range)): + assert isinstance(self.args[1], Tuple) + return FiniteSet(*self.args[1]) + return self.args[1] + + def _deprecation_warn_distribution(self): + sympy_deprecation_warning( + """ + Calling the distribution method with a RandomIndexedSymbol + argument, like X.distribution(X(t)) is deprecated. Instead, call + distribution() with the given timestamp, like + + X.distribution(t) + """, + deprecated_since_version="1.7.1", + active_deprecations_target="deprecated-distribution-randomindexedsymbol", + stacklevel=4, + ) + + def distribution(self, key=None): + if key is None: + self._deprecation_warn_distribution() + return Distribution() + + def density(self, x): + return Density() + + def __call__(self, time): + """ + Overridden in ContinuousTimeStochasticProcess. + """ + raise NotImplementedError("Use [] for indexing discrete time stochastic process.") + + def __getitem__(self, time): + """ + Overridden in DiscreteTimeStochasticProcess. + """ + raise NotImplementedError("Use () for indexing continuous time stochastic process.") + + def probability(self, condition): + raise NotImplementedError() + + def joint_distribution(self, *args): + """ + Computes the joint distribution of the random indexed variables. + + Parameters + ========== + + args: iterable + The finite list of random indexed variables/the key of a stochastic + process whose joint distribution has to be computed. + + Returns + ======= + + JointDistribution + The joint distribution of the list of random indexed variables. + An unevaluated object is returned if it is not possible to + compute the joint distribution. + + Raises + ====== + + ValueError: When the arguments passed are not of type RandomIndexSymbol + or Number. + """ + args = list(args) + for i, arg in enumerate(args): + if S(arg).is_Number: + if self.index_set.is_subset(S.Integers): + args[i] = self.__getitem__(arg) + else: + args[i] = self.__call__(arg) + elif not isinstance(arg, RandomIndexedSymbol): + raise ValueError("Expected a RandomIndexedSymbol or " + "key not %s"%(type(arg))) + + if args[0].pspace.distribution == Distribution(): + return JointDistribution(*args) + density = Lambda(tuple(args), + expr=Mul.fromiter(arg.pspace.process.density(arg) for arg in args)) + return JointDistributionHandmade(density) + + def expectation(self, condition, given_condition): + raise NotImplementedError("Abstract method for expectation queries.") + + def sample(self): + raise NotImplementedError("Abstract method for sampling queries.") + +class DiscreteTimeStochasticProcess(StochasticProcess): + """ + Base class for all discrete stochastic processes. + """ + def __getitem__(self, time): + """ + For indexing discrete time stochastic processes. + + Returns + ======= + + RandomIndexedSymbol + """ + time = sympify(time) + if not time.is_symbol and time not in self.index_set: + raise IndexError("%s is not in the index set of %s"%(time, self.symbol)) + idx_obj = Indexed(self.symbol, time) + pspace_obj = StochasticPSpace(self.symbol, self, self.distribution(time)) + return RandomIndexedSymbol(idx_obj, pspace_obj) + +class ContinuousTimeStochasticProcess(StochasticProcess): + """ + Base class for all continuous time stochastic process. + """ + def __call__(self, time): + """ + For indexing continuous time stochastic processes. + + Returns + ======= + + RandomIndexedSymbol + """ + time = sympify(time) + if not time.is_symbol and time not in self.index_set: + raise IndexError("%s is not in the index set of %s"%(time, self.symbol)) + func_obj = Function(self.symbol)(time) + pspace_obj = StochasticPSpace(self.symbol, self, self.distribution(time)) + return RandomIndexedSymbol(func_obj, pspace_obj) + +class TransitionMatrixOf(Boolean): + """ + Assumes that the matrix is the transition matrix + of the process. + """ + + def __new__(cls, process, matrix): + if not isinstance(process, DiscreteMarkovChain): + raise ValueError("Currently only DiscreteMarkovChain " + "support TransitionMatrixOf.") + matrix = _matrix_checks(matrix) + return Basic.__new__(cls, process, matrix) + + process = property(lambda self: self.args[0]) + matrix = property(lambda self: self.args[1]) + +class GeneratorMatrixOf(TransitionMatrixOf): + """ + Assumes that the matrix is the generator matrix + of the process. + """ + + def __new__(cls, process, matrix): + if not isinstance(process, ContinuousMarkovChain): + raise ValueError("Currently only ContinuousMarkovChain " + "support GeneratorMatrixOf.") + matrix = _matrix_checks(matrix) + return Basic.__new__(cls, process, matrix) + +class StochasticStateSpaceOf(Boolean): + + def __new__(cls, process, state_space): + if not isinstance(process, (DiscreteMarkovChain, ContinuousMarkovChain)): + raise ValueError("Currently only DiscreteMarkovChain and ContinuousMarkovChain " + "support StochasticStateSpaceOf.") + state_space = _state_converter(state_space) + if isinstance(state_space, Range): + ss_size = ceiling((state_space.stop - state_space.start) / state_space.step) + else: + ss_size = len(state_space) + state_index = Range(ss_size) + return Basic.__new__(cls, process, state_index) + + process = property(lambda self: self.args[0]) + state_index = property(lambda self: self.args[1]) + +class MarkovProcess(StochasticProcess): + """ + Contains methods that handle queries + common to Markov processes. + """ + + @property + def number_of_states(self) -> Integer | Symbol: + """ + The number of states in the Markov Chain. + """ + return _sympify(self.args[2].shape[0]) # type: ignore + + @property + def _state_index(self): + """ + Returns state index as Range. + """ + return self.args[1] + + @classmethod + def _sanity_checks(cls, state_space, trans_probs): + # Try to never have None as state_space or trans_probs. + # This helps a lot if we get it done at the start. + if (state_space is None) and (trans_probs is None): + _n = Dummy('n', integer=True, nonnegative=True) + state_space = _state_converter(Range(_n)) + trans_probs = _matrix_checks(MatrixSymbol('_T', _n, _n)) + + elif state_space is None: + trans_probs = _matrix_checks(trans_probs) + state_space = _state_converter(Range(trans_probs.shape[0])) + + elif trans_probs is None: + state_space = _state_converter(state_space) + if isinstance(state_space, Range): + _n = ceiling((state_space.stop - state_space.start) / state_space.step) + else: + _n = len(state_space) + trans_probs = MatrixSymbol('_T', _n, _n) + + else: + state_space = _state_converter(state_space) + trans_probs = _matrix_checks(trans_probs) + # Range object doesn't want to give a symbolic size + # so we do it ourselves. + if isinstance(state_space, Range): + ss_size = ceiling((state_space.stop - state_space.start) / state_space.step) + else: + ss_size = len(state_space) + if ss_size != trans_probs.shape[0]: + raise ValueError('The size of the state space and the number of ' + 'rows of the transition matrix must be the same.') + + return state_space, trans_probs + + def _extract_information(self, given_condition): + """ + Helper function to extract information, like, + transition matrix/generator matrix, state space, etc. + """ + if isinstance(self, DiscreteMarkovChain): + trans_probs = self.transition_probabilities + state_index = self._state_index + elif isinstance(self, ContinuousMarkovChain): + trans_probs = self.generator_matrix + state_index = self._state_index + if isinstance(given_condition, And): + gcs = given_condition.args + given_condition = S.true + for gc in gcs: + if isinstance(gc, TransitionMatrixOf): + trans_probs = gc.matrix + if isinstance(gc, StochasticStateSpaceOf): + state_index = gc.state_index + if isinstance(gc, Relational): + given_condition = given_condition & gc + if isinstance(given_condition, TransitionMatrixOf): + trans_probs = given_condition.matrix + given_condition = S.true + if isinstance(given_condition, StochasticStateSpaceOf): + state_index = given_condition.state_index + given_condition = S.true + return trans_probs, state_index, given_condition + + def _check_trans_probs(self, trans_probs, row_sum=1): + """ + Helper function for checking the validity of transition + probabilities. + """ + if not isinstance(trans_probs, MatrixSymbol): + rows = trans_probs.tolist() + for row in rows: + if (sum(row) - row_sum) != 0: + raise ValueError("Values in a row must sum to %s. " + "If you are using Float or floats then please use Rational."%(row_sum)) + + def _work_out_state_index(self, state_index, given_condition, trans_probs): + """ + Helper function to extract state space if there + is a random symbol in the given condition. + """ + # if given condition is None, then there is no need to work out + # state_space from random variables + if given_condition != None: + rand_var = list(given_condition.atoms(RandomSymbol) - + given_condition.atoms(RandomIndexedSymbol)) + if len(rand_var) == 1: + state_index = rand_var[0].pspace.set + + # `not None` is `True`. So the old test fails for symbolic sizes. + # Need to build the statement differently. + sym_cond = not self.number_of_states.is_Integer + cond1 = not sym_cond and len(state_index) != trans_probs.shape[0] + if cond1: + raise ValueError("state space is not compatible with the transition probabilities.") + if not isinstance(trans_probs.shape[0], Symbol): + state_index = FiniteSet(*range(trans_probs.shape[0])) + return state_index + + @cacheit + def _preprocess(self, given_condition, evaluate): + """ + Helper function for pre-processing the information. + """ + is_insufficient = False + + if not evaluate: # avoid pre-processing if the result is not to be evaluated + return (True, None, None, None) + + # extracting transition matrix and state space + trans_probs, state_index, given_condition = self._extract_information(given_condition) + + # given_condition does not have sufficient information + # for computations + if trans_probs is None or \ + given_condition is None: + is_insufficient = True + else: + # checking transition probabilities + if isinstance(self, DiscreteMarkovChain): + self._check_trans_probs(trans_probs, row_sum=1) + elif isinstance(self, ContinuousMarkovChain): + self._check_trans_probs(trans_probs, row_sum=0) + + # working out state space + state_index = self._work_out_state_index(state_index, given_condition, trans_probs) + + return is_insufficient, trans_probs, state_index, given_condition + + def replace_with_index(self, condition): + if isinstance(condition, Relational): + lhs, rhs = condition.lhs, condition.rhs + if not isinstance(lhs, RandomIndexedSymbol): + lhs, rhs = rhs, lhs + condition = type(condition)(self.index_of.get(lhs, lhs), + self.index_of.get(rhs, rhs)) + return condition + + def probability(self, condition, given_condition=None, evaluate=True, **kwargs): + """ + Handles probability queries for Markov process. + + Parameters + ========== + + condition: Relational + given_condition: Relational/And + + Returns + ======= + Probability + If the information is not sufficient. + Expr + In all other cases. + + Note + ==== + Any information passed at the time of query overrides + any information passed at the time of object creation like + transition probabilities, state space. + Pass the transition matrix using TransitionMatrixOf, + generator matrix using GeneratorMatrixOf and state space + using StochasticStateSpaceOf in given_condition using & or And. + """ + check, mat, state_index, new_given_condition = \ + self._preprocess(given_condition, evaluate) + + rv = list(condition.atoms(RandomIndexedSymbol)) + symbolic = False + for sym in rv: + if sym.key.is_symbol: + symbolic = True + break + + if check: + return Probability(condition, new_given_condition) + + if isinstance(self, ContinuousMarkovChain): + trans_probs = self.transition_probabilities(mat) + elif isinstance(self, DiscreteMarkovChain): + trans_probs = mat + condition = self.replace_with_index(condition) + given_condition = self.replace_with_index(given_condition) + new_given_condition = self.replace_with_index(new_given_condition) + + if isinstance(condition, Relational): + if isinstance(new_given_condition, And): + gcs = new_given_condition.args + else: + gcs = (new_given_condition, ) + min_key_rv = list(new_given_condition.atoms(RandomIndexedSymbol)) + + if len(min_key_rv): + min_key_rv = min_key_rv[0] + for r in rv: + if min_key_rv.key.is_symbol or r.key.is_symbol: + continue + if min_key_rv.key > r.key: + return Probability(condition) + else: + min_key_rv = None + return Probability(condition) + + if symbolic: + return self._symbolic_probability(condition, new_given_condition, rv, min_key_rv) + + if len(rv) > 1: + rv[0] = condition.lhs + rv[1] = condition.rhs + if rv[0].key < rv[1].key: + rv[0], rv[1] = rv[1], rv[0] + if isinstance(condition, Gt): + condition = Lt(condition.lhs, condition.rhs) + elif isinstance(condition, Lt): + condition = Gt(condition.lhs, condition.rhs) + elif isinstance(condition, Ge): + condition = Le(condition.lhs, condition.rhs) + elif isinstance(condition, Le): + condition = Ge(condition.lhs, condition.rhs) + s = Rational(0, 1) + n = len(self.state_space) + + if isinstance(condition, (Eq, Ne)): + for i in range(0, n): + s += self.probability(Eq(rv[0], i), Eq(rv[1], i)) * self.probability(Eq(rv[1], i), new_given_condition) + return s if isinstance(condition, Eq) else 1 - s + else: + upper = 0 + greater = False + if isinstance(condition, (Ge, Lt)): + upper = 1 + if isinstance(condition, (Ge, Gt)): + greater = True + + for i in range(0, n): + if i <= n//2: + for j in range(0, i + upper): + s += self.probability(Eq(rv[0], i), Eq(rv[1], j)) * self.probability(Eq(rv[1], j), new_given_condition) + else: + s += self.probability(Eq(rv[0], i), new_given_condition) + for j in range(i + upper, n): + s -= self.probability(Eq(rv[0], i), Eq(rv[1], j)) * self.probability(Eq(rv[1], j), new_given_condition) + return s if greater else 1 - s + + rv = rv[0] + states = condition.as_set() + prob, gstate = {}, None + for gc in gcs: + if gc.has(min_key_rv): + if gc.has(Probability): + p, gp = (gc.rhs, gc.lhs) if isinstance(gc.lhs, Probability) \ + else (gc.lhs, gc.rhs) + gr = gp.args[0] + gset = Intersection(gr.as_set(), state_index) + gstate = list(gset)[0] + prob[gset] = p + else: + _, gstate = (gc.lhs.key, gc.rhs) if isinstance(gc.lhs, RandomIndexedSymbol) \ + else (gc.rhs.key, gc.lhs) + + if not all(k in self.index_set for k in (rv.key, min_key_rv.key)): + raise IndexError("The timestamps of the process are not in it's index set.") + states = Intersection(states, state_index) if not isinstance(self.number_of_states, Symbol) else states + for state in Union(states, FiniteSet(gstate)): + if not state.is_Integer or Ge(state, mat.shape[0]) is True: + raise IndexError("No information is available for (%s, %s) in " + "transition probabilities of shape, (%s, %s). " + "State space is zero indexed." + %(gstate, state, mat.shape[0], mat.shape[1])) + if prob: + gstates = Union(*prob.keys()) + if len(gstates) == 1: + gstate = list(gstates)[0] + gprob = list(prob.values())[0] + prob[gstates] = gprob + elif len(gstates) == len(state_index) - 1: + gstate = list(state_index - gstates)[0] + gprob = S.One - sum(prob.values()) + prob[state_index - gstates] = gprob + else: + raise ValueError("Conflicting information.") + else: + gprob = S.One + + if min_key_rv == rv: + return sum(prob[FiniteSet(state)] for state in states) + if isinstance(self, ContinuousMarkovChain): + return gprob * sum(trans_probs(rv.key - min_key_rv.key).__getitem__((gstate, state)) + for state in states) + if isinstance(self, DiscreteMarkovChain): + return gprob * sum((trans_probs**(rv.key - min_key_rv.key)).__getitem__((gstate, state)) + for state in states) + + if isinstance(condition, Not): + expr = condition.args[0] + return S.One - self.probability(expr, given_condition, evaluate, **kwargs) + + if isinstance(condition, And): + compute_later, state2cond, conds = [], {}, condition.args + for expr in conds: + if isinstance(expr, Relational): + ris = list(expr.atoms(RandomIndexedSymbol))[0] + if state2cond.get(ris, None) is None: + state2cond[ris] = S.true + state2cond[ris] &= expr + else: + compute_later.append(expr) + ris = [] + for ri in state2cond: + ris.append(ri) + cset = Intersection(state2cond[ri].as_set(), state_index) + if len(cset) == 0: + return S.Zero + state2cond[ri] = cset.as_relational(ri) + sorted_ris = sorted(ris, key=lambda ri: ri.key) + prod = self.probability(state2cond[sorted_ris[0]], given_condition, evaluate, **kwargs) + for i in range(1, len(sorted_ris)): + ri, prev_ri = sorted_ris[i], sorted_ris[i-1] + if not isinstance(state2cond[ri], Eq): + raise ValueError("The process is in multiple states at %s, unable to determine the probability."%(ri)) + mat_of = TransitionMatrixOf(self, mat) if isinstance(self, DiscreteMarkovChain) else GeneratorMatrixOf(self, mat) + prod *= self.probability(state2cond[ri], state2cond[prev_ri] + & mat_of + & StochasticStateSpaceOf(self, state_index), + evaluate, **kwargs) + for expr in compute_later: + prod *= self.probability(expr, given_condition, evaluate, **kwargs) + return prod + + if isinstance(condition, Or): + return sum(self.probability(expr, given_condition, evaluate, **kwargs) + for expr in condition.args) + + raise NotImplementedError("Mechanism for handling (%s, %s) queries hasn't been " + "implemented yet."%(condition, given_condition)) + + def _symbolic_probability(self, condition, new_given_condition, rv, min_key_rv): + #Function to calculate probability for queries with symbols + if isinstance(condition, Relational): + curr_state = new_given_condition.rhs if isinstance(new_given_condition.lhs, RandomIndexedSymbol) \ + else new_given_condition.lhs + next_state = condition.rhs if isinstance(condition.lhs, RandomIndexedSymbol) \ + else condition.lhs + + if isinstance(condition, (Eq, Ne)): + if isinstance(self, DiscreteMarkovChain): + P = self.transition_probabilities**(rv[0].key - min_key_rv.key) + else: + P = exp(self.generator_matrix*(rv[0].key - min_key_rv.key)) + prob = P[curr_state, next_state] if isinstance(condition, Eq) else 1 - P[curr_state, next_state] + return Piecewise((prob, rv[0].key > min_key_rv.key), (Probability(condition), True)) + else: + upper = 1 + greater = False + if isinstance(condition, (Ge, Lt)): + upper = 0 + if isinstance(condition, (Ge, Gt)): + greater = True + k = Dummy('k') + condition = Eq(condition.lhs, k) if isinstance(condition.lhs, RandomIndexedSymbol)\ + else Eq(condition.rhs, k) + total = Sum(self.probability(condition, new_given_condition), (k, next_state + upper, self.state_space._sup)) + return Piecewise((total, rv[0].key > min_key_rv.key), (Probability(condition), True)) if greater\ + else Piecewise((1 - total, rv[0].key > min_key_rv.key), (Probability(condition), True)) + else: + return Probability(condition, new_given_condition) + + def expectation(self, expr, condition=None, evaluate=True, **kwargs): + """ + Handles expectation queries for markov process. + + Parameters + ========== + + expr: RandomIndexedSymbol, Relational, Logic + Condition for which expectation has to be computed. Must + contain a RandomIndexedSymbol of the process. + condition: Relational, Logic + The given conditions under which computations should be done. + + Returns + ======= + + Expectation + Unevaluated object if computations cannot be done due to + insufficient information. + Expr + In all other cases when the computations are successful. + + Note + ==== + + Any information passed at the time of query overrides + any information passed at the time of object creation like + transition probabilities, state space. + + Pass the transition matrix using TransitionMatrixOf, + generator matrix using GeneratorMatrixOf and state space + using StochasticStateSpaceOf in given_condition using & or And. + """ + + check, mat, state_index, condition = \ + self._preprocess(condition, evaluate) + + if check: + return Expectation(expr, condition) + + rvs = random_symbols(expr) + if isinstance(expr, Expr) and isinstance(condition, Eq) \ + and len(rvs) == 1: + # handle queries similar to E(f(X[i]), Eq(X[i-m], )) + condition=self.replace_with_index(condition) + state_index=self.replace_with_index(state_index) + rv = list(rvs)[0] + lhsg, rhsg = condition.lhs, condition.rhs + if not isinstance(lhsg, RandomIndexedSymbol): + lhsg, rhsg = (rhsg, lhsg) + if rhsg not in state_index: + raise ValueError("%s state is not in the state space."%(rhsg)) + if rv.key < lhsg.key: + raise ValueError("Incorrect given condition is given, expectation " + "time %s < time %s"%(rv.key, rv.key)) + mat_of = TransitionMatrixOf(self, mat) if isinstance(self, DiscreteMarkovChain) else GeneratorMatrixOf(self, mat) + cond = condition & mat_of & \ + StochasticStateSpaceOf(self, state_index) + func = lambda s: self.probability(Eq(rv, s), cond) * expr.subs(rv, self._state_index[s]) + return sum(func(s) for s in state_index) + + raise NotImplementedError("Mechanism for handling (%s, %s) queries hasn't been " + "implemented yet."%(expr, condition)) + +class DiscreteMarkovChain(DiscreteTimeStochasticProcess, MarkovProcess): + """ + Represents a finite discrete time-homogeneous Markov chain. + + This type of Markov Chain can be uniquely characterised by + its (ordered) state space and its one-step transition probability + matrix. + + Parameters + ========== + + sym: + The name given to the Markov Chain + state_space: + Optional, by default, Range(n) + trans_probs: + Optional, by default, MatrixSymbol('_T', n, n) + + Examples + ======== + + >>> from sympy.stats import DiscreteMarkovChain, TransitionMatrixOf, P, E + >>> from sympy import Matrix, MatrixSymbol, Eq, symbols + >>> T = Matrix([[0.5, 0.2, 0.3],[0.2, 0.5, 0.3],[0.2, 0.3, 0.5]]) + >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + >>> YS = DiscreteMarkovChain("Y") + + >>> Y.state_space + {0, 1, 2} + >>> Y.transition_probabilities + Matrix([ + [0.5, 0.2, 0.3], + [0.2, 0.5, 0.3], + [0.2, 0.3, 0.5]]) + >>> TS = MatrixSymbol('T', 3, 3) + >>> P(Eq(YS[3], 2), Eq(YS[1], 1) & TransitionMatrixOf(YS, TS)) + T[0, 2]*T[1, 0] + T[1, 1]*T[1, 2] + T[1, 2]*T[2, 2] + >>> P(Eq(Y[3], 2), Eq(Y[1], 1)).round(2) + 0.36 + + Probabilities will be calculated based on indexes rather + than state names. For example, with the Sunny-Cloudy-Rainy + model with string state names: + + >>> from sympy.core.symbol import Str + >>> Y = DiscreteMarkovChain("Y", [Str('Sunny'), Str('Cloudy'), Str('Rainy')], T) + >>> P(Eq(Y[3], 2), Eq(Y[1], 1)).round(2) + 0.36 + + This gives the same answer as the ``[0, 1, 2]`` state space. + Currently, there is no support for state names within probability + and expectation statements. Here is a work-around using ``Str``: + + >>> P(Eq(Str('Rainy'), Y[3]), Eq(Y[1], Str('Cloudy'))).round(2) + 0.36 + + Symbol state names can also be used: + + >>> sunny, cloudy, rainy = symbols('Sunny, Cloudy, Rainy') + >>> Y = DiscreteMarkovChain("Y", [sunny, cloudy, rainy], T) + >>> P(Eq(Y[3], rainy), Eq(Y[1], cloudy)).round(2) + 0.36 + + Expectations will be calculated as follows: + + >>> E(Y[3], Eq(Y[1], cloudy)) + 0.38*Cloudy + 0.36*Rainy + 0.26*Sunny + + Probability of expressions with multiple RandomIndexedSymbols + can also be calculated provided there is only 1 RandomIndexedSymbol + in the given condition. It is always better to use Rational instead + of floating point numbers for the probabilities in the + transition matrix to avoid errors. + + >>> from sympy import Gt, Le, Rational + >>> T = Matrix([[Rational(5, 10), Rational(3, 10), Rational(2, 10)], [Rational(2, 10), Rational(7, 10), Rational(1, 10)], [Rational(3, 10), Rational(3, 10), Rational(4, 10)]]) + >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + >>> P(Eq(Y[3], Y[1]), Eq(Y[0], 0)).round(3) + 0.409 + >>> P(Gt(Y[3], Y[1]), Eq(Y[0], 0)).round(2) + 0.36 + >>> P(Le(Y[15], Y[10]), Eq(Y[8], 2)).round(7) + 0.6963328 + + Symbolic probability queries are also supported + + >>> a, b, c, d = symbols('a b c d') + >>> T = Matrix([[Rational(1, 10), Rational(4, 10), Rational(5, 10)], [Rational(3, 10), Rational(4, 10), Rational(3, 10)], [Rational(7, 10), Rational(2, 10), Rational(1, 10)]]) + >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + >>> query = P(Eq(Y[a], b), Eq(Y[c], d)) + >>> query.subs({a:10, b:2, c:5, d:1}).round(4) + 0.3096 + >>> P(Eq(Y[10], 2), Eq(Y[5], 1)).evalf().round(4) + 0.3096 + >>> query_gt = P(Gt(Y[a], b), Eq(Y[c], d)) + >>> query_gt.subs({a:21, b:0, c:5, d:0}).evalf().round(5) + 0.64705 + >>> P(Gt(Y[21], 0), Eq(Y[5], 0)).round(5) + 0.64705 + + There is limited support for arbitrarily sized states: + + >>> n = symbols('n', nonnegative=True, integer=True) + >>> T = MatrixSymbol('T', n, n) + >>> Y = DiscreteMarkovChain("Y", trans_probs=T) + >>> Y.state_space + Range(0, n, 1) + >>> query = P(Eq(Y[a], b), Eq(Y[c], d)) + >>> query.subs({a:10, b:2, c:5, d:1}) + (T**5)[1, 2] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Markov_chain#Discrete-time_Markov_chain + .. [2] https://web.archive.org/web/20201230182007/https://www.dartmouth.edu/~chance/teaching_aids/books_articles/probability_book/Chapter11.pdf + """ + index_set = S.Naturals0 + + def __new__(cls, sym, state_space=None, trans_probs=None): + sym = _symbol_converter(sym) + + state_space, trans_probs = MarkovProcess._sanity_checks(state_space, trans_probs) + + obj = Basic.__new__(cls, sym, state_space, trans_probs) # type: ignore + indices = {} + if isinstance(obj.number_of_states, Integer): + for index, state in enumerate(obj._state_index): + indices[state] = index + obj.index_of = indices + return obj + + @property + def transition_probabilities(self): + """ + Transition probabilities of discrete Markov chain, + either an instance of Matrix or MatrixSymbol. + """ + return self.args[2] + + def communication_classes(self) -> list[tuple[list[Basic], Boolean, Integer]]: + """ + Returns the list of communication classes that partition + the states of the markov chain. + + A communication class is defined to be a set of states + such that every state in that set is reachable from + every other state in that set. Due to its properties + this forms a class in the mathematical sense. + Communication classes are also known as recurrence + classes. + + Returns + ======= + + classes + The ``classes`` are a list of tuples. Each + tuple represents a single communication class + with its properties. The first element in the + tuple is the list of states in the class, the + second element is whether the class is recurrent + and the third element is the period of the + communication class. + + Examples + ======== + + >>> from sympy.stats import DiscreteMarkovChain + >>> from sympy import Matrix + >>> T = Matrix([[0, 1, 0], + ... [1, 0, 0], + ... [1, 0, 0]]) + >>> X = DiscreteMarkovChain('X', [1, 2, 3], T) + >>> classes = X.communication_classes() + >>> for states, is_recurrent, period in classes: + ... states, is_recurrent, period + ([1, 2], True, 2) + ([3], False, 1) + + From this we can see that states ``1`` and ``2`` + communicate, are recurrent and have a period + of 2. We can also see state ``3`` is transient + with a period of 1. + + Notes + ===== + + The algorithm used is of order ``O(n**2)`` where + ``n`` is the number of states in the markov chain. + It uses Tarjan's algorithm to find the classes + themselves and then it uses a breadth-first search + algorithm to find each class's periodicity. + Most of the algorithm's components approach ``O(n)`` + as the matrix becomes more and more sparse. + + References + ========== + + .. [1] https://web.archive.org/web/20220207032113/https://www.columbia.edu/~ww2040/4701Sum07/4701-06-Notes-MCII.pdf + .. [2] https://cecas.clemson.edu/~shierd/Shier/markov.pdf + .. [3] https://www.proquest.com/openview/4adc6a51d8371be5b0e4c7dff287fc70/1?pq-origsite=gscholar&cbl=2026366&diss=y + .. [4] https://www.mathworks.com/help/econ/dtmc.classify.html + """ + n = self.number_of_states + T = self.transition_probabilities + + if isinstance(T, MatrixSymbol): + raise NotImplementedError("Cannot perform the operation with a symbolic matrix.") + + # begin Tarjan's algorithm + V = Range(n) + # don't use state names. Rather use state + # indexes since we use them for matrix + # indexing here and later onward + E = [(i, j) for i in V for j in V if T[i, j] != 0] + classes = strongly_connected_components((V, E)) + # end Tarjan's algorithm + + recurrence = [] + periods = [] + for class_ in classes: + # begin recurrent check (similar to self._check_trans_probs()) + submatrix = T[class_, class_] # get the submatrix with those states + is_recurrent = S.true + rows = submatrix.tolist() + for row in rows: + if (sum(row) - 1) != 0: + is_recurrent = S.false + break + recurrence.append(is_recurrent) + # end recurrent check + + # begin breadth-first search + non_tree_edge_values: set[int] = set() + visited = {class_[0]} + newly_visited = {class_[0]} + level = {class_[0]: 0} + current_level = 0 + done = False # imitate a do-while loop + while not done: # runs at most len(class_) times + done = len(visited) == len(class_) + current_level += 1 + + # this loop and the while loop above run a combined len(class_) number of times. + # so this triple nested loop runs through each of the n states once. + for i in newly_visited: + + # the loop below runs len(class_) number of times + # complexity is around about O(n * avg(len(class_))) + newly_visited = {j for j in class_ if T[i, j] != 0} + + new_tree_edges = newly_visited.difference(visited) + for j in new_tree_edges: + level[j] = current_level + + new_non_tree_edges = newly_visited.intersection(visited) + new_non_tree_edge_values = {level[i]-level[j]+1 for j in new_non_tree_edges} + + non_tree_edge_values = non_tree_edge_values.union(new_non_tree_edge_values) + visited = visited.union(new_tree_edges) + + # igcd needs at least 2 arguments + positive_ntev = {val_e for val_e in non_tree_edge_values if val_e > 0} + if len(positive_ntev) == 0: + periods.append(len(class_)) + elif len(positive_ntev) == 1: + periods.append(positive_ntev.pop()) + else: + periods.append(igcd(*positive_ntev)) + # end breadth-first search + + # convert back to the user's state names + classes = [[_sympify(self._state_index[i]) for i in class_] for class_ in classes] + return list(zip(classes, recurrence, map(Integer,periods))) + + def fundamental_matrix(self): + """ + Each entry fundamental matrix can be interpreted as + the expected number of times the chains is in state j + if it started in state i. + + References + ========== + + .. [1] https://lips.cs.princeton.edu/the-fundamental-matrix-of-a-finite-markov-chain/ + + """ + _, _, _, Q = self.decompose() + + if Q.shape[0] > 0: # if non-ergodic + I = eye(Q.shape[0]) + if (I - Q).det() == 0: + raise ValueError("The fundamental matrix doesn't exist.") + return (I - Q).inv().as_immutable() + else: # if ergodic + P = self.transition_probabilities + I = eye(P.shape[0]) + w = self.fixed_row_vector() + W = Matrix([list(w) for i in range(0, P.shape[0])]) + if (I - P + W).det() == 0: + raise ValueError("The fundamental matrix doesn't exist.") + return (I - P + W).inv().as_immutable() + + def absorbing_probabilities(self): + """ + Computes the absorbing probabilities, i.e. + the ij-th entry of the matrix denotes the + probability of Markov chain being absorbed + in state j starting from state i. + """ + _, _, R, _ = self.decompose() + N = self.fundamental_matrix() + if R is None or N is None: + return None + return N*R + + def absorbing_probabilites(self): + sympy_deprecation_warning( + """ + DiscreteMarkovChain.absorbing_probabilites() is deprecated. Use + absorbing_probabilities() instead (note the spelling difference). + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-absorbing_probabilites", + ) + return self.absorbing_probabilities() + + def is_regular(self): + tuples = self.communication_classes() + if len(tuples) == 0: + return S.false # not defined for a 0x0 matrix + classes, _, periods = list(zip(*tuples)) + return And(len(classes) == 1, periods[0] == 1) + + def is_ergodic(self): + tuples = self.communication_classes() + if len(tuples) == 0: + return S.false # not defined for a 0x0 matrix + classes, _, _ = list(zip(*tuples)) + return S(len(classes) == 1) + + def is_absorbing_state(self, state): + trans_probs = self.transition_probabilities + if isinstance(trans_probs, ImmutableMatrix) and \ + state < trans_probs.shape[0]: + return S(trans_probs[state, state]) is S.One + + def is_absorbing_chain(self): + states, A, B, C = self.decompose() + r = A.shape[0] + return And(r > 0, A == Identity(r).as_explicit()) + + def stationary_distribution(self, condition_set=False) -> ImmutableMatrix | ConditionSet | Lambda: + r""" + The stationary distribution is any row vector, p, that solves p = pP, + is row stochastic and each element in p must be nonnegative. + That means in matrix form: :math:`(P-I)^T p^T = 0` and + :math:`(1, \dots, 1) p = 1` + where ``P`` is the one-step transition matrix. + + All time-homogeneous Markov Chains with a finite state space + have at least one stationary distribution. In addition, if + a finite time-homogeneous Markov Chain is irreducible, the + stationary distribution is unique. + + Parameters + ========== + + condition_set : bool + If the chain has a symbolic size or transition matrix, + it will return a ``Lambda`` if ``False`` and return a + ``ConditionSet`` if ``True``. + + Examples + ======== + + >>> from sympy.stats import DiscreteMarkovChain + >>> from sympy import Matrix, S + + An irreducible Markov Chain + + >>> T = Matrix([[S(1)/2, S(1)/2, 0], + ... [S(4)/5, S(1)/5, 0], + ... [1, 0, 0]]) + >>> X = DiscreteMarkovChain('X', trans_probs=T) + >>> X.stationary_distribution() + Matrix([[8/13, 5/13, 0]]) + + A reducible Markov Chain + + >>> T = Matrix([[S(1)/2, S(1)/2, 0], + ... [S(4)/5, S(1)/5, 0], + ... [0, 0, 1]]) + >>> X = DiscreteMarkovChain('X', trans_probs=T) + >>> X.stationary_distribution() + Matrix([[8/13 - 8*tau0/13, 5/13 - 5*tau0/13, tau0]]) + + >>> Y = DiscreteMarkovChain('Y') + >>> Y.stationary_distribution() + Lambda((wm, _T), Eq(wm*_T, wm)) + + >>> Y.stationary_distribution(condition_set=True) + ConditionSet(wm, Eq(wm*_T, wm)) + + References + ========== + + .. [1] https://www.probabilitycourse.com/chapter11/11_2_6_stationary_and_limiting_distributions.php + .. [2] https://web.archive.org/web/20210508104430/https://galton.uchicago.edu/~yibi/teaching/stat317/2014/Lectures/Lecture4_6up.pdf + + See Also + ======== + + sympy.stats.DiscreteMarkovChain.limiting_distribution + """ + trans_probs = self.transition_probabilities + n = self.number_of_states + + if n == 0: + return ImmutableMatrix(Matrix([[]])) + + # symbolic matrix version + if isinstance(trans_probs, MatrixSymbol): + wm = MatrixSymbol('wm', 1, n) + if condition_set: + return ConditionSet(wm, Eq(wm * trans_probs, wm)) + else: + return Lambda((wm, trans_probs), Eq(wm * trans_probs, wm)) + + # numeric matrix version + a = Matrix(trans_probs - Identity(n)).T + a[0, 0:n] = ones(1, n) # type: ignore + b = zeros(n, 1) + b[0, 0] = 1 + + soln = list(linsolve((a, b)))[0] + return ImmutableMatrix([soln]) + + def fixed_row_vector(self): + """ + A wrapper for ``stationary_distribution()``. + """ + return self.stationary_distribution() + + @property + def limiting_distribution(self): + """ + The fixed row vector is the limiting + distribution of a discrete Markov chain. + """ + return self.fixed_row_vector() + + def decompose(self) -> tuple[list[Basic], ImmutableMatrix, ImmutableMatrix, ImmutableMatrix]: + """ + Decomposes the transition matrix into submatrices with + special properties. + + The transition matrix can be decomposed into 4 submatrices: + - A - the submatrix from recurrent states to recurrent states. + - B - the submatrix from transient to recurrent states. + - C - the submatrix from transient to transient states. + - O - the submatrix of zeros for recurrent to transient states. + + Returns + ======= + + states, A, B, C + ``states`` - a list of state names with the first being + the recurrent states and the last being + the transient states in the order + of the row names of A and then the row names of C. + ``A`` - the submatrix from recurrent states to recurrent states. + ``B`` - the submatrix from transient to recurrent states. + ``C`` - the submatrix from transient to transient states. + + Examples + ======== + + >>> from sympy.stats import DiscreteMarkovChain + >>> from sympy import Matrix, S + + One can decompose this chain for example: + + >>> T = Matrix([[S(1)/2, S(1)/2, 0, 0, 0], + ... [S(2)/5, S(1)/5, S(2)/5, 0, 0], + ... [0, 0, 1, 0, 0], + ... [0, 0, S(1)/2, S(1)/2, 0], + ... [S(1)/2, 0, 0, 0, S(1)/2]]) + >>> X = DiscreteMarkovChain('X', trans_probs=T) + >>> states, A, B, C = X.decompose() + >>> states + [2, 0, 1, 3, 4] + + >>> A # recurrent to recurrent + Matrix([[1]]) + + >>> B # transient to recurrent + Matrix([ + [ 0], + [2/5], + [1/2], + [ 0]]) + + >>> C # transient to transient + Matrix([ + [1/2, 1/2, 0, 0], + [2/5, 1/5, 0, 0], + [ 0, 0, 1/2, 0], + [1/2, 0, 0, 1/2]]) + + This means that state 2 is the only absorbing state + (since A is a 1x1 matrix). B is a 4x1 matrix since + the 4 remaining transient states all merge into recurrent + state 2. And C is the 4x4 matrix that shows how the + transient states 0, 1, 3, 4 all interact. + + See Also + ======== + + sympy.stats.DiscreteMarkovChain.communication_classes + sympy.stats.DiscreteMarkovChain.canonical_form + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Absorbing_Markov_chain + .. [2] https://people.brandeis.edu/~igusa/Math56aS08/Math56a_S08_notes015.pdf + """ + trans_probs = self.transition_probabilities + + classes = self.communication_classes() + r_states = [] + t_states = [] + + for states, recurrent, period in classes: + if recurrent: + r_states += states + else: + t_states += states + + states = r_states + t_states + indexes = [self.index_of[state] for state in states] # type: ignore + + A = Matrix(len(r_states), len(r_states), + lambda i, j: trans_probs[indexes[i], indexes[j]]) + + B = Matrix(len(t_states), len(r_states), + lambda i, j: trans_probs[indexes[len(r_states) + i], indexes[j]]) + + C = Matrix(len(t_states), len(t_states), + lambda i, j: trans_probs[indexes[len(r_states) + i], indexes[len(r_states) + j]]) + + return states, A.as_immutable(), B.as_immutable(), C.as_immutable() + + def canonical_form(self) -> tuple[list[Basic], ImmutableMatrix]: + """ + Reorders the one-step transition matrix + so that recurrent states appear first and transient + states appear last. Other representations include inserting + transient states first and recurrent states last. + + Returns + ======= + + states, P_new + ``states`` is the list that describes the order of the + new states in the matrix + so that the ith element in ``states`` is the state of the + ith row of A. + ``P_new`` is the new transition matrix in canonical form. + + Examples + ======== + + >>> from sympy.stats import DiscreteMarkovChain + >>> from sympy import Matrix, S + + You can convert your chain into canonical form: + + >>> T = Matrix([[S(1)/2, S(1)/2, 0, 0, 0], + ... [S(2)/5, S(1)/5, S(2)/5, 0, 0], + ... [0, 0, 1, 0, 0], + ... [0, 0, S(1)/2, S(1)/2, 0], + ... [S(1)/2, 0, 0, 0, S(1)/2]]) + >>> X = DiscreteMarkovChain('X', list(range(1, 6)), trans_probs=T) + >>> states, new_matrix = X.canonical_form() + >>> states + [3, 1, 2, 4, 5] + + >>> new_matrix + Matrix([ + [ 1, 0, 0, 0, 0], + [ 0, 1/2, 1/2, 0, 0], + [2/5, 2/5, 1/5, 0, 0], + [1/2, 0, 0, 1/2, 0], + [ 0, 1/2, 0, 0, 1/2]]) + + The new states are [3, 1, 2, 4, 5] and you can + create a new chain with this and its canonical + form will remain the same (since it is already + in canonical form). + + >>> X = DiscreteMarkovChain('X', states, new_matrix) + >>> states, new_matrix = X.canonical_form() + >>> states + [3, 1, 2, 4, 5] + + >>> new_matrix + Matrix([ + [ 1, 0, 0, 0, 0], + [ 0, 1/2, 1/2, 0, 0], + [2/5, 2/5, 1/5, 0, 0], + [1/2, 0, 0, 1/2, 0], + [ 0, 1/2, 0, 0, 1/2]]) + + This is not limited to absorbing chains: + + >>> T = Matrix([[0, 5, 5, 0, 0], + ... [0, 0, 0, 10, 0], + ... [5, 0, 5, 0, 0], + ... [0, 10, 0, 0, 0], + ... [0, 3, 0, 3, 4]])/10 + >>> X = DiscreteMarkovChain('X', trans_probs=T) + >>> states, new_matrix = X.canonical_form() + >>> states + [1, 3, 0, 2, 4] + + >>> new_matrix + Matrix([ + [ 0, 1, 0, 0, 0], + [ 1, 0, 0, 0, 0], + [ 1/2, 0, 0, 1/2, 0], + [ 0, 0, 1/2, 1/2, 0], + [3/10, 3/10, 0, 0, 2/5]]) + + See Also + ======== + + sympy.stats.DiscreteMarkovChain.communication_classes + sympy.stats.DiscreteMarkovChain.decompose + + References + ========== + + .. [1] https://onlinelibrary.wiley.com/doi/pdf/10.1002/9780470316887.app1 + .. [2] http://www.columbia.edu/~ww2040/6711F12/lect1023big.pdf + """ + states, A, B, C = self.decompose() + O = zeros(A.shape[0], C.shape[1]) + return states, BlockMatrix([[A, O], [B, C]]).as_explicit() + + def sample(self): + """ + Returns + ======= + + sample: iterator object + iterator object containing the sample + + """ + if not isinstance(self.transition_probabilities, (Matrix, ImmutableMatrix)): + raise ValueError("Transition Matrix must be provided for sampling") + Tlist = self.transition_probabilities.tolist() + samps = [random.choice(list(self.state_space))] + yield samps[0] + time = 1 + densities = {} + for state in self.state_space: + states = list(self.state_space) + densities[state] = {states[i]: Tlist[state][i] + for i in range(len(states))} + while time < S.Infinity: + samps.append((next(sample_iter(FiniteRV("_", densities[samps[time - 1]]))))) + yield samps[time] + time += 1 + +class ContinuousMarkovChain(ContinuousTimeStochasticProcess, MarkovProcess): + """ + Represents continuous time Markov chain. + + Parameters + ========== + + sym : Symbol/str + state_space : Set + Optional, by default, S.Reals + gen_mat : Matrix/ImmutableMatrix/MatrixSymbol + Optional, by default, None + + Examples + ======== + + >>> from sympy.stats import ContinuousMarkovChain, P + >>> from sympy import Matrix, S, Eq, Gt + >>> G = Matrix([[-S(1), S(1)], [S(1), -S(1)]]) + >>> C = ContinuousMarkovChain('C', state_space=[0, 1], gen_mat=G) + >>> C.limiting_distribution() + Matrix([[1/2, 1/2]]) + >>> C.state_space + {0, 1} + >>> C.generator_matrix + Matrix([ + [-1, 1], + [ 1, -1]]) + + Probability queries are supported + + >>> P(Eq(C(1.96), 0), Eq(C(0.78), 1)).round(5) + 0.45279 + >>> P(Gt(C(1.7), 0), Eq(C(0.82), 1)).round(5) + 0.58602 + + Probability of expressions with multiple RandomIndexedSymbols + can also be calculated provided there is only 1 RandomIndexedSymbol + in the given condition. It is always better to use Rational instead + of floating point numbers for the probabilities in the + generator matrix to avoid errors. + + >>> from sympy import Gt, Le, Rational + >>> G = Matrix([[-S(1), Rational(1, 10), Rational(9, 10)], [Rational(2, 5), -S(1), Rational(3, 5)], [Rational(1, 2), Rational(1, 2), -S(1)]]) + >>> C = ContinuousMarkovChain('C', state_space=[0, 1, 2], gen_mat=G) + >>> P(Eq(C(3.92), C(1.75)), Eq(C(0.46), 0)).round(5) + 0.37933 + >>> P(Gt(C(3.92), C(1.75)), Eq(C(0.46), 0)).round(5) + 0.34211 + >>> P(Le(C(1.57), C(3.14)), Eq(C(1.22), 1)).round(4) + 0.7143 + + Symbolic probability queries are also supported + + >>> from sympy import symbols + >>> a,b,c,d = symbols('a b c d') + >>> G = Matrix([[-S(1), Rational(1, 10), Rational(9, 10)], [Rational(2, 5), -S(1), Rational(3, 5)], [Rational(1, 2), Rational(1, 2), -S(1)]]) + >>> C = ContinuousMarkovChain('C', state_space=[0, 1, 2], gen_mat=G) + >>> query = P(Eq(C(a), b), Eq(C(c), d)) + >>> query.subs({a:3.65, b:2, c:1.78, d:1}).evalf().round(10) + 0.4002723175 + >>> P(Eq(C(3.65), 2), Eq(C(1.78), 1)).round(10) + 0.4002723175 + >>> query_gt = P(Gt(C(a), b), Eq(C(c), d)) + >>> query_gt.subs({a:43.2, b:0, c:3.29, d:2}).evalf().round(10) + 0.6832579186 + >>> P(Gt(C(43.2), 0), Eq(C(3.29), 2)).round(10) + 0.6832579186 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Markov_chain#Continuous-time_Markov_chain + .. [2] https://u.math.biu.ac.il/~amirgi/CTMCnotes.pdf + """ + index_set = S.Reals + + def __new__(cls, sym, state_space=None, gen_mat=None): + sym = _symbol_converter(sym) + state_space, gen_mat = MarkovProcess._sanity_checks(state_space, gen_mat) + obj = Basic.__new__(cls, sym, state_space, gen_mat) + indices = {} + if isinstance(obj.number_of_states, Integer): + for index, state in enumerate(obj.state_space): + indices[state] = index + obj.index_of = indices + return obj + + @property + def generator_matrix(self): + return self.args[2] + + @cacheit + def transition_probabilities(self, gen_mat=None): + t = Dummy('t') + if isinstance(gen_mat, (Matrix, ImmutableMatrix)) and \ + gen_mat.is_diagonalizable(): + # for faster computation use diagonalized generator matrix + Q, D = gen_mat.diagonalize() + return Lambda(t, Q*exp(t*D)*Q.inv()) + if gen_mat != None: + return Lambda(t, exp(t*gen_mat)) + + def limiting_distribution(self): + gen_mat = self.generator_matrix + if gen_mat is None: + return None + if isinstance(gen_mat, MatrixSymbol): + wm = MatrixSymbol('wm', 1, gen_mat.shape[0]) + return Lambda((wm, gen_mat), Eq(wm*gen_mat, wm)) + w = IndexedBase('w') + wi = [w[i] for i in range(gen_mat.shape[0])] + wm = Matrix([wi]) + eqs = (wm*gen_mat).tolist()[0] + eqs.append(sum(wi) - 1) + soln = list(linsolve(eqs, wi))[0] + return ImmutableMatrix([soln]) + + +class BernoulliProcess(DiscreteTimeStochasticProcess): + """ + The Bernoulli process consists of repeated + independent Bernoulli process trials with the same parameter `p`. + It's assumed that the probability `p` applies to every + trial and that the outcomes of each trial + are independent of all the rest. Therefore Bernoulli Process + is Discrete State and Discrete Time Stochastic Process. + + Parameters + ========== + + sym : Symbol/str + success : Integer/str + The event which is considered to be success. Default: 1. + failure: Integer/str + The event which is considered to be failure. Default: 0. + p : Real Number between 0 and 1 + Represents the probability of getting success. + + Examples + ======== + + >>> from sympy.stats import BernoulliProcess, P, E + >>> from sympy import Eq, Gt + >>> B = BernoulliProcess("B", p=0.7, success=1, failure=0) + >>> B.state_space + {0, 1} + >>> B.p.round(2) + 0.70 + >>> B.success + 1 + >>> B.failure + 0 + >>> X = B[1] + B[2] + B[3] + >>> P(Eq(X, 0)).round(2) + 0.03 + >>> P(Eq(X, 2)).round(2) + 0.44 + >>> P(Eq(X, 4)).round(2) + 0 + >>> P(Gt(X, 1)).round(2) + 0.78 + >>> P(Eq(B[1], 0) & Eq(B[2], 1) & Eq(B[3], 0) & Eq(B[4], 1)).round(2) + 0.04 + >>> B.joint_distribution(B[1], B[2]) + JointDistributionHandmade(Lambda((B[1], B[2]), Piecewise((0.7, Eq(B[1], 1)), + (0.3, Eq(B[1], 0)), (0, True))*Piecewise((0.7, Eq(B[2], 1)), (0.3, Eq(B[2], 0)), + (0, True)))) + >>> E(2*B[1] + B[2]).round(2) + 2.10 + >>> P(B[1] < 1).round(2) + 0.30 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bernoulli_process + .. [2] https://mathcs.clarku.edu/~djoyce/ma217/bernoulli.pdf + + """ + + index_set = S.Naturals0 + + def __new__(cls, sym, p, success=1, failure=0): + _value_check(p >= 0 and p <= 1, 'Value of p must be between 0 and 1.') + sym = _symbol_converter(sym) + p = _sympify(p) + success = _sym_sympify(success) + failure = _sym_sympify(failure) + return Basic.__new__(cls, sym, p, success, failure) + + @property + def symbol(self): + return self.args[0] + + @property + def p(self): + return self.args[1] + + @property + def success(self): + return self.args[2] + + @property + def failure(self): + return self.args[3] + + @property + def state_space(self): + return _set_converter([self.success, self.failure]) + + def distribution(self, key=None): + if key is None: + self._deprecation_warn_distribution() + return BernoulliDistribution(self.p) + return BernoulliDistribution(self.p, self.success, self.failure) + + def simple_rv(self, rv): + return Bernoulli(rv.name, p=self.p, + succ=self.success, fail=self.failure) + + def expectation(self, expr, condition=None, evaluate=True, **kwargs): + """ + Computes expectation. + + Parameters + ========== + + expr : RandomIndexedSymbol, Relational, Logic + Condition for which expectation has to be computed. Must + contain a RandomIndexedSymbol of the process. + condition : Relational, Logic + The given conditions under which computations should be done. + + Returns + ======= + + Expectation of the RandomIndexedSymbol. + + """ + + return _SubstituteRV._expectation(expr, condition, evaluate, **kwargs) + + def probability(self, condition, given_condition=None, evaluate=True, **kwargs): + """ + Computes probability. + + Parameters + ========== + + condition : Relational + Condition for which probability has to be computed. Must + contain a RandomIndexedSymbol of the process. + given_condition : Relational, Logic + The given conditions under which computations should be done. + + Returns + ======= + + Probability of the condition. + + """ + + return _SubstituteRV._probability(condition, given_condition, evaluate, **kwargs) + + def density(self, x): + return Piecewise((self.p, Eq(x, self.success)), + (1 - self.p, Eq(x, self.failure)), + (S.Zero, True)) + +class _SubstituteRV: + """ + Internal class to handle the queries of expectation and probability + by substitution. + """ + + @staticmethod + def _rvindexed_subs(expr, condition=None): + """ + Substitutes the RandomIndexedSymbol with the RandomSymbol with + same name, distribution and probability as RandomIndexedSymbol. + + Parameters + ========== + + expr: RandomIndexedSymbol, Relational, Logic + Condition for which expectation has to be computed. Must + contain a RandomIndexedSymbol of the process. + condition: Relational, Logic + The given conditions under which computations should be done. + + """ + + rvs_expr = random_symbols(expr) + if len(rvs_expr) != 0: + swapdict_expr = {} + for rv in rvs_expr: + if isinstance(rv, RandomIndexedSymbol): + newrv = rv.pspace.process.simple_rv(rv) # substitute with equivalent simple rv + swapdict_expr[rv] = newrv + expr = expr.subs(swapdict_expr) + rvs_cond = random_symbols(condition) + if len(rvs_cond)!=0: + swapdict_cond = {} + for rv in rvs_cond: + if isinstance(rv, RandomIndexedSymbol): + newrv = rv.pspace.process.simple_rv(rv) + swapdict_cond[rv] = newrv + condition = condition.subs(swapdict_cond) + return expr, condition + + @classmethod + def _expectation(self, expr, condition=None, evaluate=True, **kwargs): + """ + Internal method for computing expectation of indexed RV. + + Parameters + ========== + + expr: RandomIndexedSymbol, Relational, Logic + Condition for which expectation has to be computed. Must + contain a RandomIndexedSymbol of the process. + condition: Relational, Logic + The given conditions under which computations should be done. + + Returns + ======= + + Expectation of the RandomIndexedSymbol. + + """ + new_expr, new_condition = self._rvindexed_subs(expr, condition) + + if not is_random(new_expr): + return new_expr + new_pspace = pspace(new_expr) + if new_condition is not None: + new_expr = given(new_expr, new_condition) + if new_expr.is_Add: # As E is Linear + return Add(*[new_pspace.compute_expectation( + expr=arg, evaluate=evaluate, **kwargs) + for arg in new_expr.args]) + return new_pspace.compute_expectation( + new_expr, evaluate=evaluate, **kwargs) + + @classmethod + def _probability(self, condition, given_condition=None, evaluate=True, **kwargs): + """ + Internal method for computing probability of indexed RV + + Parameters + ========== + + condition: Relational + Condition for which probability has to be computed. Must + contain a RandomIndexedSymbol of the process. + given_condition: Relational/And + The given conditions under which computations should be done. + + Returns + ======= + + Probability of the condition. + + """ + new_condition, new_givencondition = self._rvindexed_subs(condition, given_condition) + + if isinstance(new_givencondition, RandomSymbol): + condrv = random_symbols(new_condition) + if len(condrv) == 1 and condrv[0] == new_givencondition: + return BernoulliDistribution(self._probability(new_condition), 0, 1) + + if any(dependent(rv, new_givencondition) for rv in condrv): + return Probability(new_condition, new_givencondition) + else: + return self._probability(new_condition) + + if new_givencondition is not None and \ + not isinstance(new_givencondition, (Relational, Boolean)): + raise ValueError("%s is not a relational or combination of relationals" + % (new_givencondition)) + if new_givencondition == False or new_condition == False: + return S.Zero + if new_condition == True: + return S.One + if not isinstance(new_condition, (Relational, Boolean)): + raise ValueError("%s is not a relational or combination of relationals" + % (new_condition)) + + if new_givencondition is not None: # If there is a condition + # Recompute on new conditional expr + return self._probability(given(new_condition, new_givencondition, **kwargs), **kwargs) + result = pspace(new_condition).probability(new_condition, **kwargs) + if evaluate and hasattr(result, 'doit'): + return result.doit() + else: + return result + +def get_timerv_swaps(expr, condition): + """ + Finds the appropriate interval for each time stamp in expr by parsing + the given condition and returns intervals for each timestamp and + dictionary that maps variable time-stamped Random Indexed Symbol to its + corresponding Random Indexed variable with fixed time stamp. + + Parameters + ========== + + expr: SymPy Expression + Expression containing Random Indexed Symbols with variable time stamps + condition: Relational/Boolean Expression + Expression containing time bounds of variable time stamps in expr + + Examples + ======== + + >>> from sympy.stats.stochastic_process_types import get_timerv_swaps, PoissonProcess + >>> from sympy import symbols, Contains, Interval + >>> x, t, d = symbols('x t d', positive=True) + >>> X = PoissonProcess("X", 3) + >>> get_timerv_swaps(x*X(t), Contains(t, Interval.Lopen(0, 1))) + ([Interval.Lopen(0, 1)], {X(t): X(1)}) + >>> get_timerv_swaps((X(t)**2 + X(d)**2), Contains(t, Interval.Lopen(0, 1)) + ... & Contains(d, Interval.Ropen(1, 4))) # doctest: +SKIP + ([Interval.Ropen(1, 4), Interval.Lopen(0, 1)], {X(d): X(3), X(t): X(1)}) + + Returns + ======= + + intervals: list + List of Intervals/FiniteSet on which each time stamp is defined + rv_swap: dict + Dictionary mapping variable time Random Indexed Symbol to constant time + Random Indexed Variable + + """ + + if not isinstance(condition, (Relational, Boolean)): + raise ValueError("%s is not a relational or combination of relationals" + % (condition)) + expr_syms = list(expr.atoms(RandomIndexedSymbol)) + if isinstance(condition, (And, Or)): + given_cond_args = condition.args + else: # single condition + given_cond_args = (condition, ) + rv_swap = {} + intervals = [] + for expr_sym in expr_syms: + for arg in given_cond_args: + if arg.has(expr_sym.key) and isinstance(expr_sym.key, Symbol): + intv = _set_converter(arg.args[1]) + diff_key = intv._sup - intv._inf + if diff_key == oo: + raise ValueError("%s should have finite bounds" % str(expr_sym.name)) + elif diff_key == S.Zero: # has singleton set + diff_key = intv._sup + rv_swap[expr_sym] = expr_sym.subs({expr_sym.key: diff_key}) + intervals.append(intv) + return intervals, rv_swap + + +class CountingProcess(ContinuousTimeStochasticProcess): + """ + This class handles the common methods of the Counting Processes + such as Poisson, Wiener and Gamma Processes + """ + index_set = _set_converter(Interval(0, oo)) + + @property + def symbol(self): + return self.args[0] + + def expectation(self, expr, condition=None, evaluate=True, **kwargs): + """ + Computes expectation + + Parameters + ========== + + expr: RandomIndexedSymbol, Relational, Logic + Condition for which expectation has to be computed. Must + contain a RandomIndexedSymbol of the process. + condition: Relational, Boolean + The given conditions under which computations should be done, i.e, + the intervals on which each variable time stamp in expr is defined + + Returns + ======= + + Expectation of the given expr + + """ + if condition is not None: + intervals, rv_swap = get_timerv_swaps(expr, condition) + # they are independent when they have non-overlapping intervals + if len(intervals) == 1 or all(Intersection(*intv_comb) == EmptySet + for intv_comb in itertools.combinations(intervals, 2)): + if expr.is_Add: + return Add.fromiter(self.expectation(arg, condition) + for arg in expr.args) + expr = expr.subs(rv_swap) + else: + return Expectation(expr, condition) + + return _SubstituteRV._expectation(expr, evaluate=evaluate, **kwargs) + + def _solve_argwith_tworvs(self, arg): + if arg.args[0].key >= arg.args[1].key or isinstance(arg, Eq): + diff_key = abs(arg.args[0].key - arg.args[1].key) + rv = arg.args[0] + arg = arg.__class__(rv.pspace.process(diff_key), 0) + else: + diff_key = arg.args[1].key - arg.args[0].key + rv = arg.args[1] + arg = arg.__class__(rv.pspace.process(diff_key), 0) + return arg + + def _solve_numerical(self, condition, given_condition=None): + if isinstance(condition, And): + args_list = list(condition.args) + else: + args_list = [condition] + if given_condition is not None: + if isinstance(given_condition, And): + args_list.extend(list(given_condition.args)) + else: + args_list.extend([given_condition]) + # sort the args based on timestamp to get the independent increments in + # each segment using all the condition args as well as given_condition args + args_list = sorted(args_list, key=lambda x: x.args[0].key) + result = [] + cond_args = list(condition.args) if isinstance(condition, And) else [condition] + if args_list[0] in cond_args and not (is_random(args_list[0].args[0]) + and is_random(args_list[0].args[1])): + result.append(_SubstituteRV._probability(args_list[0])) + + if is_random(args_list[0].args[0]) and is_random(args_list[0].args[1]): + arg = self._solve_argwith_tworvs(args_list[0]) + result.append(_SubstituteRV._probability(arg)) + + for i in range(len(args_list) - 1): + curr, nex = args_list[i], args_list[i + 1] + diff_key = nex.args[0].key - curr.args[0].key + working_set = curr.args[0].pspace.process.state_space + if curr.args[1] > nex.args[1]: #impossible condition so return 0 + result.append(0) + break + if isinstance(curr, Eq): + working_set = Intersection(working_set, Interval.Lopen(curr.args[1], oo)) + else: + working_set = Intersection(working_set, curr.as_set()) + if isinstance(nex, Eq): + working_set = Intersection(working_set, Interval(-oo, nex.args[1])) + else: + working_set = Intersection(working_set, nex.as_set()) + if working_set == EmptySet: + rv = Eq(curr.args[0].pspace.process(diff_key), 0) + result.append(_SubstituteRV._probability(rv)) + else: + if working_set.is_finite_set: + if isinstance(curr, Eq) and isinstance(nex, Eq): + rv = Eq(curr.args[0].pspace.process(diff_key), len(working_set)) + result.append(_SubstituteRV._probability(rv)) + elif isinstance(curr, Eq) ^ isinstance(nex, Eq): + result.append(Add.fromiter(_SubstituteRV._probability(Eq( + curr.args[0].pspace.process(diff_key), x)) + for x in range(len(working_set)))) + else: + n = len(working_set) + result.append(Add.fromiter((n - x)*_SubstituteRV._probability(Eq( + curr.args[0].pspace.process(diff_key), x)) for x in range(n))) + else: + result.append(_SubstituteRV._probability( + curr.args[0].pspace.process(diff_key) <= working_set._sup - working_set._inf)) + return Mul.fromiter(result) + + + def probability(self, condition, given_condition=None, evaluate=True, **kwargs): + """ + Computes probability. + + Parameters + ========== + + condition: Relational + Condition for which probability has to be computed. Must + contain a RandomIndexedSymbol of the process. + given_condition: Relational, Boolean + The given conditions under which computations should be done, i.e, + the intervals on which each variable time stamp in expr is defined + + Returns + ======= + + Probability of the condition + + """ + check_numeric = True + if isinstance(condition, (And, Or)): + cond_args = condition.args + else: + cond_args = (condition, ) + # check that condition args are numeric or not + if not all(arg.args[0].key.is_number for arg in cond_args): + check_numeric = False + if given_condition is not None: + check_given_numeric = True + if isinstance(given_condition, (And, Or)): + given_cond_args = given_condition.args + else: + given_cond_args = (given_condition, ) + # check that given condition args are numeric or not + if given_condition.has(Contains): + check_given_numeric = False + # Handle numerical queries + if check_numeric and check_given_numeric: + res = [] + if isinstance(condition, Or): + res.append(Add.fromiter(self._solve_numerical(arg, given_condition) + for arg in condition.args)) + if isinstance(given_condition, Or): + res.append(Add.fromiter(self._solve_numerical(condition, arg) + for arg in given_condition.args)) + if res: + return Add.fromiter(res) + return self._solve_numerical(condition, given_condition) + + # No numeric queries, go by Contains?... then check that all the + # given condition are in form of `Contains` + if not all(arg.has(Contains) for arg in given_cond_args): + raise ValueError("If given condition is passed with `Contains`, then " + "please pass the evaluated condition with its corresponding information " + "in terms of intervals of each time stamp to be passed in given condition.") + + intervals, rv_swap = get_timerv_swaps(condition, given_condition) + # they are independent when they have non-overlapping intervals + if len(intervals) == 1 or all(Intersection(*intv_comb) == EmptySet + for intv_comb in itertools.combinations(intervals, 2)): + if isinstance(condition, And): + return Mul.fromiter(self.probability(arg, given_condition) + for arg in condition.args) + elif isinstance(condition, Or): + return Add.fromiter(self.probability(arg, given_condition) + for arg in condition.args) + condition = condition.subs(rv_swap) + else: + return Probability(condition, given_condition) + if check_numeric: + return self._solve_numerical(condition) + return _SubstituteRV._probability(condition, evaluate=evaluate, **kwargs) + +class PoissonProcess(CountingProcess): + """ + The Poisson process is a counting process. It is usually used in scenarios + where we are counting the occurrences of certain events that appear + to happen at a certain rate, but completely at random. + + Parameters + ========== + + sym : Symbol/str + lamda : Positive number + Rate of the process, ``lambda > 0`` + + Examples + ======== + + >>> from sympy.stats import PoissonProcess, P, E + >>> from sympy import symbols, Eq, Ne, Contains, Interval + >>> X = PoissonProcess("X", lamda=3) + >>> X.state_space + Naturals0 + >>> X.lamda + 3 + >>> t1, t2 = symbols('t1 t2', positive=True) + >>> P(X(t1) < 4) + (9*t1**3/2 + 9*t1**2/2 + 3*t1 + 1)*exp(-3*t1) + >>> P(Eq(X(t1), 2) | Ne(X(t1), 4), Contains(t1, Interval.Ropen(2, 4))) + 1 - 36*exp(-6) + >>> P(Eq(X(t1), 2) & Eq(X(t2), 3), Contains(t1, Interval.Lopen(0, 2)) + ... & Contains(t2, Interval.Lopen(2, 4))) + 648*exp(-12) + >>> E(X(t1)) + 3*t1 + >>> E(X(t1)**2 + 2*X(t2), Contains(t1, Interval.Lopen(0, 1)) + ... & Contains(t2, Interval.Lopen(1, 2))) + 18 + >>> P(X(3) < 1, Eq(X(1), 0)) + exp(-6) + >>> P(Eq(X(4), 3), Eq(X(2), 3)) + exp(-6) + >>> P(X(2) <= 3, X(1) > 1) + 5*exp(-3) + + Merging two Poisson Processes + + >>> Y = PoissonProcess("Y", lamda=4) + >>> Z = X + Y + >>> Z.lamda + 7 + + Splitting a Poisson Process into two independent Poisson Processes + + >>> N, M = Z.split(l1=2, l2=5) + >>> N.lamda, M.lamda + (2, 5) + + References + ========== + + .. [1] https://www.probabilitycourse.com/chapter11/11_0_0_intro.php + .. [2] https://en.wikipedia.org/wiki/Poisson_point_process + + """ + + def __new__(cls, sym, lamda): + _value_check(lamda > 0, 'lamda should be a positive number.') + sym = _symbol_converter(sym) + lamda = _sympify(lamda) + return Basic.__new__(cls, sym, lamda) + + @property + def lamda(self): + return self.args[1] + + @property + def state_space(self): + return S.Naturals0 + + def distribution(self, key): + if isinstance(key, RandomIndexedSymbol): + self._deprecation_warn_distribution() + return PoissonDistribution(self.lamda*key.key) + return PoissonDistribution(self.lamda*key) + + def density(self, x): + return (self.lamda*x.key)**x / factorial(x) * exp(-(self.lamda*x.key)) + + def simple_rv(self, rv): + return Poisson(rv.name, lamda=self.lamda*rv.key) + + def __add__(self, other): + if not isinstance(other, PoissonProcess): + raise ValueError("Only instances of Poisson Process can be merged") + return PoissonProcess(Dummy(self.symbol.name + other.symbol.name), + self.lamda + other.lamda) + + def split(self, l1, l2): + if _sympify(l1 + l2) != self.lamda: + raise ValueError("Sum of l1 and l2 should be %s" % str(self.lamda)) + return PoissonProcess(Dummy("l1"), l1), PoissonProcess(Dummy("l2"), l2) + +class WienerProcess(CountingProcess): + """ + The Wiener process is a real valued continuous-time stochastic process. + In physics it is used to study Brownian motion and it is often also called + Brownian motion due to its historical connection with physical process of the + same name originally observed by Scottish botanist Robert Brown. + + Parameters + ========== + + sym : Symbol/str + + Examples + ======== + + >>> from sympy.stats import WienerProcess, P, E + >>> from sympy import symbols, Contains, Interval + >>> X = WienerProcess("X") + >>> X.state_space + Reals + >>> t1, t2 = symbols('t1 t2', positive=True) + >>> P(X(t1) < 7).simplify() + erf(7*sqrt(2)/(2*sqrt(t1)))/2 + 1/2 + >>> P((X(t1) > 2) | (X(t1) < 4), Contains(t1, Interval.Ropen(2, 4))).simplify() + -erf(1)/2 + erf(2)/2 + 1 + >>> E(X(t1)) + 0 + >>> E(X(t1) + 2*X(t2), Contains(t1, Interval.Lopen(0, 1)) + ... & Contains(t2, Interval.Lopen(1, 2))) + 0 + + References + ========== + + .. [1] https://www.probabilitycourse.com/chapter11/11_4_0_brownian_motion_wiener_process.php + .. [2] https://en.wikipedia.org/wiki/Wiener_process + + """ + def __new__(cls, sym): + sym = _symbol_converter(sym) + return Basic.__new__(cls, sym) + + @property + def state_space(self): + return S.Reals + + def distribution(self, key): + if isinstance(key, RandomIndexedSymbol): + self._deprecation_warn_distribution() + return NormalDistribution(0, sqrt(key.key)) + return NormalDistribution(0, sqrt(key)) + + def density(self, x): + return exp(-x**2/(2*x.key)) / (sqrt(2*pi)*sqrt(x.key)) + + def simple_rv(self, rv): + return Normal(rv.name, 0, sqrt(rv.key)) + + +class GammaProcess(CountingProcess): + r""" + A Gamma process is a random process with independent gamma distributed + increments. It is a pure-jump increasing Levy process. + + Parameters + ========== + + sym : Symbol/str + lamda : Positive number + Jump size of the process, ``lamda > 0`` + gamma : Positive number + Rate of jump arrivals, `\gamma > 0` + + Examples + ======== + + >>> from sympy.stats import GammaProcess, E, P, variance + >>> from sympy import symbols, Contains, Interval, Not + >>> t, d, x, l, g = symbols('t d x l g', positive=True) + >>> X = GammaProcess("X", l, g) + >>> E(X(t)) + g*t/l + >>> variance(X(t)).simplify() + g*t/l**2 + >>> X = GammaProcess('X', 1, 2) + >>> P(X(t) < 1).simplify() + lowergamma(2*t, 1)/gamma(2*t) + >>> P(Not((X(t) < 5) & (X(d) > 3)), Contains(t, Interval.Ropen(2, 4)) & + ... Contains(d, Interval.Lopen(7, 8))).simplify() + -4*exp(-3) + 472*exp(-8)/3 + 1 + >>> E(X(2) + x*E(X(5))) + 10*x + 4 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gamma_process + + """ + def __new__(cls, sym, lamda, gamma): + _value_check(lamda > 0, 'lamda should be a positive number') + _value_check(gamma > 0, 'gamma should be a positive number') + sym = _symbol_converter(sym) + gamma = _sympify(gamma) + lamda = _sympify(lamda) + return Basic.__new__(cls, sym, lamda, gamma) + + @property + def lamda(self): + return self.args[1] + + @property + def gamma(self): + return self.args[2] + + @property + def state_space(self): + return _set_converter(Interval(0, oo)) + + def distribution(self, key): + if isinstance(key, RandomIndexedSymbol): + self._deprecation_warn_distribution() + return GammaDistribution(self.gamma*key.key, 1/self.lamda) + return GammaDistribution(self.gamma*key, 1/self.lamda) + + def density(self, x): + k = self.gamma*x.key + theta = 1/self.lamda + return x**(k - 1) * exp(-x/theta) / (gamma(k)*theta**k) + + def simple_rv(self, rv): + return Gamma(rv.name, self.gamma*rv.key, 1/self.lamda) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/symbolic_multivariate_probability.py b/.venv/lib/python3.13/site-packages/sympy/stats/symbolic_multivariate_probability.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe8776e58e82489e29734cea48c9138bc512f34 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/symbolic_multivariate_probability.py @@ -0,0 +1,308 @@ +import itertools + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import expand as _expand +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.matrices.exceptions import ShapeError +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.matmul import MatMul +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.stats.rv import RandomSymbol, is_random +from sympy.core.sympify import _sympify +from sympy.stats.symbolic_probability import Variance, Covariance, Expectation + + +class ExpectationMatrix(Expectation, MatrixExpr): + """ + Expectation of a random matrix expression. + + Examples + ======== + + >>> from sympy.stats import ExpectationMatrix, Normal + >>> from sympy.stats.rv import RandomMatrixSymbol + >>> from sympy import symbols, MatrixSymbol, Matrix + >>> k = symbols("k") + >>> A, B = MatrixSymbol("A", k, k), MatrixSymbol("B", k, k) + >>> X, Y = RandomMatrixSymbol("X", k, 1), RandomMatrixSymbol("Y", k, 1) + >>> ExpectationMatrix(X) + ExpectationMatrix(X) + >>> ExpectationMatrix(A*X).shape + (k, 1) + + To expand the expectation in its expression, use ``expand()``: + + >>> ExpectationMatrix(A*X + B*Y).expand() + A*ExpectationMatrix(X) + B*ExpectationMatrix(Y) + >>> ExpectationMatrix((X + Y)*(X - Y).T).expand() + ExpectationMatrix(X*X.T) - ExpectationMatrix(X*Y.T) + ExpectationMatrix(Y*X.T) - ExpectationMatrix(Y*Y.T) + + To evaluate the ``ExpectationMatrix``, use ``doit()``: + + >>> N11, N12 = Normal('N11', 11, 1), Normal('N12', 12, 1) + >>> N21, N22 = Normal('N21', 21, 1), Normal('N22', 22, 1) + >>> M11, M12 = Normal('M11', 1, 1), Normal('M12', 2, 1) + >>> M21, M22 = Normal('M21', 3, 1), Normal('M22', 4, 1) + >>> x1 = Matrix([[N11, N12], [N21, N22]]) + >>> x2 = Matrix([[M11, M12], [M21, M22]]) + >>> ExpectationMatrix(x1 + x2).doit() + Matrix([ + [12, 14], + [24, 26]]) + + """ + def __new__(cls, expr, condition=None): + expr = _sympify(expr) + if condition is None: + if not is_random(expr): + return expr + obj = Expr.__new__(cls, expr) + else: + condition = _sympify(condition) + obj = Expr.__new__(cls, expr, condition) + + obj._shape = expr.shape + obj._condition = condition + return obj + + @property + def shape(self): + return self._shape + + def expand(self, **hints): + expr = self.args[0] + condition = self._condition + if not is_random(expr): + return expr + + if isinstance(expr, Add): + return Add.fromiter(Expectation(a, condition=condition).expand() + for a in expr.args) + + expand_expr = _expand(expr) + if isinstance(expand_expr, Add): + return Add.fromiter(Expectation(a, condition=condition).expand() + for a in expand_expr.args) + + elif isinstance(expr, (Mul, MatMul)): + rv = [] + nonrv = [] + postnon = [] + + for a in expr.args: + if is_random(a): + if rv: + rv.extend(postnon) + else: + nonrv.extend(postnon) + postnon = [] + rv.append(a) + elif a.is_Matrix: + postnon.append(a) + else: + nonrv.append(a) + + # In order to avoid infinite-looping (MatMul may call .doit() again), + # do not rebuild + if len(nonrv) == 0: + return self + return Mul.fromiter(nonrv)*Expectation(Mul.fromiter(rv), + condition=condition)*Mul.fromiter(postnon) + + return self + +class VarianceMatrix(Variance, MatrixExpr): + """ + Variance of a random matrix probability expression. Also known as + Covariance matrix, auto-covariance matrix, dispersion matrix, + or variance-covariance matrix. + + Examples + ======== + + >>> from sympy.stats import VarianceMatrix + >>> from sympy.stats.rv import RandomMatrixSymbol + >>> from sympy import symbols, MatrixSymbol + >>> k = symbols("k") + >>> A, B = MatrixSymbol("A", k, k), MatrixSymbol("B", k, k) + >>> X, Y = RandomMatrixSymbol("X", k, 1), RandomMatrixSymbol("Y", k, 1) + >>> VarianceMatrix(X) + VarianceMatrix(X) + >>> VarianceMatrix(X).shape + (k, k) + + To expand the variance in its expression, use ``expand()``: + + >>> VarianceMatrix(A*X).expand() + A*VarianceMatrix(X)*A.T + >>> VarianceMatrix(A*X + B*Y).expand() + 2*A*CrossCovarianceMatrix(X, Y)*B.T + A*VarianceMatrix(X)*A.T + B*VarianceMatrix(Y)*B.T + """ + def __new__(cls, arg, condition=None): + arg = _sympify(arg) + + if 1 not in arg.shape: + raise ShapeError("Expression is not a vector") + + shape = (arg.shape[0], arg.shape[0]) if arg.shape[1] == 1 else (arg.shape[1], arg.shape[1]) + + if condition: + obj = Expr.__new__(cls, arg, condition) + else: + obj = Expr.__new__(cls, arg) + + obj._shape = shape + obj._condition = condition + return obj + + @property + def shape(self): + return self._shape + + def expand(self, **hints): + arg = self.args[0] + condition = self._condition + + if not is_random(arg): + return ZeroMatrix(*self.shape) + + if isinstance(arg, RandomSymbol): + return self + elif isinstance(arg, Add): + rv = [] + for a in arg.args: + if is_random(a): + rv.append(a) + variances = Add(*(Variance(xv, condition).expand() for xv in rv)) + map_to_covar = lambda x: 2*Covariance(*x, condition=condition).expand() + covariances = Add(*map(map_to_covar, itertools.combinations(rv, 2))) + return variances + covariances + elif isinstance(arg, (Mul, MatMul)): + nonrv = [] + rv = [] + for a in arg.args: + if is_random(a): + rv.append(a) + else: + nonrv.append(a) + if len(rv) == 0: + return ZeroMatrix(*self.shape) + # Avoid possible infinite loops with MatMul: + if len(nonrv) == 0: + return self + # Variance of many multiple matrix products is not implemented: + if len(rv) > 1: + return self + return Mul.fromiter(nonrv)*Variance(Mul.fromiter(rv), + condition)*(Mul.fromiter(nonrv)).transpose() + + # this expression contains a RandomSymbol somehow: + return self + +class CrossCovarianceMatrix(Covariance, MatrixExpr): + """ + Covariance of a random matrix probability expression. + + Examples + ======== + + >>> from sympy.stats import CrossCovarianceMatrix + >>> from sympy.stats.rv import RandomMatrixSymbol + >>> from sympy import symbols, MatrixSymbol + >>> k = symbols("k") + >>> A, B = MatrixSymbol("A", k, k), MatrixSymbol("B", k, k) + >>> C, D = MatrixSymbol("C", k, k), MatrixSymbol("D", k, k) + >>> X, Y = RandomMatrixSymbol("X", k, 1), RandomMatrixSymbol("Y", k, 1) + >>> Z, W = RandomMatrixSymbol("Z", k, 1), RandomMatrixSymbol("W", k, 1) + >>> CrossCovarianceMatrix(X, Y) + CrossCovarianceMatrix(X, Y) + >>> CrossCovarianceMatrix(X, Y).shape + (k, k) + + To expand the covariance in its expression, use ``expand()``: + + >>> CrossCovarianceMatrix(X + Y, Z).expand() + CrossCovarianceMatrix(X, Z) + CrossCovarianceMatrix(Y, Z) + >>> CrossCovarianceMatrix(A*X, Y).expand() + A*CrossCovarianceMatrix(X, Y) + >>> CrossCovarianceMatrix(A*X, B.T*Y).expand() + A*CrossCovarianceMatrix(X, Y)*B + >>> CrossCovarianceMatrix(A*X + B*Y, C.T*Z + D.T*W).expand() + A*CrossCovarianceMatrix(X, W)*D + A*CrossCovarianceMatrix(X, Z)*C + B*CrossCovarianceMatrix(Y, W)*D + B*CrossCovarianceMatrix(Y, Z)*C + + """ + def __new__(cls, arg1, arg2, condition=None): + arg1 = _sympify(arg1) + arg2 = _sympify(arg2) + + if (1 not in arg1.shape) or (1 not in arg2.shape) or (arg1.shape[1] != arg2.shape[1]): + raise ShapeError("Expression is not a vector") + + shape = (arg1.shape[0], arg2.shape[0]) if arg1.shape[1] == 1 and arg2.shape[1] == 1 \ + else (1, 1) + + if condition: + obj = Expr.__new__(cls, arg1, arg2, condition) + else: + obj = Expr.__new__(cls, arg1, arg2) + + obj._shape = shape + obj._condition = condition + return obj + + @property + def shape(self): + return self._shape + + def expand(self, **hints): + arg1 = self.args[0] + arg2 = self.args[1] + condition = self._condition + + if arg1 == arg2: + return VarianceMatrix(arg1, condition).expand() + + if not is_random(arg1) or not is_random(arg2): + return ZeroMatrix(*self.shape) + + if isinstance(arg1, RandomSymbol) and isinstance(arg2, RandomSymbol): + return CrossCovarianceMatrix(arg1, arg2, condition) + + coeff_rv_list1 = self._expand_single_argument(arg1.expand()) + coeff_rv_list2 = self._expand_single_argument(arg2.expand()) + + addends = [a*CrossCovarianceMatrix(r1, r2, condition=condition)*b.transpose() + for (a, r1) in coeff_rv_list1 for (b, r2) in coeff_rv_list2] + return Add.fromiter(addends) + + @classmethod + def _expand_single_argument(cls, expr): + # return (coefficient, random_symbol) pairs: + if isinstance(expr, RandomSymbol): + return [(S.One, expr)] + elif isinstance(expr, Add): + outval = [] + for a in expr.args: + if isinstance(a, (Mul, MatMul)): + outval.append(cls._get_mul_nonrv_rv_tuple(a)) + elif is_random(a): + outval.append((S.One, a)) + + return outval + elif isinstance(expr, (Mul, MatMul)): + return [cls._get_mul_nonrv_rv_tuple(expr)] + elif is_random(expr): + return [(S.One, expr)] + + @classmethod + def _get_mul_nonrv_rv_tuple(cls, m): + rv = [] + nonrv = [] + for a in m.args: + if is_random(a): + rv.append(a) + else: + nonrv.append(a) + return (Mul.fromiter(nonrv), Mul.fromiter(rv)) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/symbolic_probability.py b/.venv/lib/python3.13/site-packages/sympy/stats/symbolic_probability.py new file mode 100644 index 0000000000000000000000000000000000000000..5d0b971a8691f82de15258d4c460129059eaf436 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/symbolic_probability.py @@ -0,0 +1,698 @@ +import itertools +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import expand as _expand +from sympy.core.mul import Mul +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import Not +from sympy.core.parameters import global_parameters +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import _sympify +from sympy.core.relational import Relational +from sympy.logic.boolalg import Boolean +from sympy.stats import variance, covariance +from sympy.stats.rv import (RandomSymbol, pspace, dependent, + given, sampling_E, RandomIndexedSymbol, is_random, + PSpace, sampling_P, random_symbols) + +__all__ = ['Probability', 'Expectation', 'Variance', 'Covariance'] + + +@is_random.register(Expr) +def _(x): + atoms = x.free_symbols + if len(atoms) == 1 and next(iter(atoms)) == x: + return False + return any(is_random(i) for i in atoms) + +@is_random.register(RandomSymbol) # type: ignore +def _(x): + return True + + +class Probability(Expr): + """ + Symbolic expression for the probability. + + Examples + ======== + + >>> from sympy.stats import Probability, Normal + >>> from sympy import Integral + >>> X = Normal("X", 0, 1) + >>> prob = Probability(X > 1) + >>> prob + Probability(X > 1) + + Integral representation: + + >>> prob.rewrite(Integral) + Integral(sqrt(2)*exp(-_z**2/2)/(2*sqrt(pi)), (_z, 1, oo)) + + Evaluation of the integral: + + >>> prob.evaluate_integral() + sqrt(2)*(-sqrt(2)*sqrt(pi)*erf(sqrt(2)/2) + sqrt(2)*sqrt(pi))/(4*sqrt(pi)) + """ + + is_commutative = True + + def __new__(cls, prob, condition=None, **kwargs): + prob = _sympify(prob) + if condition is None: + obj = Expr.__new__(cls, prob) + else: + condition = _sympify(condition) + obj = Expr.__new__(cls, prob, condition) + obj._condition = condition + return obj + + def doit(self, **hints): + condition = self.args[0] + given_condition = self._condition + numsamples = hints.get('numsamples', False) + evaluate = hints.get('evaluate', True) + + if isinstance(condition, Not): + return S.One - self.func(condition.args[0], given_condition, + evaluate=evaluate).doit(**hints) + + if condition.has(RandomIndexedSymbol): + return pspace(condition).probability(condition, given_condition, + evaluate=evaluate) + + if isinstance(given_condition, RandomSymbol): + condrv = random_symbols(condition) + if len(condrv) == 1 and condrv[0] == given_condition: + from sympy.stats.frv_types import BernoulliDistribution + return BernoulliDistribution(self.func(condition).doit(**hints), 0, 1) + if any(dependent(rv, given_condition) for rv in condrv): + return Probability(condition, given_condition) + else: + return Probability(condition).doit() + + if given_condition is not None and \ + not isinstance(given_condition, (Relational, Boolean)): + raise ValueError("%s is not a relational or combination of relationals" + % (given_condition)) + + if given_condition == False or condition is S.false: + return S.Zero + if not isinstance(condition, (Relational, Boolean)): + raise ValueError("%s is not a relational or combination of relationals" + % (condition)) + if condition is S.true: + return S.One + + if numsamples: + return sampling_P(condition, given_condition, numsamples=numsamples) + if given_condition is not None: # If there is a condition + # Recompute on new conditional expr + return Probability(given(condition, given_condition)).doit() + + # Otherwise pass work off to the ProbabilitySpace + if pspace(condition) == PSpace(): + return Probability(condition, given_condition) + + result = pspace(condition).probability(condition) + if hasattr(result, 'doit') and evaluate: + return result.doit() + else: + return result + + def _eval_rewrite_as_Integral(self, arg, condition=None, **kwargs): + return self.func(arg, condition=condition).doit(evaluate=False) + + _eval_rewrite_as_Sum = _eval_rewrite_as_Integral + + def evaluate_integral(self): + return self.rewrite(Integral).doit() + + +class Expectation(Expr): + """ + Symbolic expression for the expectation. + + Examples + ======== + + >>> from sympy.stats import Expectation, Normal, Probability, Poisson + >>> from sympy import symbols, Integral, Sum + >>> mu = symbols("mu") + >>> sigma = symbols("sigma", positive=True) + >>> X = Normal("X", mu, sigma) + >>> Expectation(X) + Expectation(X) + >>> Expectation(X).evaluate_integral().simplify() + mu + + To get the integral expression of the expectation: + + >>> Expectation(X).rewrite(Integral) + Integral(sqrt(2)*X*exp(-(X - mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (X, -oo, oo)) + + The same integral expression, in more abstract terms: + + >>> Expectation(X).rewrite(Probability) + Integral(x*Probability(Eq(X, x)), (x, -oo, oo)) + + To get the Summation expression of the expectation for discrete random variables: + + >>> lamda = symbols('lamda', positive=True) + >>> Z = Poisson('Z', lamda) + >>> Expectation(Z).rewrite(Sum) + Sum(Z*lamda**Z*exp(-lamda)/factorial(Z), (Z, 0, oo)) + + This class is aware of some properties of the expectation: + + >>> from sympy.abc import a + >>> Expectation(a*X) + Expectation(a*X) + >>> Y = Normal("Y", 1, 2) + >>> Expectation(X + Y) + Expectation(X + Y) + + To expand the ``Expectation`` into its expression, use ``expand()``: + + >>> Expectation(X + Y).expand() + Expectation(X) + Expectation(Y) + >>> Expectation(a*X + Y).expand() + a*Expectation(X) + Expectation(Y) + >>> Expectation(a*X + Y) + Expectation(a*X + Y) + >>> Expectation((X + Y)*(X - Y)).expand() + Expectation(X**2) - Expectation(Y**2) + + To evaluate the ``Expectation``, use ``doit()``: + + >>> Expectation(X + Y).doit() + mu + 1 + >>> Expectation(X + Expectation(Y + Expectation(2*X))).doit() + 3*mu + 1 + + To prevent evaluating nested ``Expectation``, use ``doit(deep=False)`` + + >>> Expectation(X + Expectation(Y)).doit(deep=False) + mu + Expectation(Expectation(Y)) + >>> Expectation(X + Expectation(Y + Expectation(2*X))).doit(deep=False) + mu + Expectation(Expectation(Expectation(2*X) + Y)) + + """ + + def __new__(cls, expr, condition=None, **kwargs): + expr = _sympify(expr) + if expr.is_Matrix: + from sympy.stats.symbolic_multivariate_probability import ExpectationMatrix + return ExpectationMatrix(expr, condition) + if condition is None: + if not is_random(expr): + return expr + obj = Expr.__new__(cls, expr) + else: + condition = _sympify(condition) + obj = Expr.__new__(cls, expr, condition) + obj._condition = condition + return obj + + def _eval_is_commutative(self): + return(self.args[0].is_commutative) + + def expand(self, **hints): + expr = self.args[0] + condition = self._condition + + if not is_random(expr): + return expr + + if isinstance(expr, Add): + return Add.fromiter(Expectation(a, condition=condition).expand() + for a in expr.args) + + expand_expr = _expand(expr) + if isinstance(expand_expr, Add): + return Add.fromiter(Expectation(a, condition=condition).expand() + for a in expand_expr.args) + + elif isinstance(expr, Mul): + rv = [] + nonrv = [] + for a in expr.args: + if is_random(a): + rv.append(a) + else: + nonrv.append(a) + return Mul.fromiter(nonrv)*Expectation(Mul.fromiter(rv), condition=condition) + + return self + + def doit(self, **hints): + deep = hints.get('deep', True) + condition = self._condition + expr = self.args[0] + numsamples = hints.get('numsamples', False) + evaluate = hints.get('evaluate', True) + + if deep: + expr = expr.doit(**hints) + + if not is_random(expr) or isinstance(expr, Expectation): # expr isn't random? + return expr + if numsamples: # Computing by monte carlo sampling? + evalf = hints.get('evalf', True) + return sampling_E(expr, condition, numsamples=numsamples, evalf=evalf) + + if expr.has(RandomIndexedSymbol): + return pspace(expr).compute_expectation(expr, condition) + + # Create new expr and recompute E + if condition is not None: # If there is a condition + return self.func(given(expr, condition)).doit(**hints) + + # A few known statements for efficiency + + if expr.is_Add: # We know that E is Linear + return Add(*[self.func(arg, condition).doit(**hints) + if not isinstance(arg, Expectation) else self.func(arg, condition) + for arg in expr.args]) + if expr.is_Mul: + if expr.atoms(Expectation): + return expr + + if pspace(expr) == PSpace(): + return self.func(expr) + # Otherwise case is simple, pass work off to the ProbabilitySpace + result = pspace(expr).compute_expectation(expr, evaluate=evaluate) + if hasattr(result, 'doit') and evaluate: + return result.doit(**hints) + else: + return result + + + def _eval_rewrite_as_Probability(self, arg, condition=None, **kwargs): + rvs = arg.atoms(RandomSymbol) + if len(rvs) > 1: + raise NotImplementedError() + if len(rvs) == 0: + return arg + + rv = rvs.pop() + if rv.pspace is None: + raise ValueError("Probability space not known") + + symbol = rv.symbol + if symbol.name[0].isupper(): + symbol = Symbol(symbol.name.lower()) + else : + symbol = Symbol(symbol.name + "_1") + + if rv.pspace.is_Continuous: + return Integral(arg.replace(rv, symbol)*Probability(Eq(rv, symbol), condition), (symbol, rv.pspace.domain.set.inf, rv.pspace.domain.set.sup)) + else: + if rv.pspace.is_Finite: + raise NotImplementedError + else: + return Sum(arg.replace(rv, symbol)*Probability(Eq(rv, symbol), condition), (symbol, rv.pspace.domain.set.inf, rv.pspace.set.sup)) + + def _eval_rewrite_as_Integral(self, arg, condition=None, evaluate=False, **kwargs): + return self.func(arg, condition=condition).doit(deep=False, evaluate=evaluate) + + _eval_rewrite_as_Sum = _eval_rewrite_as_Integral # For discrete this will be Sum + + def evaluate_integral(self): + return self.rewrite(Integral).doit() + + evaluate_sum = evaluate_integral + +class Variance(Expr): + """ + Symbolic expression for the variance. + + Examples + ======== + + >>> from sympy import symbols, Integral + >>> from sympy.stats import Normal, Expectation, Variance, Probability + >>> mu = symbols("mu", positive=True) + >>> sigma = symbols("sigma", positive=True) + >>> X = Normal("X", mu, sigma) + >>> Variance(X) + Variance(X) + >>> Variance(X).evaluate_integral() + sigma**2 + + Integral representation of the underlying calculations: + + >>> Variance(X).rewrite(Integral) + Integral(sqrt(2)*(X - Integral(sqrt(2)*X*exp(-(X - mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (X, -oo, oo)))**2*exp(-(X - mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (X, -oo, oo)) + + Integral representation, without expanding the PDF: + + >>> Variance(X).rewrite(Probability) + -Integral(x*Probability(Eq(X, x)), (x, -oo, oo))**2 + Integral(x**2*Probability(Eq(X, x)), (x, -oo, oo)) + + Rewrite the variance in terms of the expectation + + >>> Variance(X).rewrite(Expectation) + -Expectation(X)**2 + Expectation(X**2) + + Some transformations based on the properties of the variance may happen: + + >>> from sympy.abc import a + >>> Y = Normal("Y", 0, 1) + >>> Variance(a*X) + Variance(a*X) + + To expand the variance in its expression, use ``expand()``: + + >>> Variance(a*X).expand() + a**2*Variance(X) + >>> Variance(X + Y) + Variance(X + Y) + >>> Variance(X + Y).expand() + 2*Covariance(X, Y) + Variance(X) + Variance(Y) + + """ + def __new__(cls, arg, condition=None, **kwargs): + arg = _sympify(arg) + + if arg.is_Matrix: + from sympy.stats.symbolic_multivariate_probability import VarianceMatrix + return VarianceMatrix(arg, condition) + if condition is None: + obj = Expr.__new__(cls, arg) + else: + condition = _sympify(condition) + obj = Expr.__new__(cls, arg, condition) + obj._condition = condition + return obj + + def _eval_is_commutative(self): + return self.args[0].is_commutative + + def expand(self, **hints): + arg = self.args[0] + condition = self._condition + + if not is_random(arg): + return S.Zero + + if isinstance(arg, RandomSymbol): + return self + elif isinstance(arg, Add): + rv = [] + for a in arg.args: + if is_random(a): + rv.append(a) + variances = Add(*(Variance(xv, condition).expand() for xv in rv)) + map_to_covar = lambda x: 2*Covariance(*x, condition=condition).expand() + covariances = Add(*map(map_to_covar, itertools.combinations(rv, 2))) + return variances + covariances + elif isinstance(arg, Mul): + nonrv = [] + rv = [] + for a in arg.args: + if is_random(a): + rv.append(a) + else: + nonrv.append(a**2) + if len(rv) == 0: + return S.Zero + return Mul.fromiter(nonrv)*Variance(Mul.fromiter(rv), condition) + + # this expression contains a RandomSymbol somehow: + return self + + def _eval_rewrite_as_Expectation(self, arg, condition=None, **kwargs): + e1 = Expectation(arg**2, condition) + e2 = Expectation(arg, condition)**2 + return e1 - e2 + + def _eval_rewrite_as_Probability(self, arg, condition=None, **kwargs): + return self.rewrite(Expectation).rewrite(Probability) + + def _eval_rewrite_as_Integral(self, arg, condition=None, **kwargs): + return variance(self.args[0], self._condition, evaluate=False) + + _eval_rewrite_as_Sum = _eval_rewrite_as_Integral + + def evaluate_integral(self): + return self.rewrite(Integral).doit() + + +class Covariance(Expr): + """ + Symbolic expression for the covariance. + + Examples + ======== + + >>> from sympy.stats import Covariance + >>> from sympy.stats import Normal + >>> X = Normal("X", 3, 2) + >>> Y = Normal("Y", 0, 1) + >>> Z = Normal("Z", 0, 1) + >>> W = Normal("W", 0, 1) + >>> cexpr = Covariance(X, Y) + >>> cexpr + Covariance(X, Y) + + Evaluate the covariance, `X` and `Y` are independent, + therefore zero is the result: + + >>> cexpr.evaluate_integral() + 0 + + Rewrite the covariance expression in terms of expectations: + + >>> from sympy.stats import Expectation + >>> cexpr.rewrite(Expectation) + Expectation(X*Y) - Expectation(X)*Expectation(Y) + + In order to expand the argument, use ``expand()``: + + >>> from sympy.abc import a, b, c, d + >>> Covariance(a*X + b*Y, c*Z + d*W) + Covariance(a*X + b*Y, c*Z + d*W) + >>> Covariance(a*X + b*Y, c*Z + d*W).expand() + a*c*Covariance(X, Z) + a*d*Covariance(W, X) + b*c*Covariance(Y, Z) + b*d*Covariance(W, Y) + + This class is aware of some properties of the covariance: + + >>> Covariance(X, X).expand() + Variance(X) + >>> Covariance(a*X, b*Y).expand() + a*b*Covariance(X, Y) + """ + + def __new__(cls, arg1, arg2, condition=None, **kwargs): + arg1 = _sympify(arg1) + arg2 = _sympify(arg2) + + if arg1.is_Matrix or arg2.is_Matrix: + from sympy.stats.symbolic_multivariate_probability import CrossCovarianceMatrix + return CrossCovarianceMatrix(arg1, arg2, condition) + + if kwargs.pop('evaluate', global_parameters.evaluate): + arg1, arg2 = sorted([arg1, arg2], key=default_sort_key) + + if condition is None: + obj = Expr.__new__(cls, arg1, arg2) + else: + condition = _sympify(condition) + obj = Expr.__new__(cls, arg1, arg2, condition) + obj._condition = condition + return obj + + def _eval_is_commutative(self): + return self.args[0].is_commutative + + def expand(self, **hints): + arg1 = self.args[0] + arg2 = self.args[1] + condition = self._condition + + if arg1 == arg2: + return Variance(arg1, condition).expand() + + if not is_random(arg1): + return S.Zero + if not is_random(arg2): + return S.Zero + + arg1, arg2 = sorted([arg1, arg2], key=default_sort_key) + + if isinstance(arg1, RandomSymbol) and isinstance(arg2, RandomSymbol): + return Covariance(arg1, arg2, condition) + + coeff_rv_list1 = self._expand_single_argument(arg1.expand()) + coeff_rv_list2 = self._expand_single_argument(arg2.expand()) + + addends = [a*b*Covariance(*sorted([r1, r2], key=default_sort_key), condition=condition) + for (a, r1) in coeff_rv_list1 for (b, r2) in coeff_rv_list2] + return Add.fromiter(addends) + + @classmethod + def _expand_single_argument(cls, expr): + # return (coefficient, random_symbol) pairs: + if isinstance(expr, RandomSymbol): + return [(S.One, expr)] + elif isinstance(expr, Add): + outval = [] + for a in expr.args: + if isinstance(a, Mul): + outval.append(cls._get_mul_nonrv_rv_tuple(a)) + elif is_random(a): + outval.append((S.One, a)) + + return outval + elif isinstance(expr, Mul): + return [cls._get_mul_nonrv_rv_tuple(expr)] + elif is_random(expr): + return [(S.One, expr)] + + @classmethod + def _get_mul_nonrv_rv_tuple(cls, m): + rv = [] + nonrv = [] + for a in m.args: + if is_random(a): + rv.append(a) + else: + nonrv.append(a) + return (Mul.fromiter(nonrv), Mul.fromiter(rv)) + + def _eval_rewrite_as_Expectation(self, arg1, arg2, condition=None, **kwargs): + e1 = Expectation(arg1*arg2, condition) + e2 = Expectation(arg1, condition)*Expectation(arg2, condition) + return e1 - e2 + + def _eval_rewrite_as_Probability(self, arg1, arg2, condition=None, **kwargs): + return self.rewrite(Expectation).rewrite(Probability) + + def _eval_rewrite_as_Integral(self, arg1, arg2, condition=None, **kwargs): + return covariance(self.args[0], self.args[1], self._condition, evaluate=False) + + _eval_rewrite_as_Sum = _eval_rewrite_as_Integral + + def evaluate_integral(self): + return self.rewrite(Integral).doit() + + +class Moment(Expr): + """ + Symbolic class for Moment + + Examples + ======== + + >>> from sympy import Symbol, Integral + >>> from sympy.stats import Normal, Expectation, Probability, Moment + >>> mu = Symbol('mu', real=True) + >>> sigma = Symbol('sigma', positive=True) + >>> X = Normal('X', mu, sigma) + >>> M = Moment(X, 3, 1) + + To evaluate the result of Moment use `doit`: + + >>> M.doit() + mu**3 - 3*mu**2 + 3*mu*sigma**2 + 3*mu - 3*sigma**2 - 1 + + Rewrite the Moment expression in terms of Expectation: + + >>> M.rewrite(Expectation) + Expectation((X - 1)**3) + + Rewrite the Moment expression in terms of Probability: + + >>> M.rewrite(Probability) + Integral((x - 1)**3*Probability(Eq(X, x)), (x, -oo, oo)) + + Rewrite the Moment expression in terms of Integral: + + >>> M.rewrite(Integral) + Integral(sqrt(2)*(X - 1)**3*exp(-(X - mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (X, -oo, oo)) + + """ + def __new__(cls, X, n, c=0, condition=None, **kwargs): + X = _sympify(X) + n = _sympify(n) + c = _sympify(c) + if condition is not None: + condition = _sympify(condition) + return super().__new__(cls, X, n, c, condition) + else: + return super().__new__(cls, X, n, c) + + def doit(self, **hints): + return self.rewrite(Expectation).doit(**hints) + + def _eval_rewrite_as_Expectation(self, X, n, c=0, condition=None, **kwargs): + return Expectation((X - c)**n, condition) + + def _eval_rewrite_as_Probability(self, X, n, c=0, condition=None, **kwargs): + return self.rewrite(Expectation).rewrite(Probability) + + def _eval_rewrite_as_Integral(self, X, n, c=0, condition=None, **kwargs): + return self.rewrite(Expectation).rewrite(Integral) + + +class CentralMoment(Expr): + """ + Symbolic class Central Moment + + Examples + ======== + + >>> from sympy import Symbol, Integral + >>> from sympy.stats import Normal, Expectation, Probability, CentralMoment + >>> mu = Symbol('mu', real=True) + >>> sigma = Symbol('sigma', positive=True) + >>> X = Normal('X', mu, sigma) + >>> CM = CentralMoment(X, 4) + + To evaluate the result of CentralMoment use `doit`: + + >>> CM.doit().simplify() + 3*sigma**4 + + Rewrite the CentralMoment expression in terms of Expectation: + + >>> CM.rewrite(Expectation) + Expectation((-Expectation(X) + X)**4) + + Rewrite the CentralMoment expression in terms of Probability: + + >>> CM.rewrite(Probability) + Integral((x - Integral(x*Probability(True), (x, -oo, oo)))**4*Probability(Eq(X, x)), (x, -oo, oo)) + + Rewrite the CentralMoment expression in terms of Integral: + + >>> CM.rewrite(Integral) + Integral(sqrt(2)*(X - Integral(sqrt(2)*X*exp(-(X - mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (X, -oo, oo)))**4*exp(-(X - mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (X, -oo, oo)) + + """ + def __new__(cls, X, n, condition=None, **kwargs): + X = _sympify(X) + n = _sympify(n) + if condition is not None: + condition = _sympify(condition) + return super().__new__(cls, X, n, condition) + else: + return super().__new__(cls, X, n) + + def doit(self, **hints): + return self.rewrite(Expectation).doit(**hints) + + def _eval_rewrite_as_Expectation(self, X, n, condition=None, **kwargs): + mu = Expectation(X, condition, **kwargs) + return Moment(X, n, mu, condition, **kwargs).rewrite(Expectation) + + def _eval_rewrite_as_Probability(self, X, n, condition=None, **kwargs): + return self.rewrite(Expectation).rewrite(Probability) + + def _eval_rewrite_as_Integral(self, X, n, condition=None, **kwargs): + return self.rewrite(Expectation).rewrite(Integral) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_compound_rv.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_compound_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..573ba364b686738e56bb1c4615acd2a9bc8bf3ae --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_compound_rv.py @@ -0,0 +1,159 @@ +from sympy.concrete.summations import Sum +from sympy.core.numbers import (oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.beta_functions import beta +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import Integral +from sympy.sets.sets import Interval +from sympy.stats import (Normal, P, E, density, Gamma, Poisson, Rayleigh, + variance, Bernoulli, Beta, Uniform, cdf) +from sympy.stats.compound_rv import CompoundDistribution, CompoundPSpace +from sympy.stats.crv_types import NormalDistribution +from sympy.stats.drv_types import PoissonDistribution +from sympy.stats.frv_types import BernoulliDistribution +from sympy.testing.pytest import raises, ignore_warnings +from sympy.stats.joint_rv_types import MultivariateNormalDistribution + +from sympy.abc import x + + +# helpers for testing troublesome unevaluated expressions +flat = lambda s: ''.join(str(s).split()) +streq = lambda *a: len(set(map(flat, a))) == 1 +assert streq(x, x) +assert streq(x, 'x') +assert not streq(x, x + 1) + + +def test_normal_CompoundDist(): + X = Normal('X', 1, 2) + Y = Normal('X', X, 4) + assert density(Y)(x).simplify() == sqrt(10)*exp(-x**2/40 + x/20 - S(1)/40)/(20*sqrt(pi)) + assert E(Y) == 1 # it is always equal to mean of X + assert P(Y > 1) == S(1)/2 # as 1 is the mean + assert P(Y > 5).simplify() == S(1)/2 - erf(sqrt(10)/5)/2 + assert variance(Y) == variance(X) + 4**2 # 2**2 + 4**2 + # https://math.stackexchange.com/questions/1484451/ + # (Contains proof of E and variance computation) + + +def test_poisson_CompoundDist(): + k, t, y = symbols('k t y', positive=True, real=True) + G = Gamma('G', k, t) + D = Poisson('P', G) + assert density(D)(y).simplify() == t**y*(t + 1)**(-k - y)*gamma(k + y)/(gamma(k)*gamma(y + 1)) + # https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture + assert E(D).simplify() == k*t # mean of NegativeBinomialDistribution + + +def test_bernoulli_CompoundDist(): + X = Beta('X', 1, 2) + Y = Bernoulli('Y', X) + assert density(Y).dict == {0: S(2)/3, 1: S(1)/3} + assert E(Y) == P(Eq(Y, 1)) == S(1)/3 + assert variance(Y) == S(2)/9 + assert cdf(Y) == {0: S(2)/3, 1: 1} + + # test issue 8128 + a = Bernoulli('a', S(1)/2) + b = Bernoulli('b', a) + assert density(b).dict == {0: S(1)/2, 1: S(1)/2} + assert P(b > 0.5) == S(1)/2 + + X = Uniform('X', 0, 1) + Y = Bernoulli('Y', X) + assert E(Y) == S(1)/2 + assert P(Eq(Y, 1)) == E(Y) + + +def test_unevaluated_CompoundDist(): + # these tests need to be removed once they work with evaluation as they are currently not + # evaluated completely in sympy. + R = Rayleigh('R', 4) + X = Normal('X', 3, R) + ans = ''' + Piecewise(((-sqrt(pi)*sinh(x/4 - 3/4) + sqrt(pi)*cosh(x/4 - 3/4))/( + 8*sqrt(pi)), Abs(arg(x - 3)) <= pi/4), (Integral(sqrt(2)*exp(-(x - 3) + **2/(2*R**2))*exp(-R**2/32)/(32*sqrt(pi)), (R, 0, oo)), True))''' + assert streq(density(X)(x), ans) + + expre = ''' + Integral(X*Integral(sqrt(2)*exp(-(X-3)**2/(2*R**2))*exp(-R**2/32)/(32* + sqrt(pi)),(R,0,oo)),(X,-oo,oo))''' + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert streq(E(X, evaluate=False).rewrite(Integral), expre) + + X = Poisson('X', 1) + Y = Poisson('Y', X) + Z = Poisson('Z', Y) + exprd = Sum(exp(-Y)*Y**x*Sum(exp(-1)*exp(-X)*X**Y/(factorial(X)*factorial(Y) + ), (X, 0, oo))/factorial(x), (Y, 0, oo)) + assert density(Z)(x) == exprd + + N = Normal('N', 1, 2) + M = Normal('M', 3, 4) + D = Normal('D', M, N) + exprd = ''' + Integral(sqrt(2)*exp(-(N-1)**2/8)*Integral(exp(-(x-M)**2/(2*N**2))*exp + (-(M-3)**2/32)/(8*pi*N),(M,-oo,oo))/(4*sqrt(pi)),(N,-oo,oo))''' + assert streq(density(D, evaluate=False)(x), exprd) + + +def test_Compound_Distribution(): + X = Normal('X', 2, 4) + N = NormalDistribution(X, 4) + C = CompoundDistribution(N) + assert C.is_Continuous + assert C.set == Interval(-oo, oo) + assert C.pdf(x, evaluate=True).simplify() == exp(-x**2/64 + x/16 - S(1)/16)/(8*sqrt(pi)) + + assert not isinstance(CompoundDistribution(NormalDistribution(2, 3)), + CompoundDistribution) + M = MultivariateNormalDistribution([1, 2], [[2, 1], [1, 2]]) + raises(NotImplementedError, lambda: CompoundDistribution(M)) + + X = Beta('X', 2, 4) + B = BernoulliDistribution(X, 1, 0) + C = CompoundDistribution(B) + assert C.is_Finite + assert C.set == {0, 1} + y = symbols('y', negative=False, integer=True) + assert C.pdf(y, evaluate=True) == Piecewise((S(1)/(30*beta(2, 4)), Eq(y, 0)), + (S(1)/(60*beta(2, 4)), Eq(y, 1)), (0, True)) + + k, t, z = symbols('k t z', positive=True, real=True) + G = Gamma('G', k, t) + X = PoissonDistribution(G) + C = CompoundDistribution(X) + assert C.is_Discrete + assert C.set == S.Naturals0 + assert C.pdf(z, evaluate=True).simplify() == t**z*(t + 1)**(-k - z)*gamma(k \ + + z)/(gamma(k)*gamma(z + 1)) + + +def test_compound_pspace(): + X = Normal('X', 2, 4) + Y = Normal('Y', 3, 6) + assert not isinstance(Y.pspace, CompoundPSpace) + N = NormalDistribution(1, 2) + D = PoissonDistribution(3) + B = BernoulliDistribution(0.2, 1, 0) + pspace1 = CompoundPSpace('N', N) + pspace2 = CompoundPSpace('D', D) + pspace3 = CompoundPSpace('B', B) + assert not isinstance(pspace1, CompoundPSpace) + assert not isinstance(pspace2, CompoundPSpace) + assert not isinstance(pspace3, CompoundPSpace) + M = MultivariateNormalDistribution([1, 2], [[2, 1], [1, 2]]) + raises(ValueError, lambda: CompoundPSpace('M', M)) + Y = Normal('Y', X, 6) + assert isinstance(Y.pspace, CompoundPSpace) + assert Y.pspace.distribution == CompoundDistribution(NormalDistribution(X, 6)) + assert Y.pspace.domain.set == Interval(-oo, oo) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_continuous_rv.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_continuous_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c4206b5c29ffd3194d1ae05e57c51c9c1b6d78 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_continuous_rv.py @@ -0,0 +1,1583 @@ +from sympy.concrete.summations import Sum +from sympy.core.function import (Lambda, diff, expand_func) +from sympy.core.mul import Mul +from sympy.core import EulerGamma +from sympy.core.numbers import (E as e, I, Rational, pi) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.complexes import (Abs, im, re, sign) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (cosh, sinh) +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (asin, atan, cos, sin, tan) +from sympy.functions.special.bessel import (besseli, besselj, besselk) +from sympy.functions.special.beta_functions import beta +from sympy.functions.special.error_functions import (erf, erfc, erfi, expint) +from sympy.functions.special.gamma_functions import (gamma, lowergamma, uppergamma) +from sympy.functions.special.zeta_functions import zeta +from sympy.functions.special.hyper import hyper +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (And, Or) +from sympy.sets.sets import Interval +from sympy.simplify.simplify import simplify +from sympy.utilities.lambdify import lambdify +from sympy.functions.special.error_functions import erfinv +from sympy.functions.special.hyper import meijerg +from sympy.sets.sets import FiniteSet, Complement, Intersection +from sympy.stats import (P, E, where, density, variance, covariance, skewness, kurtosis, median, + given, pspace, cdf, characteristic_function, moment_generating_function, + ContinuousRV, Arcsin, Benini, Beta, BetaNoncentral, BetaPrime, + Cauchy, Chi, ChiSquared, ChiNoncentral, Dagum, Davis, Erlang, ExGaussian, + Exponential, ExponentialPower, FDistribution, FisherZ, Frechet, Gamma, + GammaInverse, Gompertz, Gumbel, Kumaraswamy, Laplace, Levy, Logistic, LogCauchy, + LogLogistic, LogitNormal, LogNormal, Maxwell, Moyal, Nakagami, Normal, GaussianInverse, + Pareto, PowerFunction, QuadraticU, RaisedCosine, Rayleigh, Reciprocal, ShiftedGompertz, StudentT, + Trapezoidal, Triangular, Uniform, UniformSum, VonMises, Weibull, coskewness, + WignerSemicircle, Wald, correlation, moment, cmoment, smoment, quantile, + Lomax, BoundedPareto) + +from sympy.stats.crv_types import NormalDistribution, ExponentialDistribution, ContinuousDistributionHandmade +from sympy.stats.joint_rv_types import MultivariateLaplaceDistribution, MultivariateNormalDistribution +from sympy.stats.crv import SingleContinuousPSpace, SingleContinuousDomain +from sympy.stats.compound_rv import CompoundPSpace +from sympy.stats.symbolic_probability import Probability +from sympy.testing.pytest import raises, XFAIL, slow, ignore_warnings +from sympy.core.random import verify_numerically as tn + +oo = S.Infinity + +x, y, z = map(Symbol, 'xyz') + +def test_single_normal(): + mu = Symbol('mu', real=True) + sigma = Symbol('sigma', positive=True) + X = Normal('x', 0, 1) + Y = X*sigma + mu + + assert E(Y) == mu + assert variance(Y) == sigma**2 + pdf = density(Y) + x = Symbol('x', real=True) + assert (pdf(x) == + 2**S.Half*exp(-(x - mu)**2/(2*sigma**2))/(2*pi**S.Half*sigma)) + + assert P(X**2 < 1) == erf(2**S.Half/2) + ans = quantile(Y)(x) + assert ans == Complement(Intersection(FiniteSet( + sqrt(2)*sigma*(sqrt(2)*mu/(2*sigma)+ erfinv(2*x - 1))), + Interval(-oo, oo)), FiniteSet(mu)) + assert E(X, Eq(X, mu)) == mu + + assert median(X) == FiniteSet(0) + # issue 8248 + assert X.pspace.compute_expectation(1).doit() == 1 + + +def test_conditional_1d(): + X = Normal('x', 0, 1) + Y = given(X, X >= 0) + z = Symbol('z') + + assert density(Y)(z) == 2 * density(X)(z) + + assert Y.pspace.domain.set == Interval(0, oo) + assert E(Y) == sqrt(2) / sqrt(pi) + + assert E(X**2) == E(Y**2) + + +def test_ContinuousDomain(): + X = Normal('x', 0, 1) + assert where(X**2 <= 1).set == Interval(-1, 1) + assert where(X**2 <= 1).symbol == X.symbol + assert where(And(X**2 <= 1, X >= 0)).set == Interval(0, 1) + raises(ValueError, lambda: where(sin(X) > 1)) + + Y = given(X, X >= 0) + + assert Y.pspace.domain.set == Interval(0, oo) + + +def test_multiple_normal(): + X, Y = Normal('x', 0, 1), Normal('y', 0, 1) + p = Symbol("p", positive=True) + + assert E(X + Y) == 0 + assert variance(X + Y) == 2 + assert variance(X + X) == 4 + assert covariance(X, Y) == 0 + assert covariance(2*X + Y, -X) == -2*variance(X) + assert skewness(X) == 0 + assert skewness(X + Y) == 0 + assert kurtosis(X) == 3 + assert kurtosis(X+Y) == 3 + assert correlation(X, Y) == 0 + assert correlation(X, X + Y) == correlation(X, X - Y) + assert moment(X, 2) == 1 + assert cmoment(X, 3) == 0 + assert moment(X + Y, 4) == 12 + assert cmoment(X, 2) == variance(X) + assert smoment(X*X, 2) == 1 + assert smoment(X + Y, 3) == skewness(X + Y) + assert smoment(X + Y, 4) == kurtosis(X + Y) + assert E(X, Eq(X + Y, 0)) == 0 + assert variance(X, Eq(X + Y, 0)) == S.Half + assert quantile(X)(p) == sqrt(2)*erfinv(2*p - S.One) + + +def test_symbolic(): + mu1, mu2 = symbols('mu1 mu2', real=True) + s1, s2 = symbols('sigma1 sigma2', positive=True) + rate = Symbol('lambda', positive=True) + X = Normal('x', mu1, s1) + Y = Normal('y', mu2, s2) + Z = Exponential('z', rate) + a, b, c = symbols('a b c', real=True) + + assert E(X) == mu1 + assert E(X + Y) == mu1 + mu2 + assert E(a*X + b) == a*E(X) + b + assert variance(X) == s1**2 + assert variance(X + a*Y + b) == variance(X) + a**2*variance(Y) + + assert E(Z) == 1/rate + assert E(a*Z + b) == a*E(Z) + b + assert E(X + a*Z + b) == mu1 + a/rate + b + assert median(X) == FiniteSet(mu1) + + +def test_cdf(): + X = Normal('x', 0, 1) + + d = cdf(X) + assert P(X < 1) == d(1).rewrite(erfc) + assert d(0) == S.Half + + d = cdf(X, X > 0) # given X>0 + assert d(0) == 0 + + Y = Exponential('y', 10) + d = cdf(Y) + assert d(-5) == 0 + assert P(Y > 3) == 1 - d(3) + + raises(ValueError, lambda: cdf(X + Y)) + + Z = Exponential('z', 1) + f = cdf(Z) + assert f(z) == Piecewise((1 - exp(-z), z >= 0), (0, True)) + + +def test_characteristic_function(): + X = Uniform('x', 0, 1) + + cf = characteristic_function(X) + assert cf(1) == -I*(-1 + exp(I)) + + Y = Normal('y', 1, 1) + cf = characteristic_function(Y) + assert cf(0) == 1 + assert cf(1) == exp(I - S.Half) + + Z = Exponential('z', 5) + cf = characteristic_function(Z) + assert cf(0) == 1 + assert cf(1).expand() == Rational(25, 26) + I*5/26 + + X = GaussianInverse('x', 1, 1) + cf = characteristic_function(X) + assert cf(0) == 1 + assert cf(1) == exp(1 - sqrt(1 - 2*I)) + + X = ExGaussian('x', 0, 1, 1) + cf = characteristic_function(X) + assert cf(0) == 1 + assert cf(1) == (1 + I)*exp(Rational(-1, 2))/2 + + L = Levy('x', 0, 1) + cf = characteristic_function(L) + assert cf(0) == 1 + assert cf(1) == exp(-sqrt(2)*sqrt(-I)) + + +def test_moment_generating_function(): + t = symbols('t', positive=True) + + # Symbolic tests + a, b, c = symbols('a b c') + + mgf = moment_generating_function(Beta('x', a, b))(t) + assert mgf == hyper((a,), (a + b,), t) + + mgf = moment_generating_function(Chi('x', a))(t) + assert mgf == sqrt(2)*t*gamma(a/2 + S.Half)*\ + hyper((a/2 + S.Half,), (Rational(3, 2),), t**2/2)/gamma(a/2) +\ + hyper((a/2,), (S.Half,), t**2/2) + + mgf = moment_generating_function(ChiSquared('x', a))(t) + assert mgf == (1 - 2*t)**(-a/2) + + mgf = moment_generating_function(Erlang('x', a, b))(t) + assert mgf == (1 - t/b)**(-a) + + mgf = moment_generating_function(ExGaussian("x", a, b, c))(t) + assert mgf == exp(a*t + b**2*t**2/2)/(1 - t/c) + + mgf = moment_generating_function(Exponential('x', a))(t) + assert mgf == a/(a - t) + + mgf = moment_generating_function(Gamma('x', a, b))(t) + assert mgf == (-b*t + 1)**(-a) + + mgf = moment_generating_function(Gumbel('x', a, b))(t) + assert mgf == exp(b*t)*gamma(-a*t + 1) + + mgf = moment_generating_function(Gompertz('x', a, b))(t) + assert mgf == b*exp(b)*expint(t/a, b) + + mgf = moment_generating_function(Laplace('x', a, b))(t) + assert mgf == exp(a*t)/(-b**2*t**2 + 1) + + mgf = moment_generating_function(Logistic('x', a, b))(t) + assert mgf == exp(a*t)*beta(-b*t + 1, b*t + 1) + + mgf = moment_generating_function(Normal('x', a, b))(t) + assert mgf == exp(a*t + b**2*t**2/2) + + mgf = moment_generating_function(Pareto('x', a, b))(t) + assert mgf == b*(-a*t)**b*uppergamma(-b, -a*t) + + mgf = moment_generating_function(QuadraticU('x', a, b))(t) + assert str(mgf) == ("(3*(t*(-4*b + (a + b)**2) + 4)*exp(b*t) - " + "3*(t*(a**2 + 2*a*(b - 2) + b**2) + 4)*exp(a*t))/(t**2*(a - b)**3)") + + mgf = moment_generating_function(RaisedCosine('x', a, b))(t) + assert mgf == pi**2*exp(a*t)*sinh(b*t)/(b*t*(b**2*t**2 + pi**2)) + + mgf = moment_generating_function(Rayleigh('x', a))(t) + assert mgf == sqrt(2)*sqrt(pi)*a*t*(erf(sqrt(2)*a*t/2) + 1)\ + *exp(a**2*t**2/2)/2 + 1 + + mgf = moment_generating_function(Triangular('x', a, b, c))(t) + assert str(mgf) == ("(-2*(-a + b)*exp(c*t) + 2*(-a + c)*exp(b*t) + " + "2*(b - c)*exp(a*t))/(t**2*(-a + b)*(-a + c)*(b - c))") + + mgf = moment_generating_function(Uniform('x', a, b))(t) + assert mgf == (-exp(a*t) + exp(b*t))/(t*(-a + b)) + + mgf = moment_generating_function(UniformSum('x', a))(t) + assert mgf == ((exp(t) - 1)/t)**a + + mgf = moment_generating_function(WignerSemicircle('x', a))(t) + assert mgf == 2*besseli(1, a*t)/(a*t) + + # Numeric tests + + mgf = moment_generating_function(Beta('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 1) == hyper((2,), (3,), 1)/2 + + mgf = moment_generating_function(Chi('x', 1))(t) + assert mgf.diff(t).subs(t, 1) == sqrt(2)*hyper((1,), (Rational(3, 2),), S.Half + )/sqrt(pi) + hyper((Rational(3, 2),), (Rational(3, 2),), S.Half) + 2*sqrt(2)*hyper((2,), + (Rational(5, 2),), S.Half)/(3*sqrt(pi)) + + mgf = moment_generating_function(ChiSquared('x', 1))(t) + assert mgf.diff(t).subs(t, 1) == I + + mgf = moment_generating_function(Erlang('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 0) == 1 + + mgf = moment_generating_function(ExGaussian("x", 0, 1, 1))(t) + assert mgf.diff(t).subs(t, 2) == -exp(2) + + mgf = moment_generating_function(Exponential('x', 1))(t) + assert mgf.diff(t).subs(t, 0) == 1 + + mgf = moment_generating_function(Gamma('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 0) == 1 + + mgf = moment_generating_function(Gumbel('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 0) == EulerGamma + 1 + + mgf = moment_generating_function(Gompertz('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 1) == -e*meijerg(((), (1, 1)), + ((0, 0, 0), ()), 1) + + mgf = moment_generating_function(Laplace('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 0) == 1 + + mgf = moment_generating_function(Logistic('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 0) == beta(1, 1) + + mgf = moment_generating_function(Normal('x', 0, 1))(t) + assert mgf.diff(t).subs(t, 1) == exp(S.Half) + + mgf = moment_generating_function(Pareto('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 0) == expint(1, 0) + + mgf = moment_generating_function(QuadraticU('x', 1, 2))(t) + assert mgf.diff(t).subs(t, 1) == -12*e - 3*exp(2) + + mgf = moment_generating_function(RaisedCosine('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 1) == -2*e*pi**2*sinh(1)/\ + (1 + pi**2)**2 + e*pi**2*cosh(1)/(1 + pi**2) + + mgf = moment_generating_function(Rayleigh('x', 1))(t) + assert mgf.diff(t).subs(t, 0) == sqrt(2)*sqrt(pi)/2 + + mgf = moment_generating_function(Triangular('x', 1, 3, 2))(t) + assert mgf.diff(t).subs(t, 1) == -e + exp(3) + + mgf = moment_generating_function(Uniform('x', 0, 1))(t) + assert mgf.diff(t).subs(t, 1) == 1 + + mgf = moment_generating_function(UniformSum('x', 1))(t) + assert mgf.diff(t).subs(t, 1) == 1 + + mgf = moment_generating_function(WignerSemicircle('x', 1))(t) + assert mgf.diff(t).subs(t, 1) == -2*besseli(1, 1) + besseli(2, 1) +\ + besseli(0, 1) + + +def test_ContinuousRV(): + pdf = sqrt(2)*exp(-x**2/2)/(2*sqrt(pi)) # Normal distribution + # X and Y should be equivalent + X = ContinuousRV(x, pdf, check=True) + Y = Normal('y', 0, 1) + + assert variance(X) == variance(Y) + assert P(X > 0) == P(Y > 0) + Z = ContinuousRV(z, exp(-z), set=Interval(0, oo)) + assert Z.pspace.domain.set == Interval(0, oo) + assert E(Z) == 1 + assert P(Z > 5) == exp(-5) + raises(ValueError, lambda: ContinuousRV(z, exp(-z), set=Interval(0, 10), check=True)) + + # the correct pdf for Gamma(k, theta) but the integral in `check` + # integrates to something equivalent to 1 and not to 1 exactly + _x, k, theta = symbols("x k theta", positive=True) + pdf = 1/(gamma(k)*theta**k)*_x**(k-1)*exp(-_x/theta) + X = ContinuousRV(_x, pdf, set=Interval(0, oo)) + Y = Gamma('y', k, theta) + assert (E(X) - E(Y)).simplify() == 0 + assert (variance(X) - variance(Y)).simplify() == 0 + + +def test_arcsin(): + + a = Symbol("a", real=True) + b = Symbol("b", real=True) + + X = Arcsin('x', a, b) + assert density(X)(x) == 1/(pi*sqrt((-x + b)*(x - a))) + assert cdf(X)(x) == Piecewise((0, a > x), + (2*asin(sqrt((-a + x)/(-a + b)))/pi, b >= x), + (1, True)) + assert pspace(X).domain.set == Interval(a, b) + +def test_benini(): + alpha = Symbol("alpha", positive=True) + beta = Symbol("beta", positive=True) + sigma = Symbol("sigma", positive=True) + X = Benini('x', alpha, beta, sigma) + + assert density(X)(x) == ((alpha/x + 2*beta*log(x/sigma)/x) + *exp(-alpha*log(x/sigma) - beta*log(x/sigma)**2)) + + assert pspace(X).domain.set == Interval(sigma, oo) + raises(NotImplementedError, lambda: moment_generating_function(X)) + alpha = Symbol("alpha", nonpositive=True) + raises(ValueError, lambda: Benini('x', alpha, beta, sigma)) + + beta = Symbol("beta", nonpositive=True) + raises(ValueError, lambda: Benini('x', alpha, beta, sigma)) + + alpha = Symbol("alpha", positive=True) + raises(ValueError, lambda: Benini('x', alpha, beta, sigma)) + + beta = Symbol("beta", positive=True) + sigma = Symbol("sigma", nonpositive=True) + raises(ValueError, lambda: Benini('x', alpha, beta, sigma)) + +def test_beta(): + a, b = symbols('alpha beta', positive=True) + B = Beta('x', a, b) + + assert pspace(B).domain.set == Interval(0, 1) + assert characteristic_function(B)(x) == hyper((a,), (a + b,), I*x) + assert density(B)(x) == x**(a - 1)*(1 - x)**(b - 1)/beta(a, b) + + assert simplify(E(B)) == a / (a + b) + assert simplify(variance(B)) == a*b / (a**3 + 3*a**2*b + a**2 + 3*a*b**2 + 2*a*b + b**3 + b**2) + + # Full symbolic solution is too much, test with numeric version + a, b = 1, 2 + B = Beta('x', a, b) + assert expand_func(E(B)) == a / S(a + b) + assert expand_func(variance(B)) == (a*b) / S((a + b)**2 * (a + b + 1)) + assert median(B) == FiniteSet(1 - 1/sqrt(2)) + +def test_beta_noncentral(): + a, b = symbols('a b', positive=True) + c = Symbol('c', nonnegative=True) + _k = Dummy('k') + + X = BetaNoncentral('x', a, b, c) + + assert pspace(X).domain.set == Interval(0, 1) + + dens = density(X) + z = Symbol('z') + + res = Sum( z**(_k + a - 1)*(c/2)**_k*(1 - z)**(b - 1)*exp(-c/2)/ + (beta(_k + a, b)*factorial(_k)), (_k, 0, oo)) + assert dens(z).dummy_eq(res) + + # BetaCentral should not raise if the assumptions + # on the symbols can not be determined + a, b, c = symbols('a b c') + assert BetaNoncentral('x', a, b, c) + + a = Symbol('a', positive=False, real=True) + raises(ValueError, lambda: BetaNoncentral('x', a, b, c)) + + a = Symbol('a', positive=True) + b = Symbol('b', positive=False, real=True) + raises(ValueError, lambda: BetaNoncentral('x', a, b, c)) + + a = Symbol('a', positive=True) + b = Symbol('b', positive=True) + c = Symbol('c', nonnegative=False, real=True) + raises(ValueError, lambda: BetaNoncentral('x', a, b, c)) + +def test_betaprime(): + alpha = Symbol("alpha", positive=True) + + betap = Symbol("beta", positive=True) + + X = BetaPrime('x', alpha, betap) + assert density(X)(x) == x**(alpha - 1)*(x + 1)**(-alpha - betap)/beta(alpha, betap) + + alpha = Symbol("alpha", nonpositive=True) + raises(ValueError, lambda: BetaPrime('x', alpha, betap)) + + alpha = Symbol("alpha", positive=True) + betap = Symbol("beta", nonpositive=True) + raises(ValueError, lambda: BetaPrime('x', alpha, betap)) + X = BetaPrime('x', 1, 1) + assert median(X) == FiniteSet(1) + + +def test_BoundedPareto(): + L, H = symbols('L, H', negative=True) + raises(ValueError, lambda: BoundedPareto('X', 1, L, H)) + L, H = symbols('L, H', real=False) + raises(ValueError, lambda: BoundedPareto('X', 1, L, H)) + L, H = symbols('L, H', positive=True) + raises(ValueError, lambda: BoundedPareto('X', -1, L, H)) + + X = BoundedPareto('X', 2, L, H) + assert X.pspace.domain.set == Interval(L, H) + assert density(X)(x) == 2*L**2/(x**3*(1 - L**2/H**2)) + assert cdf(X)(x) == Piecewise((-H**2*L**2/(x**2*(H**2 - L**2)) \ + + H**2/(H**2 - L**2), L <= x), (0, True)) + assert E(X).simplify() == 2*H*L/(H + L) + X = BoundedPareto('X', 1, 2, 4) + assert E(X).simplify() == log(16) + assert median(X) == FiniteSet(Rational(8, 3)) + assert variance(X).simplify() == 8 - 16*log(2)**2 + + +def test_cauchy(): + x0 = Symbol("x0", real=True) + gamma = Symbol("gamma", positive=True) + p = Symbol("p", positive=True) + + X = Cauchy('x', x0, gamma) + # Tests the characteristic function + assert characteristic_function(X)(x) == exp(-gamma*Abs(x) + I*x*x0) + raises(NotImplementedError, lambda: moment_generating_function(X)) + assert density(X)(x) == 1/(pi*gamma*(1 + (x - x0)**2/gamma**2)) + assert diff(cdf(X)(x), x) == density(X)(x) + assert quantile(X)(p) == gamma*tan(pi*(p - S.Half)) + x0 + + x1 = Symbol("x1", real=False) + raises(ValueError, lambda: Cauchy('x', x1, gamma)) + gamma = Symbol("gamma", nonpositive=True) + raises(ValueError, lambda: Cauchy('x', x0, gamma)) + assert median(X) == FiniteSet(x0) + +def test_chi(): + from sympy.core.numbers import I + k = Symbol("k", integer=True) + + X = Chi('x', k) + assert density(X)(x) == 2**(-k/2 + 1)*x**(k - 1)*exp(-x**2/2)/gamma(k/2) + + # Tests the characteristic function + assert characteristic_function(X)(x) == sqrt(2)*I*x*gamma(k/2 + S(1)/2)*hyper((k/2 + S(1)/2,), + (S(3)/2,), -x**2/2)/gamma(k/2) + hyper((k/2,), (S(1)/2,), -x**2/2) + + # Tests the moment generating function + assert moment_generating_function(X)(x) == sqrt(2)*x*gamma(k/2 + S(1)/2)*hyper((k/2 + S(1)/2,), + (S(3)/2,), x**2/2)/gamma(k/2) + hyper((k/2,), (S(1)/2,), x**2/2) + + k = Symbol("k", integer=True, positive=False) + raises(ValueError, lambda: Chi('x', k)) + + k = Symbol("k", integer=False, positive=True) + raises(ValueError, lambda: Chi('x', k)) + +def test_chi_noncentral(): + k = Symbol("k", integer=True) + l = Symbol("l") + + X = ChiNoncentral("x", k, l) + assert density(X)(x) == (x**k*l*(x*l)**(-k/2)* + exp(-x**2/2 - l**2/2)*besseli(k/2 - 1, x*l)) + + k = Symbol("k", integer=True, positive=False) + raises(ValueError, lambda: ChiNoncentral('x', k, l)) + + k = Symbol("k", integer=True, positive=True) + l = Symbol("l", nonpositive=True) + raises(ValueError, lambda: ChiNoncentral('x', k, l)) + + k = Symbol("k", integer=False) + l = Symbol("l", positive=True) + raises(ValueError, lambda: ChiNoncentral('x', k, l)) + + +def test_chi_squared(): + k = Symbol("k", integer=True) + X = ChiSquared('x', k) + + # Tests the characteristic function + assert characteristic_function(X)(x) == ((-2*I*x + 1)**(-k/2)) + + assert density(X)(x) == 2**(-k/2)*x**(k/2 - 1)*exp(-x/2)/gamma(k/2) + assert cdf(X)(x) == Piecewise((lowergamma(k/2, x/2)/gamma(k/2), x >= 0), (0, True)) + assert E(X) == k + assert variance(X) == 2*k + + X = ChiSquared('x', 15) + assert cdf(X)(3) == -14873*sqrt(6)*exp(Rational(-3, 2))/(5005*sqrt(pi)) + erf(sqrt(6)/2) + + k = Symbol("k", integer=True, positive=False) + raises(ValueError, lambda: ChiSquared('x', k)) + + k = Symbol("k", integer=False, positive=True) + raises(ValueError, lambda: ChiSquared('x', k)) + + +def test_dagum(): + p = Symbol("p", positive=True) + b = Symbol("b", positive=True) + a = Symbol("a", positive=True) + + X = Dagum('x', p, a, b) + assert density(X)(x) == a*p*(x/b)**(a*p)*((x/b)**a + 1)**(-p - 1)/x + assert cdf(X)(x) == Piecewise(((1 + (x/b)**(-a))**(-p), x >= 0), + (0, True)) + + p = Symbol("p", nonpositive=True) + raises(ValueError, lambda: Dagum('x', p, a, b)) + + p = Symbol("p", positive=True) + b = Symbol("b", nonpositive=True) + raises(ValueError, lambda: Dagum('x', p, a, b)) + + b = Symbol("b", positive=True) + a = Symbol("a", nonpositive=True) + raises(ValueError, lambda: Dagum('x', p, a, b)) + X = Dagum('x', 1, 1, 1) + assert median(X) == FiniteSet(1) + +def test_davis(): + b = Symbol("b", positive=True) + n = Symbol("n", positive=True) + mu = Symbol("mu", positive=True) + + X = Davis('x', b, n, mu) + dividend = b**n*(x - mu)**(-1-n) + divisor = (exp(b/(x-mu))-1)*(gamma(n)*zeta(n)) + assert density(X)(x) == dividend/divisor + + +def test_erlang(): + k = Symbol("k", integer=True, positive=True) + l = Symbol("l", positive=True) + + X = Erlang("x", k, l) + assert density(X)(x) == x**(k - 1)*l**k*exp(-x*l)/gamma(k) + assert cdf(X)(x) == Piecewise((lowergamma(k, l*x)/gamma(k), x > 0), + (0, True)) + + +def test_exgaussian(): + m, z = symbols("m, z") + s, l = symbols("s, l", positive=True) + X = ExGaussian("x", m, s, l) + + assert density(X)(z) == l*exp(l*(l*s**2 + 2*m - 2*z)/2) *\ + erfc(sqrt(2)*(l*s**2 + m - z)/(2*s))/2 + + # Note: actual_output simplifies to expected_output. + # Ideally cdf(X)(z) would return expected_output + # expected_output = (erf(sqrt(2)*(l*s**2 + m - z)/(2*s)) - 1)*exp(l*(l*s**2 + 2*m - 2*z)/2)/2 - erf(sqrt(2)*(m - z)/(2*s))/2 + S.Half + u = l*(z - m) + v = l*s + GaussianCDF1 = cdf(Normal('x', 0, v))(u) + GaussianCDF2 = cdf(Normal('x', v**2, v))(u) + actual_output = GaussianCDF1 - exp(-u + (v**2/2) + log(GaussianCDF2)) + assert cdf(X)(z) == actual_output + # assert simplify(actual_output) == expected_output + + assert variance(X).expand() == s**2 + l**(-2) + + assert skewness(X).expand() == 2/(l**3*s**2*sqrt(s**2 + l**(-2)) + l * + sqrt(s**2 + l**(-2))) + + +@slow +def test_exponential(): + rate = Symbol('lambda', positive=True) + X = Exponential('x', rate) + p = Symbol("p", positive=True, real=True) + + assert E(X) == 1/rate + assert variance(X) == 1/rate**2 + assert skewness(X) == 2 + assert skewness(X) == smoment(X, 3) + assert kurtosis(X) == 9 + assert kurtosis(X) == smoment(X, 4) + assert smoment(2*X, 4) == smoment(X, 4) + assert moment(X, 3) == 3*2*1/rate**3 + assert P(X > 0) is S.One + assert P(X > 1) == exp(-rate) + assert P(X > 10) == exp(-10*rate) + assert quantile(X)(p) == -log(1-p)/rate + + assert where(X <= 1).set == Interval(0, 1) + Y = Exponential('y', 1) + assert median(Y) == FiniteSet(log(2)) + #Test issue 9970 + z = Dummy('z') + assert P(X > z) == exp(-z*rate) + assert P(X < z) == 0 + #Test issue 10076 (Distribution with interval(0,oo)) + x = Symbol('x') + _z = Dummy('_z') + b = SingleContinuousPSpace(x, ExponentialDistribution(2)) + + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + expected1 = Integral(2*exp(-2*_z), (_z, 3, oo)) + assert b.probability(x > 3, evaluate=False).rewrite(Integral).dummy_eq(expected1) + + expected2 = Integral(2*exp(-2*_z), (_z, 0, 4)) + assert b.probability(x < 4, evaluate=False).rewrite(Integral).dummy_eq(expected2) + Y = Exponential('y', 2*rate) + assert coskewness(X, X, X) == skewness(X) + assert coskewness(X, Y + rate*X, Y + 2*rate*X) == \ + 4/(sqrt(1 + 1/(4*rate**2))*sqrt(4 + 1/(4*rate**2))) + assert coskewness(X + 2*Y, Y + X, Y + 2*X, X > 3) == \ + sqrt(170)*Rational(9, 85) + +def test_exponential_power(): + mu = Symbol('mu') + z = Symbol('z') + alpha = Symbol('alpha', positive=True) + beta = Symbol('beta', positive=True) + + X = ExponentialPower('x', mu, alpha, beta) + + assert density(X)(z) == beta*exp(-(Abs(mu - z)/alpha) + ** beta)/(2*alpha*gamma(1/beta)) + assert cdf(X)(z) == S.Half + lowergamma(1/beta, + (Abs(mu - z)/alpha)**beta)*sign(-mu + z)/\ + (2*gamma(1/beta)) + + +def test_f_distribution(): + d1 = Symbol("d1", positive=True) + d2 = Symbol("d2", positive=True) + + X = FDistribution("x", d1, d2) + + assert density(X)(x) == (d2**(d2/2)*sqrt((d1*x)**d1*(d1*x + d2)**(-d1 - d2)) + /(x*beta(d1/2, d2/2))) + + raises(NotImplementedError, lambda: moment_generating_function(X)) + d1 = Symbol("d1", nonpositive=True) + raises(ValueError, lambda: FDistribution('x', d1, d1)) + + d1 = Symbol("d1", positive=True, integer=False) + raises(ValueError, lambda: FDistribution('x', d1, d1)) + + d1 = Symbol("d1", positive=True) + d2 = Symbol("d2", nonpositive=True) + raises(ValueError, lambda: FDistribution('x', d1, d2)) + + d2 = Symbol("d2", positive=True, integer=False) + raises(ValueError, lambda: FDistribution('x', d1, d2)) + + +def test_fisher_z(): + d1 = Symbol("d1", positive=True) + d2 = Symbol("d2", positive=True) + + X = FisherZ("x", d1, d2) + assert density(X)(x) == (2*d1**(d1/2)*d2**(d2/2)*(d1*exp(2*x) + d2) + **(-d1/2 - d2/2)*exp(d1*x)/beta(d1/2, d2/2)) + +def test_frechet(): + a = Symbol("a", positive=True) + s = Symbol("s", positive=True) + m = Symbol("m", real=True) + + X = Frechet("x", a, s=s, m=m) + assert density(X)(x) == a*((x - m)/s)**(-a - 1)*exp(-((x - m)/s)**(-a))/s + assert cdf(X)(x) == Piecewise((exp(-((-m + x)/s)**(-a)), m <= x), (0, True)) + +@slow +def test_gamma(): + k = Symbol("k", positive=True) + theta = Symbol("theta", positive=True) + + X = Gamma('x', k, theta) + + # Tests characteristic function + assert characteristic_function(X)(x) == ((-I*theta*x + 1)**(-k)) + + assert density(X)(x) == x**(k - 1)*theta**(-k)*exp(-x/theta)/gamma(k) + assert cdf(X, meijerg=True)(z) == Piecewise( + (-k*lowergamma(k, 0)/gamma(k + 1) + + k*lowergamma(k, z/theta)/gamma(k + 1), z >= 0), + (0, True)) + + # assert simplify(variance(X)) == k*theta**2 # handled numerically below + assert E(X) == moment(X, 1) + + k, theta = symbols('k theta', positive=True) + X = Gamma('x', k, theta) + assert E(X) == k*theta + assert variance(X) == k*theta**2 + assert skewness(X).expand() == 2/sqrt(k) + assert kurtosis(X).expand() == 3 + 6/k + + Y = Gamma('y', 2*k, 3*theta) + assert coskewness(X, theta*X + Y, k*X + Y).simplify() == \ + 2*531441**(-k)*sqrt(k)*theta*(3*3**(12*k) - 2*531441**k) \ + /(sqrt(k**2 + 18)*sqrt(theta**2 + 18)) + +def test_gamma_inverse(): + a = Symbol("a", positive=True) + b = Symbol("b", positive=True) + X = GammaInverse("x", a, b) + assert density(X)(x) == x**(-a - 1)*b**a*exp(-b/x)/gamma(a) + assert cdf(X)(x) == Piecewise((uppergamma(a, b/x)/gamma(a), x > 0), (0, True)) + assert characteristic_function(X)(x) == 2 * (-I*b*x)**(a/2) \ + * besselk(a, 2*sqrt(b)*sqrt(-I*x))/gamma(a) + raises(NotImplementedError, lambda: moment_generating_function(X)) + +def test_gompertz(): + b = Symbol("b", positive=True) + eta = Symbol("eta", positive=True) + + X = Gompertz("x", b, eta) + + assert density(X)(x) == b*eta*exp(eta)*exp(b*x)*exp(-eta*exp(b*x)) + assert cdf(X)(x) == 1 - exp(eta)*exp(-eta*exp(b*x)) + assert diff(cdf(X)(x), x) == density(X)(x) + + +def test_gumbel(): + beta = Symbol("beta", positive=True) + mu = Symbol("mu") + x = Symbol("x") + y = Symbol("y") + X = Gumbel("x", beta, mu) + Y = Gumbel("y", beta, mu, minimum=True) + assert density(X)(x).expand() == \ + exp(mu/beta)*exp(-x/beta)*exp(-exp(mu/beta)*exp(-x/beta))/beta + assert density(Y)(y).expand() == \ + exp(-mu/beta)*exp(y/beta)*exp(-exp(-mu/beta)*exp(y/beta))/beta + assert cdf(X)(x).expand() == \ + exp(-exp(mu/beta)*exp(-x/beta)) + assert characteristic_function(X)(x) == exp(I*mu*x)*gamma(-I*beta*x + 1) + +def test_kumaraswamy(): + a = Symbol("a", positive=True) + b = Symbol("b", positive=True) + + X = Kumaraswamy("x", a, b) + assert density(X)(x) == x**(a - 1)*a*b*(-x**a + 1)**(b - 1) + assert cdf(X)(x) == Piecewise((0, x < 0), + (-(-x**a + 1)**b + 1, x <= 1), + (1, True)) + + +def test_laplace(): + mu = Symbol("mu") + b = Symbol("b", positive=True) + + X = Laplace('x', mu, b) + + #Tests characteristic_function + assert characteristic_function(X)(x) == (exp(I*mu*x)/(b**2*x**2 + 1)) + + assert density(X)(x) == exp(-Abs(x - mu)/b)/(2*b) + assert cdf(X)(x) == Piecewise((exp((-mu + x)/b)/2, mu > x), + (-exp((mu - x)/b)/2 + 1, True)) + X = Laplace('x', [1, 2], [[1, 0], [0, 1]]) + assert isinstance(pspace(X).distribution, MultivariateLaplaceDistribution) + +def test_levy(): + mu = Symbol("mu", real=True) + c = Symbol("c", positive=True) + + X = Levy('x', mu, c) + assert X.pspace.domain.set == Interval(mu, oo) + assert density(X)(x) == sqrt(c/(2*pi))*exp(-c/(2*(x - mu)))/((x - mu)**(S.One + S.Half)) + assert cdf(X)(x) == erfc(sqrt(c/(2*(x - mu)))) + + raises(NotImplementedError, lambda: moment_generating_function(X)) + mu = Symbol("mu", real=False) + raises(ValueError, lambda: Levy('x',mu,c)) + + c = Symbol("c", nonpositive=True) + raises(ValueError, lambda: Levy('x',mu,c)) + + mu = Symbol("mu", real=True) + raises(ValueError, lambda: Levy('x',mu,c)) + +def test_logcauchy(): + mu = Symbol("mu", positive=True) + sigma = Symbol("sigma", positive=True) + + X = LogCauchy("x", mu, sigma) + + assert density(X)(x) == sigma/(x*pi*(sigma**2 + (-mu + log(x))**2)) + assert cdf(X)(x) == atan((log(x) - mu)/sigma)/pi + S.Half + + +def test_logistic(): + mu = Symbol("mu", real=True) + s = Symbol("s", positive=True) + p = Symbol("p", positive=True) + + X = Logistic('x', mu, s) + + #Tests characteristics_function + assert characteristic_function(X)(x) == \ + (Piecewise((pi*s*x*exp(I*mu*x)/sinh(pi*s*x), Ne(x, 0)), (1, True))) + + assert density(X)(x) == exp((-x + mu)/s)/(s*(exp((-x + mu)/s) + 1)**2) + assert cdf(X)(x) == 1/(exp((mu - x)/s) + 1) + assert quantile(X)(p) == mu - s*log(-S.One + 1/p) + +def test_loglogistic(): + a, b = symbols('a b') + assert LogLogistic('x', a, b) + + a = Symbol('a', negative=True) + b = Symbol('b', positive=True) + raises(ValueError, lambda: LogLogistic('x', a, b)) + + a = Symbol('a', positive=True) + b = Symbol('b', negative=True) + raises(ValueError, lambda: LogLogistic('x', a, b)) + + a, b, z, p = symbols('a b z p', positive=True) + X = LogLogistic('x', a, b) + assert density(X)(z) == b*(z/a)**(b - 1)/(a*((z/a)**b + 1)**2) + assert cdf(X)(z) == 1/(1 + (z/a)**(-b)) + assert quantile(X)(p) == a*(p/(1 - p))**(1/b) + + # Expectation + assert E(X) == Piecewise((S.NaN, b <= 1), (pi*a/(b*sin(pi/b)), True)) + b = symbols('b', prime=True) # b > 1 + X = LogLogistic('x', a, b) + assert E(X) == pi*a/(b*sin(pi/b)) + X = LogLogistic('x', 1, 2) + assert median(X) == FiniteSet(1) + +def test_logitnormal(): + mu = Symbol('mu', real=True) + s = Symbol('s', positive=True) + X = LogitNormal('x', mu, s) + x = Symbol('x') + + assert density(X)(x) == sqrt(2)*exp(-(-mu + log(x/(1 - x)))**2/(2*s**2))/(2*sqrt(pi)*s*x*(1 - x)) + assert cdf(X)(x) == erf(sqrt(2)*(-mu + log(x/(1 - x)))/(2*s))/2 + S(1)/2 + +def test_lognormal(): + mean = Symbol('mu', real=True) + std = Symbol('sigma', positive=True) + X = LogNormal('x', mean, std) + # The sympy integrator can't do this too well + #assert E(X) == exp(mean+std**2/2) + #assert variance(X) == (exp(std**2)-1) * exp(2*mean + std**2) + + # The sympy integrator can't do this too well + #assert E(X) == + raises(NotImplementedError, lambda: moment_generating_function(X)) + mu = Symbol("mu", real=True) + sigma = Symbol("sigma", positive=True) + + X = LogNormal('x', mu, sigma) + assert density(X)(x) == (sqrt(2)*exp(-(-mu + log(x))**2 + /(2*sigma**2))/(2*x*sqrt(pi)*sigma)) + # Tests cdf + assert cdf(X)(x) == Piecewise( + (erf(sqrt(2)*(-mu + log(x))/(2*sigma))/2 + + S(1)/2, x > 0), (0, True)) + + X = LogNormal('x', 0, 1) # Mean 0, standard deviation 1 + assert density(X)(x) == sqrt(2)*exp(-log(x)**2/2)/(2*x*sqrt(pi)) + + +def test_Lomax(): + a, l = symbols('a, l', negative=True) + raises(ValueError, lambda: Lomax('X', a, l)) + a, l = symbols('a, l', real=False) + raises(ValueError, lambda: Lomax('X', a, l)) + + a, l = symbols('a, l', positive=True) + X = Lomax('X', a, l) + assert X.pspace.domain.set == Interval(0, oo) + assert density(X)(x) == a*(1 + x/l)**(-a - 1)/l + assert cdf(X)(x) == Piecewise((1 - (1 + x/l)**(-a), x >= 0), (0, True)) + a = 3 + X = Lomax('X', a, l) + assert E(X) == l/2 + assert median(X) == FiniteSet(l*(-1 + 2**Rational(1, 3))) + assert variance(X) == 3*l**2/4 + + +def test_maxwell(): + a = Symbol("a", positive=True) + + X = Maxwell('x', a) + + assert density(X)(x) == (sqrt(2)*x**2*exp(-x**2/(2*a**2))/ + (sqrt(pi)*a**3)) + assert E(X) == 2*sqrt(2)*a/sqrt(pi) + assert variance(X) == -8*a**2/pi + 3*a**2 + assert cdf(X)(x) == erf(sqrt(2)*x/(2*a)) - sqrt(2)*x*exp(-x**2/(2*a**2))/(sqrt(pi)*a) + assert diff(cdf(X)(x), x) == density(X)(x) + + +@slow +def test_Moyal(): + mu = Symbol('mu',real=False) + sigma = Symbol('sigma', positive=True) + raises(ValueError, lambda: Moyal('M',mu, sigma)) + + mu = Symbol('mu', real=True) + sigma = Symbol('sigma', negative=True) + raises(ValueError, lambda: Moyal('M',mu, sigma)) + + sigma = Symbol('sigma', positive=True) + M = Moyal('M', mu, sigma) + assert density(M)(z) == sqrt(2)*exp(-exp((mu - z)/sigma)/2 + - (-mu + z)/(2*sigma))/(2*sqrt(pi)*sigma) + assert cdf(M)(z).simplify() == 1 - erf(sqrt(2)*exp((mu - z)/(2*sigma))/2) + assert characteristic_function(M)(z) == 2**(-I*sigma*z)*exp(I*mu*z) \ + *gamma(-I*sigma*z + Rational(1, 2))/sqrt(pi) + assert E(M) == mu + EulerGamma*sigma + sigma*log(2) + assert moment_generating_function(M)(z) == 2**(-sigma*z)*exp(mu*z) \ + *gamma(-sigma*z + Rational(1, 2))/sqrt(pi) + + +def test_nakagami(): + mu = Symbol("mu", positive=True) + omega = Symbol("omega", positive=True) + + X = Nakagami('x', mu, omega) + assert density(X)(x) == (2*x**(2*mu - 1)*mu**mu*omega**(-mu) + *exp(-x**2*mu/omega)/gamma(mu)) + assert simplify(E(X)) == (sqrt(mu)*sqrt(omega) + *gamma(mu + S.Half)/gamma(mu + 1)) + assert simplify(variance(X)) == ( + omega - omega*gamma(mu + S.Half)**2/(gamma(mu)*gamma(mu + 1))) + assert cdf(X)(x) == Piecewise( + (lowergamma(mu, mu*x**2/omega)/gamma(mu), x > 0), + (0, True)) + X = Nakagami('x', 1, 1) + assert median(X) == FiniteSet(sqrt(log(2))) + +def test_gaussian_inverse(): + # test for symbolic parameters + a, b = symbols('a b') + assert GaussianInverse('x', a, b) + + # Inverse Gaussian distribution is also known as Wald distribution + # `GaussianInverse` can also be referred by the name `Wald` + a, b, z = symbols('a b z') + X = Wald('x', a, b) + assert density(X)(z) == sqrt(2)*sqrt(b/z**3)*exp(-b*(-a + z)**2/(2*a**2*z))/(2*sqrt(pi)) + + a, b = symbols('a b', positive=True) + z = Symbol('z', positive=True) + + X = GaussianInverse('x', a, b) + assert density(X)(z) == sqrt(2)*sqrt(b)*sqrt(z**(-3))*exp(-b*(-a + z)**2/(2*a**2*z))/(2*sqrt(pi)) + assert E(X) == a + assert variance(X).expand() == a**3/b + assert cdf(X)(z) == (S.Half - erf(sqrt(2)*sqrt(b)*(1 + z/a)/(2*sqrt(z)))/2)*exp(2*b/a) +\ + erf(sqrt(2)*sqrt(b)*(-1 + z/a)/(2*sqrt(z)))/2 + S.Half + + a = symbols('a', nonpositive=True) + raises(ValueError, lambda: GaussianInverse('x', a, b)) + + a = symbols('a', positive=True) + b = symbols('b', nonpositive=True) + raises(ValueError, lambda: GaussianInverse('x', a, b)) + +def test_pareto(): + xm, beta = symbols('xm beta', positive=True) + alpha = beta + 5 + X = Pareto('x', xm, alpha) + + dens = density(X) + + #Tests cdf function + assert cdf(X)(x) == \ + Piecewise((-x**(-beta - 5)*xm**(beta + 5) + 1, x >= xm), (0, True)) + + #Tests characteristic_function + assert characteristic_function(X)(x) == \ + ((-I*x*xm)**(beta + 5)*(beta + 5)*uppergamma(-beta - 5, -I*x*xm)) + + assert dens(x) == x**(-(alpha + 1))*xm**(alpha)*(alpha) + + assert simplify(E(X)) == alpha*xm/(alpha-1) + + # computation of taylor series for MGF still too slow + #assert simplify(variance(X)) == xm**2*alpha / ((alpha-1)**2*(alpha-2)) + + +def test_pareto_numeric(): + xm, beta = 3, 2 + alpha = beta + 5 + X = Pareto('x', xm, alpha) + + assert E(X) == alpha*xm/S(alpha - 1) + assert variance(X) == xm**2*alpha / S((alpha - 1)**2*(alpha - 2)) + assert median(X) == FiniteSet(3*2**Rational(1, 7)) + # Skewness tests too slow. Try shortcutting function? + + +def test_PowerFunction(): + alpha = Symbol("alpha", nonpositive=True) + a, b = symbols('a, b', real=True) + raises (ValueError, lambda: PowerFunction('x', alpha, a, b)) + + a, b = symbols('a, b', real=False) + raises (ValueError, lambda: PowerFunction('x', alpha, a, b)) + + alpha = Symbol("alpha", positive=True) + a, b = symbols('a, b', real=True) + raises (ValueError, lambda: PowerFunction('x', alpha, 5, 2)) + + X = PowerFunction('X', 2, a, b) + assert density(X)(z) == (-2*a + 2*z)/(-a + b)**2 + assert cdf(X)(z) == Piecewise((a**2/(a**2 - 2*a*b + b**2) - + 2*a*z/(a**2 - 2*a*b + b**2) + z**2/(a**2 - 2*a*b + b**2), a <= z), (0, True)) + + X = PowerFunction('X', 2, 0, 1) + assert density(X)(z) == 2*z + assert cdf(X)(z) == Piecewise((z**2, z >= 0), (0,True)) + assert E(X) == Rational(2,3) + assert P(X < 0) == 0 + assert P(X < 1) == 1 + assert median(X) == FiniteSet(1/sqrt(2)) + +def test_raised_cosine(): + mu = Symbol("mu", real=True) + s = Symbol("s", positive=True) + + X = RaisedCosine("x", mu, s) + + assert pspace(X).domain.set == Interval(mu - s, mu + s) + #Tests characteristics_function + assert characteristic_function(X)(x) == \ + Piecewise((exp(-I*pi*mu/s)/2, Eq(x, -pi/s)), (exp(I*pi*mu/s)/2, Eq(x, pi/s)), (pi**2*exp(I*mu*x)*sin(s*x)/(s*x*(-s**2*x**2 + pi**2)), True)) + + assert density(X)(x) == (Piecewise(((cos(pi*(x - mu)/s) + 1)/(2*s), + And(x <= mu + s, mu - s <= x)), (0, True))) + + +def test_rayleigh(): + sigma = Symbol("sigma", positive=True) + + X = Rayleigh('x', sigma) + + #Tests characteristic_function + assert characteristic_function(X)(x) == (-sqrt(2)*sqrt(pi)*sigma*x*(erfi(sqrt(2)*sigma*x/2) - I)*exp(-sigma**2*x**2/2)/2 + 1) + + assert density(X)(x) == x*exp(-x**2/(2*sigma**2))/sigma**2 + assert E(X) == sqrt(2)*sqrt(pi)*sigma/2 + assert variance(X) == -pi*sigma**2/2 + 2*sigma**2 + assert cdf(X)(x) == 1 - exp(-x**2/(2*sigma**2)) + assert diff(cdf(X)(x), x) == density(X)(x) + +def test_reciprocal(): + a = Symbol("a", real=True) + b = Symbol("b", real=True) + + X = Reciprocal('x', a, b) + assert density(X)(x) == 1/(x*(-log(a) + log(b))) + assert cdf(X)(x) == Piecewise((log(a)/(log(a) - log(b)) - log(x)/(log(a) - log(b)), a <= x), (0, True)) + X = Reciprocal('x', 5, 30) + + assert E(X) == 25/(log(30) - log(5)) + assert P(X < 4) == S.Zero + assert P(X < 20) == log(20) / (log(30) - log(5)) - log(5) / (log(30) - log(5)) + assert cdf(X)(10) == log(10) / (log(30) - log(5)) - log(5) / (log(30) - log(5)) + + a = symbols('a', nonpositive=True) + raises(ValueError, lambda: Reciprocal('x', a, b)) + + a = symbols('a', positive=True) + b = symbols('b', positive=True) + raises(ValueError, lambda: Reciprocal('x', a + b, a)) + +def test_shiftedgompertz(): + b = Symbol("b", positive=True) + eta = Symbol("eta", positive=True) + X = ShiftedGompertz("x", b, eta) + assert density(X)(x) == b*(eta*(1 - exp(-b*x)) + 1)*exp(-b*x)*exp(-eta*exp(-b*x)) + + +def test_studentt(): + nu = Symbol("nu", positive=True) + + X = StudentT('x', nu) + assert density(X)(x) == (1 + x**2/nu)**(-nu/2 - S.Half)/(sqrt(nu)*beta(S.Half, nu/2)) + assert cdf(X)(x) == S.Half + x*gamma(nu/2 + S.Half)*hyper((S.Half, nu/2 + S.Half), + (Rational(3, 2),), -x**2/nu)/(sqrt(pi)*sqrt(nu)*gamma(nu/2)) + raises(NotImplementedError, lambda: moment_generating_function(X)) + +def test_trapezoidal(): + a = Symbol("a", real=True) + b = Symbol("b", real=True) + c = Symbol("c", real=True) + d = Symbol("d", real=True) + + X = Trapezoidal('x', a, b, c, d) + assert density(X)(x) == Piecewise(((-2*a + 2*x)/((-a + b)*(-a - b + c + d)), (a <= x) & (x < b)), + (2/(-a - b + c + d), (b <= x) & (x < c)), + ((2*d - 2*x)/((-c + d)*(-a - b + c + d)), (c <= x) & (x <= d)), + (0, True)) + + X = Trapezoidal('x', 0, 1, 2, 3) + assert E(X) == Rational(3, 2) + assert variance(X) == Rational(5, 12) + assert P(X < 2) == Rational(3, 4) + assert median(X) == FiniteSet(Rational(3, 2)) + +def test_triangular(): + a = Symbol("a") + b = Symbol("b") + c = Symbol("c") + + X = Triangular('x', a, b, c) + assert pspace(X).domain.set == Interval(a, b) + assert str(density(X)(x)) == ("Piecewise(((-2*a + 2*x)/((-a + b)*(-a + c)), (a <= x) & (c > x)), " + "(2/(-a + b), Eq(c, x)), ((2*b - 2*x)/((-a + b)*(b - c)), (b >= x) & (c < x)), (0, True))") + + #Tests moment_generating_function + assert moment_generating_function(X)(x).expand() == \ + ((-2*(-a + b)*exp(c*x) + 2*(-a + c)*exp(b*x) + 2*(b - c)*exp(a*x))/(x**2*(-a + b)*(-a + c)*(b - c))).expand() + assert str(characteristic_function(X)(x)) == \ + '(2*(-a + b)*exp(I*c*x) - 2*(-a + c)*exp(I*b*x) - 2*(b - c)*exp(I*a*x))/(x**2*(-a + b)*(-a + c)*(b - c))' + +def test_quadratic_u(): + a = Symbol("a", real=True) + b = Symbol("b", real=True) + + X = QuadraticU("x", a, b) + Y = QuadraticU("x", 1, 2) + + assert pspace(X).domain.set == Interval(a, b) + # Tests _moment_generating_function + assert moment_generating_function(Y)(1) == -15*exp(2) + 27*exp(1) + assert moment_generating_function(Y)(2) == -9*exp(4)/2 + 21*exp(2)/2 + + assert characteristic_function(Y)(1) == 3*I*(-1 + 4*I)*exp(I*exp(2*I)) + assert density(X)(x) == (Piecewise((12*(x - a/2 - b/2)**2/(-a + b)**3, + And(x <= b, a <= x)), (0, True))) + + +def test_uniform(): + l = Symbol('l', real=True) + w = Symbol('w', positive=True) + X = Uniform('x', l, l + w) + + assert E(X) == l + w/2 + assert variance(X).expand() == w**2/12 + + # With numbers all is well + X = Uniform('x', 3, 5) + assert P(X < 3) == 0 and P(X > 5) == 0 + assert P(X < 4) == P(X > 4) == S.Half + assert median(X) == FiniteSet(4) + + z = Symbol('z') + p = density(X)(z) + assert p.subs(z, 3.7) == S.Half + assert p.subs(z, -1) == 0 + assert p.subs(z, 6) == 0 + + c = cdf(X) + assert c(2) == 0 and c(3) == 0 + assert c(Rational(7, 2)) == Rational(1, 4) + assert c(5) == 1 and c(6) == 1 + + +@XFAIL +@slow +def test_uniform_P(): + """ This stopped working because SingleContinuousPSpace.compute_density no + longer calls integrate on a DiracDelta but rather just solves directly. + integrate used to call UniformDistribution.expectation which special-cased + subsed out the Min and Max terms that Uniform produces + + I decided to regress on this class for general cleanliness (and I suspect + speed) of the algorithm. + """ + l = Symbol('l', real=True) + w = Symbol('w', positive=True) + X = Uniform('x', l, l + w) + assert P(X < l) == 0 and P(X > l + w) == 0 + + +def test_uniformsum(): + n = Symbol("n", integer=True) + _k = Dummy("k") + x = Symbol("x") + + X = UniformSum('x', n) + res = Sum((-1)**_k*(-_k + x)**(n - 1)*binomial(n, _k), (_k, 0, floor(x)))/factorial(n - 1) + assert density(X)(x).dummy_eq(res) + + #Tests set functions + assert X.pspace.domain.set == Interval(0, n) + + #Tests the characteristic_function + assert characteristic_function(X)(x) == (-I*(exp(I*x) - 1)/x)**n + + #Tests the moment_generating_function + assert moment_generating_function(X)(x) == ((exp(x) - 1)/x)**n + + +def test_von_mises(): + mu = Symbol("mu") + k = Symbol("k", positive=True) + + X = VonMises("x", mu, k) + assert density(X)(x) == exp(k*cos(x - mu))/(2*pi*besseli(0, k)) + + +def test_weibull(): + a, b = symbols('a b', positive=True) + # FIXME: simplify(E(X)) seems to hang without extended_positive=True + # On a Linux machine this had a rapid memory leak... + # a, b = symbols('a b', positive=True) + X = Weibull('x', a, b) + + assert E(X).expand() == a * gamma(1 + 1/b) + assert variance(X).expand() == (a**2 * gamma(1 + 2/b) - E(X)**2).expand() + assert simplify(skewness(X)) == (2*gamma(1 + 1/b)**3 - 3*gamma(1 + 1/b)*gamma(1 + 2/b) + gamma(1 + 3/b))/(-gamma(1 + 1/b)**2 + gamma(1 + 2/b))**Rational(3, 2) + assert simplify(kurtosis(X)) == (-3*gamma(1 + 1/b)**4 +\ + 6*gamma(1 + 1/b)**2*gamma(1 + 2/b) - 4*gamma(1 + 1/b)*gamma(1 + 3/b) + gamma(1 + 4/b))/(gamma(1 + 1/b)**2 - gamma(1 + 2/b))**2 + +def test_weibull_numeric(): + # Test for integers and rationals + a = 1 + bvals = [S.Half, 1, Rational(3, 2), 5] + for b in bvals: + X = Weibull('x', a, b) + assert simplify(E(X)) == expand_func(a * gamma(1 + 1/S(b))) + assert simplify(variance(X)) == simplify( + a**2 * gamma(1 + 2/S(b)) - E(X)**2) + # Not testing Skew... it's slow with int/frac values > 3/2 + + +def test_wignersemicircle(): + R = Symbol("R", positive=True) + + X = WignerSemicircle('x', R) + assert pspace(X).domain.set == Interval(-R, R) + assert density(X)(x) == 2*sqrt(-x**2 + R**2)/(pi*R**2) + assert E(X) == 0 + + + #Tests ChiNoncentralDistribution + assert characteristic_function(X)(x) == \ + Piecewise((2*besselj(1, R*x)/(R*x), Ne(x, 0)), (1, True)) + + +def test_input_value_assertions(): + a, b = symbols('a b') + p, q = symbols('p q', positive=True) + m, n = symbols('m n', positive=False, real=True) + + raises(ValueError, lambda: Normal('x', 3, 0)) + raises(ValueError, lambda: Normal('x', m, n)) + Normal('X', a, p) # No error raised + raises(ValueError, lambda: Exponential('x', m)) + Exponential('Ex', p) # No error raised + for fn in [Pareto, Weibull, Beta, Gamma]: + raises(ValueError, lambda: fn('x', m, p)) + raises(ValueError, lambda: fn('x', p, n)) + fn('x', p, q) # No error raised + + +def test_unevaluated(): + X = Normal('x', 0, 1) + k = Dummy('k') + expr1 = Integral(sqrt(2)*k*exp(-k**2/2)/(2*sqrt(pi)), (k, -oo, oo)) + expr2 = Integral(sqrt(2)*exp(-k**2/2)/(2*sqrt(pi)), (k, 0, oo)) + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert E(X, evaluate=False).rewrite(Integral).dummy_eq(expr1) + assert E(X + 1, evaluate=False).rewrite(Integral).dummy_eq(expr1 + 1) + assert P(X > 0, evaluate=False).rewrite(Integral).dummy_eq(expr2) + + assert P(X > 0, X**2 < 1) == S.Half + + +def test_probability_unevaluated(): + T = Normal('T', 30, 3) + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert type(P(T > 33, evaluate=False)) == Probability + + +def test_density_unevaluated(): + X = Normal('X', 0, 1) + Y = Normal('Y', 0, 2) + assert isinstance(density(X+Y, evaluate=False)(z), Integral) + + +def test_NormalDistribution(): + nd = NormalDistribution(0, 1) + x = Symbol('x') + assert nd.cdf(x) == erf(sqrt(2)*x/2)/2 + S.Half + assert nd.expectation(1, x) == 1 + assert nd.expectation(x, x) == 0 + assert nd.expectation(x**2, x) == 1 + #Test issue 10076 + a = SingleContinuousPSpace(x, NormalDistribution(2, 4)) + _z = Dummy('_z') + + expected1 = Integral(sqrt(2)*exp(-(_z - 2)**2/32)/(8*sqrt(pi)),(_z, -oo, 1)) + assert a.probability(x < 1, evaluate=False).dummy_eq(expected1) is True + + expected2 = Integral(sqrt(2)*exp(-(_z - 2)**2/32)/(8*sqrt(pi)),(_z, 1, oo)) + assert a.probability(x > 1, evaluate=False).dummy_eq(expected2) is True + + b = SingleContinuousPSpace(x, NormalDistribution(1, 9)) + + expected3 = Integral(sqrt(2)*exp(-(_z - 1)**2/162)/(18*sqrt(pi)),(_z, 6, oo)) + assert b.probability(x > 6, evaluate=False).dummy_eq(expected3) is True + + expected4 = Integral(sqrt(2)*exp(-(_z - 1)**2/162)/(18*sqrt(pi)),(_z, -oo, 6)) + assert b.probability(x < 6, evaluate=False).dummy_eq(expected4) is True + + +def test_random_parameters(): + mu = Normal('mu', 2, 3) + meas = Normal('T', mu, 1) + assert density(meas, evaluate=False)(z) + assert isinstance(pspace(meas), CompoundPSpace) + X = Normal('x', [1, 2], [[1, 0], [0, 1]]) + assert isinstance(pspace(X).distribution, MultivariateNormalDistribution) + assert density(meas)(z).simplify() == sqrt(5)*exp(-z**2/20 + z/5 - S(1)/5)/(10*sqrt(pi)) + + +def test_random_parameters_given(): + mu = Normal('mu', 2, 3) + meas = Normal('T', mu, 1) + assert given(meas, Eq(mu, 5)) == Normal('T', 5, 1) + + +def test_conjugate_priors(): + mu = Normal('mu', 2, 3) + x = Normal('x', mu, 1) + assert isinstance(simplify(density(mu, Eq(x, y), evaluate=False)(z)), + Mul) + + +def test_difficult_univariate(): + """ Since using solve in place of deltaintegrate we're able to perform + substantially more complex density computations on single continuous random + variables """ + x = Normal('x', 0, 1) + assert density(x**3) + assert density(exp(x**2)) + assert density(log(x)) + + +def test_issue_10003(): + X = Exponential('x', 3) + G = Gamma('g', 1, 2) + assert P(X < -1) is S.Zero + assert P(G < -1) is S.Zero + + +def test_precomputed_cdf(): + x = symbols("x", real=True) + mu = symbols("mu", real=True) + sigma, xm, alpha = symbols("sigma xm alpha", positive=True) + n = symbols("n", integer=True, positive=True) + distribs = [ + Normal("X", mu, sigma), + Pareto("P", xm, alpha), + ChiSquared("C", n), + Exponential("E", sigma), + # LogNormal("L", mu, sigma), + ] + for X in distribs: + compdiff = cdf(X)(x) - simplify(X.pspace.density.compute_cdf()(x)) + compdiff = simplify(compdiff.rewrite(erfc)) + assert compdiff == 0 + + +@slow +def test_precomputed_characteristic_functions(): + import mpmath + + def test_cf(dist, support_lower_limit, support_upper_limit): + pdf = density(dist) + t = Symbol('t') + + # first function is the hardcoded CF of the distribution + cf1 = lambdify([t], characteristic_function(dist)(t), 'mpmath') + + # second function is the Fourier transform of the density function + f = lambdify([x, t], pdf(x)*exp(I*x*t), 'mpmath') + cf2 = lambda t: mpmath.quad(lambda x: f(x, t), [support_lower_limit, support_upper_limit], maxdegree=10) + + # compare the two functions at various points + for test_point in [2, 5, 8, 11]: + n1 = cf1(test_point) + n2 = cf2(test_point) + + assert abs(re(n1) - re(n2)) < 1e-12 + assert abs(im(n1) - im(n2)) < 1e-12 + + test_cf(Beta('b', 1, 2), 0, 1) + test_cf(Chi('c', 3), 0, mpmath.inf) + test_cf(ChiSquared('c', 2), 0, mpmath.inf) + test_cf(Exponential('e', 6), 0, mpmath.inf) + test_cf(Logistic('l', 1, 2), -mpmath.inf, mpmath.inf) + test_cf(Normal('n', -1, 5), -mpmath.inf, mpmath.inf) + test_cf(RaisedCosine('r', 3, 1), 2, 4) + test_cf(Rayleigh('r', 0.5), 0, mpmath.inf) + test_cf(Uniform('u', -1, 1), -1, 1) + test_cf(WignerSemicircle('w', 3), -3, 3) + + +def test_long_precomputed_cdf(): + x = symbols("x", real=True) + distribs = [ + Arcsin("A", -5, 9), + Dagum("D", 4, 10, 3), + Erlang("E", 14, 5), + Frechet("F", 2, 6, -3), + Gamma("G", 2, 7), + GammaInverse("GI", 3, 5), + Kumaraswamy("K", 6, 8), + Laplace("LA", -5, 4), + Logistic("L", -6, 7), + Nakagami("N", 2, 7), + StudentT("S", 4) + ] + for distr in distribs: + for _ in range(5): + assert tn(diff(cdf(distr)(x), x), density(distr)(x), x, a=0, b=0, c=1, d=0) + + US = UniformSum("US", 5) + pdf01 = density(US)(x).subs(floor(x), 0).doit() # pdf on (0, 1) + cdf01 = cdf(US, evaluate=False)(x).subs(floor(x), 0).doit() # cdf on (0, 1) + assert tn(diff(cdf01, x), pdf01, x, a=0, b=0, c=1, d=0) + + +def test_issue_13324(): + X = Uniform('X', 0, 1) + assert E(X, X > S.Half) == Rational(3, 4) + assert E(X, X > 0) == S.Half + +def test_issue_20756(): + X = Uniform('X', -1, +1) + Y = Uniform('Y', -1, +1) + assert E(X * Y) == S.Zero + assert E(X * ((Y + 1) - 1)) == S.Zero + assert E(Y * (X*(X + 1) - X*X)) == S.Zero + +def test_FiniteSet_prob(): + E = Exponential('E', 3) + N = Normal('N', 5, 7) + assert P(Eq(E, 1)) is S.Zero + assert P(Eq(N, 2)) is S.Zero + assert P(Eq(N, x)) is S.Zero + +def test_prob_neq(): + E = Exponential('E', 4) + X = ChiSquared('X', 4) + assert P(Ne(E, 2)) == 1 + assert P(Ne(X, 4)) == 1 + assert P(Ne(X, 4)) == 1 + assert P(Ne(X, 5)) == 1 + assert P(Ne(E, x)) == 1 + +def test_union(): + N = Normal('N', 3, 2) + assert simplify(P(N**2 - N > 2)) == \ + -erf(sqrt(2))/2 - erfc(sqrt(2)/4)/2 + Rational(3, 2) + assert simplify(P(N**2 - 4 > 0)) == \ + -erf(5*sqrt(2)/4)/2 - erfc(sqrt(2)/4)/2 + Rational(3, 2) + +def test_Or(): + N = Normal('N', 0, 1) + assert simplify(P(Or(N > 2, N < 1))) == \ + -erf(sqrt(2))/2 - erfc(sqrt(2)/2)/2 + Rational(3, 2) + assert P(Or(N < 0, N < 1)) == P(N < 1) + assert P(Or(N > 0, N < 0)) == 1 + + +def test_conditional_eq(): + E = Exponential('E', 1) + assert P(Eq(E, 1), Eq(E, 1)) == 1 + assert P(Eq(E, 1), Eq(E, 2)) == 0 + assert P(E > 1, Eq(E, 2)) == 1 + assert P(E < 1, Eq(E, 2)) == 0 + +def test_ContinuousDistributionHandmade(): + x = Symbol('x') + z = Dummy('z') + dens = Lambda(x, Piecewise((S.Half, (0<=x)&(x<1)), (0, (x>=1)&(x<2)), + (S.Half, (x>=2)&(x<3)), (0, True))) + dens = ContinuousDistributionHandmade(dens, set=Interval(0, 3)) + space = SingleContinuousPSpace(z, dens) + assert dens.pdf == Lambda(x, Piecewise((S(1)/2, (x >= 0) & (x < 1)), + (0, (x >= 1) & (x < 2)), (S(1)/2, (x >= 2) & (x < 3)), (0, True))) + assert median(space.value) == Interval(1, 2) + assert E(space.value) == Rational(3, 2) + assert variance(space.value) == Rational(13, 12) + + +def test_issue_16318(): + # test compute_expectation function of the SingleContinuousDomain + N = SingleContinuousDomain(x, Interval(0, 1)) + raises(ValueError, lambda: SingleContinuousDomain.compute_expectation(N, x+1, {x, y})) + +def test_compute_density(): + X = Normal('X', 0, Symbol("sigma")**2) + raises(ValueError, lambda: density(X**5 + X)) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_discrete_rv.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_discrete_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..39650a1a08e42840aaf8b06c1eb6e60c92a57f23 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_discrete_rv.py @@ -0,0 +1,312 @@ +from sympy.concrete.summations import Sum +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.bessel import besseli +from sympy.functions.special.beta_functions import beta +from sympy.functions.special.zeta_functions import zeta +from sympy.sets.sets import FiniteSet +from sympy.simplify.simplify import simplify +from sympy.utilities.lambdify import lambdify +from sympy.core.relational import Eq, Ne +from sympy.functions.elementary.exponential import exp +from sympy.logic.boolalg import Or +from sympy.sets.fancysets import Range +from sympy.stats import (P, E, variance, density, characteristic_function, + where, moment_generating_function, skewness, cdf, + kurtosis, coskewness) +from sympy.stats.drv_types import (PoissonDistribution, GeometricDistribution, + FlorySchulz, Poisson, Geometric, Hermite, Logarithmic, + NegativeBinomial, Skellam, YuleSimon, Zeta, + DiscreteRV) +from sympy.testing.pytest import slow, nocache_fail, raises, skip +from sympy.stats.symbolic_probability import Expectation +from sympy.functions.combinatorial.factorials import FallingFactorial + +x = Symbol('x') + + +def test_PoissonDistribution(): + l = 3 + p = PoissonDistribution(l) + assert abs(p.cdf(10).evalf() - 1) < .001 + assert abs(p.cdf(10.4).evalf() - 1) < .001 + assert p.expectation(x, x) == l + assert p.expectation(x**2, x) - p.expectation(x, x)**2 == l + + +def test_Poisson(): + l = 3 + x = Poisson('x', l) + assert E(x) == l + assert E(2*x) == 2*l + assert variance(x) == l + assert density(x) == PoissonDistribution(l) + assert isinstance(E(x, evaluate=False), Expectation) + assert isinstance(E(2*x, evaluate=False), Expectation) + # issue 8248 + assert x.pspace.compute_expectation(1) == 1 + # issue 27344 + try: + import numpy as np + except ImportError: + skip("numpy not installed") + y = Poisson('y', np.float64(4.72544290380919e-11)) + assert E(y) == 4.72544290380919e-11 + y = Poisson('y', np.float64(4.725442903809197e-11)) + assert E(y) == 4.725442903809197e-11 + l2 = 5 + z = Poisson('z', l2) + assert E(z) == l2 + assert E(FallingFactorial(z, 3)) == l2**3 + assert E(z**2) == l2 + l2**2 + + +def test_FlorySchulz(): + a = Symbol("a") + z = Symbol("z") + x = FlorySchulz('x', a) + assert E(x) == (2 - a)/a + assert (variance(x) - 2*(1 - a)/a**2).simplify() == S(0) + assert density(x)(z) == a**2*z*(1 - a)**(z - 1) + + +@slow +def test_GeometricDistribution(): + p = S.One / 5 + d = GeometricDistribution(p) + assert d.expectation(x, x) == 1/p + assert d.expectation(x**2, x) - d.expectation(x, x)**2 == (1-p)/p**2 + assert abs(d.cdf(20000).evalf() - 1) < .001 + assert abs(d.cdf(20000.8).evalf() - 1) < .001 + G = Geometric('G', p=S(1)/4) + assert cdf(G)(S(7)/2) == P(G <= S(7)/2) + + X = Geometric('X', Rational(1, 5)) + Y = Geometric('Y', Rational(3, 10)) + assert coskewness(X, X + Y, X + 2*Y).simplify() == sqrt(230)*Rational(81, 1150) + + +def test_Hermite(): + a1 = Symbol("a1", positive=True) + a2 = Symbol("a2", negative=True) + raises(ValueError, lambda: Hermite("H", a1, a2)) + + a1 = Symbol("a1", negative=True) + a2 = Symbol("a2", positive=True) + raises(ValueError, lambda: Hermite("H", a1, a2)) + + a1 = Symbol("a1", positive=True) + x = Symbol("x") + H = Hermite("H", a1, a2) + assert moment_generating_function(H)(x) == exp(a1*(exp(x) - 1) + + a2*(exp(2*x) - 1)) + assert characteristic_function(H)(x) == exp(a1*(exp(I*x) - 1) + + a2*(exp(2*I*x) - 1)) + assert E(H) == a1 + 2*a2 + + H = Hermite("H", a1=5, a2=4) + assert density(H)(2) == 33*exp(-9)/2 + assert E(H) == 13 + assert variance(H) == 21 + assert kurtosis(H) == Rational(464,147) + assert skewness(H) == 37*sqrt(21)/441 + +def test_Logarithmic(): + p = S.Half + x = Logarithmic('x', p) + assert E(x) == -p / ((1 - p) * log(1 - p)) + assert variance(x) == -1/log(2)**2 + 2/log(2) + assert E(2*x**2 + 3*x + 4) == 4 + 7 / log(2) + assert isinstance(E(x, evaluate=False), Expectation) + + +@nocache_fail +def test_negative_binomial(): + r = 5 + p = S.One / 3 + x = NegativeBinomial('x', r, p) + assert E(x) == r * (1 - p) / p + # This hangs when run with the cache disabled: + assert variance(x) == r * (1 - p) / p**2 + assert E(x**5 + 2*x + 3) == E(x**5) + 2*E(x) + 3 == Rational(796473, 1) + assert isinstance(E(x, evaluate=False), Expectation) + + +def test_skellam(): + mu1 = Symbol('mu1') + mu2 = Symbol('mu2') + z = Symbol('z') + X = Skellam('x', mu1, mu2) + + assert density(X)(z) == (mu1/mu2)**(z/2) * \ + exp(-mu1 - mu2)*besseli(z, 2*sqrt(mu1*mu2)) + assert skewness(X).expand() == mu1/(mu1*sqrt(mu1 + mu2) + mu2 * + sqrt(mu1 + mu2)) - mu2/(mu1*sqrt(mu1 + mu2) + mu2*sqrt(mu1 + mu2)) + assert variance(X).expand() == mu1 + mu2 + assert E(X) == mu1 - mu2 + assert characteristic_function(X)(z) == exp( + mu1*exp(I*z) - mu1 - mu2 + mu2*exp(-I*z)) + assert moment_generating_function(X)(z) == exp( + mu1*exp(z) - mu1 - mu2 + mu2*exp(-z)) + + +def test_yule_simon(): + from sympy.core.singleton import S + rho = S(3) + x = YuleSimon('x', rho) + assert simplify(E(x)) == rho / (rho - 1) + assert simplify(variance(x)) == rho**2 / ((rho - 1)**2 * (rho - 2)) + assert isinstance(E(x, evaluate=False), Expectation) + # To test the cdf function + assert cdf(x)(x) == Piecewise((-beta(floor(x), 4)*floor(x) + 1, x >= 1), (0, True)) + + +def test_zeta(): + s = S(5) + x = Zeta('x', s) + assert E(x) == zeta(s-1) / zeta(s) + assert simplify(variance(x)) == ( + zeta(s) * zeta(s-2) - zeta(s-1)**2) / zeta(s)**2 + + +def test_discrete_probability(): + X = Geometric('X', Rational(1, 5)) + Y = Poisson('Y', 4) + G = Geometric('e', x) + assert P(Eq(X, 3)) == Rational(16, 125) + assert P(X < 3) == Rational(9, 25) + assert P(X > 3) == Rational(64, 125) + assert P(X >= 3) == Rational(16, 25) + assert P(X <= 3) == Rational(61, 125) + assert P(Ne(X, 3)) == Rational(109, 125) + assert P(Eq(Y, 3)) == 32*exp(-4)/3 + assert P(Y < 3) == 13*exp(-4) + assert P(Y > 3).equals(32*(Rational(-71, 32) + 3*exp(4)/32)*exp(-4)/3) + assert P(Y >= 3).equals(32*(Rational(-39, 32) + 3*exp(4)/32)*exp(-4)/3) + assert P(Y <= 3) == 71*exp(-4)/3 + assert P(Ne(Y, 3)).equals( + 13*exp(-4) + 32*(Rational(-71, 32) + 3*exp(4)/32)*exp(-4)/3) + assert P(X < S.Infinity) is S.One + assert P(X > S.Infinity) is S.Zero + assert P(G < 3) == x*(2-x) + assert P(Eq(G, 3)) == x*(-x + 1)**2 + + +def test_DiscreteRV(): + p = S(1)/2 + x = Symbol('x', integer=True, positive=True) + pdf = p*(1 - p)**(x - 1) # pdf of Geometric Distribution + D = DiscreteRV(x, pdf, set=S.Naturals, check=True) + assert E(D) == E(Geometric('G', S(1)/2)) == 2 + assert P(D > 3) == S(1)/8 + assert D.pspace.domain.set == S.Naturals + raises(ValueError, lambda: DiscreteRV(x, x, FiniteSet(*range(4)), check=True)) + + # purposeful invalid pmf but it should not raise since check=False + # see test_drv_types.test_ContinuousRV for explanation + X = DiscreteRV(x, 1/x, S.Naturals) + assert P(X < 2) == 1 + assert E(X) == oo + +def test_precomputed_characteristic_functions(): + import mpmath + + def test_cf(dist, support_lower_limit, support_upper_limit): + pdf = density(dist) + t = S('t') + x = S('x') + + # first function is the hardcoded CF of the distribution + cf1 = lambdify([t], characteristic_function(dist)(t), 'mpmath') + + # second function is the Fourier transform of the density function + f = lambdify([x, t], pdf(x)*exp(I*x*t), 'mpmath') + cf2 = lambda t: mpmath.nsum(lambda x: f(x, t), [ + support_lower_limit, support_upper_limit], maxdegree=10) + + # compare the two functions at various points + for test_point in [2, 5, 8, 11]: + n1 = cf1(test_point) + n2 = cf2(test_point) + + assert abs(re(n1) - re(n2)) < 1e-12 + assert abs(im(n1) - im(n2)) < 1e-12 + + test_cf(Geometric('g', Rational(1, 3)), 1, mpmath.inf) + test_cf(Logarithmic('l', Rational(1, 5)), 1, mpmath.inf) + test_cf(NegativeBinomial('n', 5, Rational(1, 7)), 0, mpmath.inf) + test_cf(Poisson('p', 5), 0, mpmath.inf) + test_cf(YuleSimon('y', 5), 1, mpmath.inf) + test_cf(Zeta('z', 5), 1, mpmath.inf) + + +def test_moment_generating_functions(): + t = S('t') + + geometric_mgf = moment_generating_function(Geometric('g', S.Half))(t) + assert geometric_mgf.diff(t).subs(t, 0) == 2 + + logarithmic_mgf = moment_generating_function(Logarithmic('l', S.Half))(t) + assert logarithmic_mgf.diff(t).subs(t, 0) == 1/log(2) + + negative_binomial_mgf = moment_generating_function( + NegativeBinomial('n', 5, Rational(1, 3)))(t) + assert negative_binomial_mgf.diff(t).subs(t, 0) == Rational(10, 1) + + poisson_mgf = moment_generating_function(Poisson('p', 5))(t) + assert poisson_mgf.diff(t).subs(t, 0) == 5 + + skellam_mgf = moment_generating_function(Skellam('s', 1, 1))(t) + assert skellam_mgf.diff(t).subs( + t, 2) == (-exp(-2) + exp(2))*exp(-2 + exp(-2) + exp(2)) + + yule_simon_mgf = moment_generating_function(YuleSimon('y', 3))(t) + assert simplify(yule_simon_mgf.diff(t).subs(t, 0)) == Rational(3, 2) + + zeta_mgf = moment_generating_function(Zeta('z', 5))(t) + assert zeta_mgf.diff(t).subs(t, 0) == pi**4/(90*zeta(5)) + + +def test_Or(): + X = Geometric('X', S.Half) + assert P(Or(X < 3, X > 4)) == Rational(13, 16) + assert P(Or(X > 2, X > 1)) == P(X > 1) + assert P(Or(X >= 3, X < 3)) == 1 + + +def test_where(): + X = Geometric('X', Rational(1, 5)) + Y = Poisson('Y', 4) + assert where(X**2 > 4).set == Range(3, S.Infinity, 1) + assert where(X**2 >= 4).set == Range(2, S.Infinity, 1) + assert where(Y**2 < 9).set == Range(0, 3, 1) + assert where(Y**2 <= 9).set == Range(0, 4, 1) + + +def test_conditional(): + X = Geometric('X', Rational(2, 3)) + Y = Poisson('Y', 3) + assert P(X > 2, X > 3) == 1 + assert P(X > 3, X > 2) == Rational(1, 3) + assert P(Y > 2, Y < 2) == 0 + assert P(Eq(Y, 3), Y >= 0) == 9*exp(-3)/2 + assert P(Eq(Y, 3), Eq(Y, 2)) == 0 + assert P(X < 2, Eq(X, 2)) == 0 + assert P(X > 2, Eq(X, 3)) == 1 + + +def test_product_spaces(): + X1 = Geometric('X1', S.Half) + X2 = Geometric('X2', Rational(1, 3)) + assert str(P(X1 + X2 < 3).rewrite(Sum)) == ( + "Sum(Piecewise((1/(4*2**n), n >= -1), (0, True)), (n, -oo, -1))/3") + assert str(P(X1 + X2 > 3).rewrite(Sum)) == ( + 'Sum(Piecewise((2**(X2 - n - 2)*(2/3)**(X2 - 1)/6, ' + 'X2 - n <= 2), (0, True)), (X2, 1, oo), (n, 1, oo))') + assert P(Eq(X1 + X2, 3)) == Rational(1, 12) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_error_prop.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_error_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..483fb4c36e202d744faeb355606ff9803a516873 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_error_prop.py @@ -0,0 +1,60 @@ +from sympy.core.function import Function +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.stats.error_prop import variance_prop +from sympy.stats.symbolic_probability import (RandomSymbol, Variance, + Covariance) + + +def test_variance_prop(): + x, y, z = symbols('x y z') + phi, t = consts = symbols('phi t') + a = RandomSymbol(x) + var_x = Variance(a) + var_y = Variance(RandomSymbol(y)) + var_z = Variance(RandomSymbol(z)) + f = Function('f')(x) + cases = { + x + y: var_x + var_y, + a + y: var_x + var_y, + x + y + z: var_x + var_y + var_z, + 2*x: 4*var_x, + x*y: var_x*y**2 + var_y*x**2, + 1/x: var_x/x**4, + x/y: (var_x*y**2 + var_y*x**2)/y**4, + exp(x): var_x*exp(2*x), + exp(2*x): 4*var_x*exp(4*x), + exp(-x*t): t**2*var_x*exp(-2*t*x), + f: Variance(f), + } + for inp, out in cases.items(): + obs = variance_prop(inp, consts=consts) + assert out == obs + +def test_variance_prop_with_covar(): + x, y, z = symbols('x y z') + phi, t = consts = symbols('phi t') + a = RandomSymbol(x) + var_x = Variance(a) + b = RandomSymbol(y) + var_y = Variance(b) + c = RandomSymbol(z) + var_z = Variance(c) + covar_x_y = Covariance(a, b) + covar_x_z = Covariance(a, c) + covar_y_z = Covariance(b, c) + cases = { + x + y: var_x + var_y + 2*covar_x_y, + a + y: var_x + var_y + 2*covar_x_y, + x + y + z: var_x + var_y + var_z + \ + 2*covar_x_y + 2*covar_x_z + 2*covar_y_z, + 2*x: 4*var_x, + x*y: var_x*y**2 + var_y*x**2 + 2*covar_x_y/(x*y), + 1/x: var_x/x**4, + exp(x): var_x*exp(2*x), + exp(2*x): 4*var_x*exp(4*x), + exp(-x*t): t**2*var_x*exp(-2*t*x), + } + for inp, out in cases.items(): + obs = variance_prop(inp, consts=consts, include_covar=True) + assert out == obs diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_finite_rv.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_finite_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..93bf0211a26ecc32d7f18c7e2d8236859857e445 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_finite_rv.py @@ -0,0 +1,509 @@ +from sympy.concrete.summations import Sum +from sympy.core.containers import (Dict, Tuple) +from sympy.core.function import Function +from sympy.core.numbers import (I, Rational, nan) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.combinatorial.numbers import harmonic +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import cos +from sympy.functions.special.beta_functions import beta +from sympy.logic.boolalg import (And, Or) +from sympy.polys.polytools import cancel +from sympy.sets.sets import FiniteSet +from sympy.simplify.simplify import simplify +from sympy.matrices import Matrix +from sympy.stats import (DiscreteUniform, Die, Bernoulli, Coin, Binomial, BetaBinomial, + Hypergeometric, Rademacher, IdealSoliton, RobustSoliton, P, E, variance, + covariance, skewness, density, where, FiniteRV, pspace, cdf, + correlation, moment, cmoment, smoment, characteristic_function, + moment_generating_function, quantile, kurtosis, median, coskewness) +from sympy.stats.frv_types import DieDistribution, BinomialDistribution, \ + HypergeometricDistribution +from sympy.stats.rv import Density +from sympy.testing.pytest import raises + + +def BayesTest(A, B): + assert P(A, B) == P(And(A, B)) / P(B) + assert P(A, B) == P(B, A) * P(A) / P(B) + + +def test_discreteuniform(): + # Symbolic + a, b, c, t = symbols('a b c t') + X = DiscreteUniform('X', [a, b, c]) + + assert E(X) == (a + b + c)/3 + assert simplify(variance(X) + - ((a**2 + b**2 + c**2)/3 - (a/3 + b/3 + c/3)**2)) == 0 + assert P(Eq(X, a)) == P(Eq(X, b)) == P(Eq(X, c)) == S('1/3') + + Y = DiscreteUniform('Y', range(-5, 5)) + + # Numeric + assert E(Y) == S('-1/2') + assert variance(Y) == S('33/4') + assert median(Y) == FiniteSet(-1, 0) + + for x in range(-5, 5): + assert P(Eq(Y, x)) == S('1/10') + assert P(Y <= x) == S(x + 6)/10 + assert P(Y >= x) == S(5 - x)/10 + + assert dict(density(Die('D', 6)).items()) == \ + dict(density(DiscreteUniform('U', range(1, 7))).items()) + + assert characteristic_function(X)(t) == exp(I*a*t)/3 + exp(I*b*t)/3 + exp(I*c*t)/3 + assert moment_generating_function(X)(t) == exp(a*t)/3 + exp(b*t)/3 + exp(c*t)/3 + # issue 18611 + raises(ValueError, lambda: DiscreteUniform('Z', [a, a, a, b, b, c])) + +def test_dice(): + # TODO: Make iid method! + X, Y, Z = Die('X', 6), Die('Y', 6), Die('Z', 6) + a, b, t, p = symbols('a b t p') + + assert E(X) == 3 + S.Half + assert variance(X) == Rational(35, 12) + assert E(X + Y) == 7 + assert E(X + X) == 7 + assert E(a*X + b) == a*E(X) + b + assert variance(X + Y) == variance(X) + variance(Y) == cmoment(X + Y, 2) + assert variance(X + X) == 4 * variance(X) == cmoment(X + X, 2) + assert cmoment(X, 0) == 1 + assert cmoment(4*X, 3) == 64*cmoment(X, 3) + assert covariance(X, Y) is S.Zero + assert covariance(X, X + Y) == variance(X) + assert density(Eq(cos(X*S.Pi), 1))[True] == S.Half + assert correlation(X, Y) == 0 + assert correlation(X, Y) == correlation(Y, X) + assert smoment(X + Y, 3) == skewness(X + Y) + assert smoment(X + Y, 4) == kurtosis(X + Y) + assert smoment(X, 0) == 1 + assert P(X > 3) == S.Half + assert P(2*X > 6) == S.Half + assert P(X > Y) == Rational(5, 12) + assert P(Eq(X, Y)) == P(Eq(X, 1)) + + assert E(X, X > 3) == 5 == moment(X, 1, 0, X > 3) + assert E(X, Y > 3) == E(X) == moment(X, 1, 0, Y > 3) + assert E(X + Y, Eq(X, Y)) == E(2*X) + assert moment(X, 0) == 1 + assert moment(5*X, 2) == 25*moment(X, 2) + assert quantile(X)(p) == Piecewise((nan, (p > 1) | (p < 0)),\ + (S.One, p <= Rational(1, 6)), (S(2), p <= Rational(1, 3)), (S(3), p <= S.Half),\ + (S(4), p <= Rational(2, 3)), (S(5), p <= Rational(5, 6)), (S(6), p <= 1)) + + assert P(X > 3, X > 3) is S.One + assert P(X > Y, Eq(Y, 6)) is S.Zero + assert P(Eq(X + Y, 12)) == Rational(1, 36) + assert P(Eq(X + Y, 12), Eq(X, 6)) == Rational(1, 6) + + assert density(X + Y) == density(Y + Z) != density(X + X) + d = density(2*X + Y**Z) + assert d[S(22)] == Rational(1, 108) and d[S(4100)] == Rational(1, 216) and S(3130) not in d + + assert pspace(X).domain.as_boolean() == Or( + *[Eq(X.symbol, i) for i in [1, 2, 3, 4, 5, 6]]) + + assert where(X > 3).set == FiniteSet(4, 5, 6) + + assert characteristic_function(X)(t) == exp(6*I*t)/6 + exp(5*I*t)/6 + exp(4*I*t)/6 + exp(3*I*t)/6 + exp(2*I*t)/6 + exp(I*t)/6 + assert moment_generating_function(X)(t) == exp(6*t)/6 + exp(5*t)/6 + exp(4*t)/6 + exp(3*t)/6 + exp(2*t)/6 + exp(t)/6 + assert median(X) == FiniteSet(3, 4) + D = Die('D', 7) + assert median(D) == FiniteSet(4) + # Bayes test for die + BayesTest(X > 3, X + Y < 5) + BayesTest(Eq(X - Y, Z), Z > Y) + BayesTest(X > 3, X > 2) + + # arg test for die + raises(ValueError, lambda: Die('X', -1)) # issue 8105: negative sides. + raises(ValueError, lambda: Die('X', 0)) + raises(ValueError, lambda: Die('X', 1.5)) # issue 8103: non integer sides. + + # symbolic test for die + n, k = symbols('n, k', positive=True) + D = Die('D', n) + dens = density(D).dict + assert dens == Density(DieDistribution(n)) + assert set(dens.subs(n, 4).doit().keys()) == {1, 2, 3, 4} + assert set(dens.subs(n, 4).doit().values()) == {Rational(1, 4)} + k = Dummy('k', integer=True) + assert E(D).dummy_eq( + Sum(Piecewise((k/n, k <= n), (0, True)), (k, 1, n))) + assert variance(D).subs(n, 6).doit() == Rational(35, 12) + + ki = Dummy('ki') + cumuf = cdf(D)(k) + assert cumuf.dummy_eq( + Sum(Piecewise((1/n, (ki >= 1) & (ki <= n)), (0, True)), (ki, 1, k))) + assert cumuf.subs({n: 6, k: 2}).doit() == Rational(1, 3) + + t = Dummy('t') + cf = characteristic_function(D)(t) + assert cf.dummy_eq( + Sum(Piecewise((exp(ki*I*t)/n, (ki >= 1) & (ki <= n)), (0, True)), (ki, 1, n))) + assert cf.subs(n, 3).doit() == exp(3*I*t)/3 + exp(2*I*t)/3 + exp(I*t)/3 + mgf = moment_generating_function(D)(t) + assert mgf.dummy_eq( + Sum(Piecewise((exp(ki*t)/n, (ki >= 1) & (ki <= n)), (0, True)), (ki, 1, n))) + assert mgf.subs(n, 3).doit() == exp(3*t)/3 + exp(2*t)/3 + exp(t)/3 + +def test_given(): + X = Die('X', 6) + assert density(X, X > 5) == {S(6): S.One} + assert where(X > 2, X > 5).as_boolean() == Eq(X.symbol, 6) + + +def test_domains(): + X, Y = Die('x', 6), Die('y', 6) + x, y = X.symbol, Y.symbol + # Domains + d = where(X > Y) + assert d.condition == (x > y) + d = where(And(X > Y, Y > 3)) + assert d.as_boolean() == Or(And(Eq(x, 5), Eq(y, 4)), And(Eq(x, 6), + Eq(y, 5)), And(Eq(x, 6), Eq(y, 4))) + assert len(d.elements) == 3 + + assert len(pspace(X + Y).domain.elements) == 36 + + Z = Die('x', 4) + + raises(ValueError, lambda: P(X > Z)) # Two domains with same internal symbol + + assert pspace(X + Y).domain.set == FiniteSet(1, 2, 3, 4, 5, 6)**2 + + assert where(X > 3).set == FiniteSet(4, 5, 6) + assert X.pspace.domain.dict == FiniteSet( + *[Dict({X.symbol: i}) for i in range(1, 7)]) + + assert where(X > Y).dict == FiniteSet(*[Dict({X.symbol: i, Y.symbol: j}) + for i in range(1, 7) for j in range(1, 7) if i > j]) + +def test_bernoulli(): + p, a, b, t = symbols('p a b t') + X = Bernoulli('B', p, a, b) + + assert E(X) == a*p + b*(-p + 1) + assert density(X)[a] == p + assert density(X)[b] == 1 - p + assert characteristic_function(X)(t) == p * exp(I * a * t) + (-p + 1) * exp(I * b * t) + assert moment_generating_function(X)(t) == p * exp(a * t) + (-p + 1) * exp(b * t) + + X = Bernoulli('B', p, 1, 0) + z = Symbol("z") + + assert E(X) == p + assert simplify(variance(X)) == p*(1 - p) + assert E(a*X + b) == a*E(X) + b + assert simplify(variance(a*X + b)) == simplify(a**2 * variance(X)) + assert quantile(X)(z) == Piecewise((nan, (z > 1) | (z < 0)), (0, z <= 1 - p), (1, z <= 1)) + Y = Bernoulli('Y', Rational(1, 2)) + assert median(Y) == FiniteSet(0, 1) + Z = Bernoulli('Z', Rational(2, 3)) + assert median(Z) == FiniteSet(1) + raises(ValueError, lambda: Bernoulli('B', 1.5)) + raises(ValueError, lambda: Bernoulli('B', -0.5)) + + #issue 8248 + assert X.pspace.compute_expectation(1) == 1 + + p = Rational(1, 5) + X = Binomial('X', 5, p) + Y = Binomial('Y', 7, 2*p) + Z = Binomial('Z', 9, 3*p) + assert coskewness(Y + Z, X + Y, X + Z).simplify() == 0 + assert coskewness(Y + 2*X + Z, X + 2*Y + Z, X + 2*Z + Y).simplify() == \ + sqrt(1529)*Rational(12, 16819) + assert coskewness(Y + 2*X + Z, X + 2*Y + Z, X + 2*Z + Y, X < 2).simplify() \ + == -sqrt(357451121)*Rational(2812, 4646864573) + +def test_cdf(): + D = Die('D', 6) + o = S.One + + assert cdf( + D) == sympify({1: o/6, 2: o/3, 3: o/2, 4: 2*o/3, 5: 5*o/6, 6: o}) + + +def test_coins(): + C, D = Coin('C'), Coin('D') + H, T = symbols('H, T') + assert P(Eq(C, D)) == S.Half + assert density(Tuple(C, D)) == {(H, H): Rational(1, 4), (H, T): Rational(1, 4), + (T, H): Rational(1, 4), (T, T): Rational(1, 4)} + assert dict(density(C).items()) == {H: S.Half, T: S.Half} + + F = Coin('F', Rational(1, 10)) + assert P(Eq(F, H)) == Rational(1, 10) + + d = pspace(C).domain + + assert d.as_boolean() == Or(Eq(C.symbol, H), Eq(C.symbol, T)) + + raises(ValueError, lambda: P(C > D)) # Can't intelligently compare H to T + +def test_binomial_verify_parameters(): + raises(ValueError, lambda: Binomial('b', .2, .5)) + raises(ValueError, lambda: Binomial('b', 3, 1.5)) + +def test_binomial_numeric(): + nvals = range(5) + pvals = [0, Rational(1, 4), S.Half, Rational(3, 4), 1] + + for n in nvals: + for p in pvals: + X = Binomial('X', n, p) + assert E(X) == n*p + assert variance(X) == n*p*(1 - p) + if n > 0 and 0 < p < 1: + assert skewness(X) == (1 - 2*p)/sqrt(n*p*(1 - p)) + assert kurtosis(X) == 3 + (1 - 6*p*(1 - p))/(n*p*(1 - p)) + for k in range(n + 1): + assert P(Eq(X, k)) == binomial(n, k)*p**k*(1 - p)**(n - k) + +def test_binomial_quantile(): + X = Binomial('X', 50, S.Half) + assert quantile(X)(0.95) == S(31) + assert median(X) == FiniteSet(25) + + X = Binomial('X', 5, S.Half) + p = Symbol("p", positive=True) + assert quantile(X)(p) == Piecewise((nan, p > S.One), (S.Zero, p <= Rational(1, 32)),\ + (S.One, p <= Rational(3, 16)), (S(2), p <= S.Half), (S(3), p <= Rational(13, 16)),\ + (S(4), p <= Rational(31, 32)), (S(5), p <= S.One)) + assert median(X) == FiniteSet(2, 3) + + +def test_binomial_symbolic(): + n = 2 + p = symbols('p', positive=True) + X = Binomial('X', n, p) + t = Symbol('t') + + assert simplify(E(X)) == n*p == simplify(moment(X, 1)) + assert simplify(variance(X)) == n*p*(1 - p) == simplify(cmoment(X, 2)) + assert cancel(skewness(X) - (1 - 2*p)/sqrt(n*p*(1 - p))) == 0 + assert cancel((kurtosis(X)) - (3 + (1 - 6*p*(1 - p))/(n*p*(1 - p)))) == 0 + assert characteristic_function(X)(t) == p ** 2 * exp(2 * I * t) + 2 * p * (-p + 1) * exp(I * t) + (-p + 1) ** 2 + assert moment_generating_function(X)(t) == p ** 2 * exp(2 * t) + 2 * p * (-p + 1) * exp(t) + (-p + 1) ** 2 + + # Test ability to change success/failure winnings + H, T = symbols('H T') + Y = Binomial('Y', n, p, succ=H, fail=T) + assert simplify(E(Y) - (n*(H*p + T*(1 - p)))) == 0 + + # test symbolic dimensions + n = symbols('n') + B = Binomial('B', n, p) + raises(NotImplementedError, lambda: P(B > 2)) + assert density(B).dict == Density(BinomialDistribution(n, p, 1, 0)) + assert set(density(B).dict.subs(n, 4).doit().keys()) == \ + {S.Zero, S.One, S(2), S(3), S(4)} + assert set(density(B).dict.subs(n, 4).doit().values()) == \ + {(1 - p)**4, 4*p*(1 - p)**3, 6*p**2*(1 - p)**2, 4*p**3*(1 - p), p**4} + k = Dummy('k', integer=True) + assert E(B > 2).dummy_eq( + Sum(Piecewise((k*p**k*(1 - p)**(-k + n)*binomial(n, k), (k >= 0) + & (k <= n) & (k > 2)), (0, True)), (k, 0, n))) + +def test_beta_binomial(): + # verify parameters + raises(ValueError, lambda: BetaBinomial('b', .2, 1, 2)) + raises(ValueError, lambda: BetaBinomial('b', 2, -1, 2)) + raises(ValueError, lambda: BetaBinomial('b', 2, 1, -2)) + assert BetaBinomial('b', 2, 1, 1) + + # test numeric values + nvals = range(1,5) + alphavals = [Rational(1, 4), S.Half, Rational(3, 4), 1, 10] + betavals = [Rational(1, 4), S.Half, Rational(3, 4), 1, 10] + + for n in nvals: + for a in alphavals: + for b in betavals: + X = BetaBinomial('X', n, a, b) + assert E(X) == moment(X, 1) + assert variance(X) == cmoment(X, 2) + + # test symbolic + n, a, b = symbols('a b n') + assert BetaBinomial('x', n, a, b) + n = 2 # Because we're using for loops, can't do symbolic n + a, b = symbols('a b', positive=True) + X = BetaBinomial('X', n, a, b) + t = Symbol('t') + + assert E(X).expand() == moment(X, 1).expand() + assert variance(X).expand() == cmoment(X, 2).expand() + assert skewness(X) == smoment(X, 3) + assert characteristic_function(X)(t) == exp(2*I*t)*beta(a + 2, b)/beta(a, b) +\ + 2*exp(I*t)*beta(a + 1, b + 1)/beta(a, b) + beta(a, b + 2)/beta(a, b) + assert moment_generating_function(X)(t) == exp(2*t)*beta(a + 2, b)/beta(a, b) +\ + 2*exp(t)*beta(a + 1, b + 1)/beta(a, b) + beta(a, b + 2)/beta(a, b) + +def test_hypergeometric_numeric(): + for N in range(1, 5): + for m in range(0, N + 1): + for n in range(1, N + 1): + X = Hypergeometric('X', N, m, n) + N, m, n = map(sympify, (N, m, n)) + assert sum(density(X).values()) == 1 + assert E(X) == n * m / N + if N > 1: + assert variance(X) == n*(m/N)*(N - m)/N*(N - n)/(N - 1) + # Only test for skewness when defined + if N > 2 and 0 < m < N and n < N: + assert skewness(X) == simplify((N - 2*m)*sqrt(N - 1)*(N - 2*n) + / (sqrt(n*m*(N - m)*(N - n))*(N - 2))) + +def test_hypergeometric_symbolic(): + N, m, n = symbols('N, m, n') + H = Hypergeometric('H', N, m, n) + dens = density(H).dict + expec = E(H > 2) + assert dens == Density(HypergeometricDistribution(N, m, n)) + assert dens.subs(N, 5).doit() == Density(HypergeometricDistribution(5, m, n)) + assert set(dens.subs({N: 3, m: 2, n: 1}).doit().keys()) == {S.Zero, S.One} + assert set(dens.subs({N: 3, m: 2, n: 1}).doit().values()) == {Rational(1, 3), Rational(2, 3)} + k = Dummy('k', integer=True) + assert expec.dummy_eq( + Sum(Piecewise((k*binomial(m, k)*binomial(N - m, -k + n) + /binomial(N, n), k > 2), (0, True)), (k, 0, n))) + +def test_rademacher(): + X = Rademacher('X') + t = Symbol('t') + + assert E(X) == 0 + assert variance(X) == 1 + assert density(X)[-1] == S.Half + assert density(X)[1] == S.Half + assert characteristic_function(X)(t) == exp(I*t)/2 + exp(-I*t)/2 + assert moment_generating_function(X)(t) == exp(t) / 2 + exp(-t) / 2 + +def test_ideal_soliton(): + raises(ValueError, lambda : IdealSoliton('sol', -12)) + raises(ValueError, lambda : IdealSoliton('sol', 13.2)) + raises(ValueError, lambda : IdealSoliton('sol', 0)) + f = Function('f') + raises(ValueError, lambda : density(IdealSoliton('sol', 10)).pmf(f)) + + k = Symbol('k', integer=True, positive=True) + x = Symbol('x', integer=True, positive=True) + t = Symbol('t') + sol = IdealSoliton('sol', k) + assert density(sol).low == S.One + assert density(sol).high == k + assert density(sol).dict == Density(density(sol)) + assert density(sol).pmf(x) == Piecewise((1/k, Eq(x, 1)), (1/(x*(x - 1)), k >= x), (0, True)) + + k_vals = [5, 20, 50, 100, 1000] + for i in k_vals: + assert E(sol.subs(k, i)) == harmonic(i) == moment(sol.subs(k, i), 1) + assert variance(sol.subs(k, i)) == (i - 1) + harmonic(i) - harmonic(i)**2 == cmoment(sol.subs(k, i),2) + assert skewness(sol.subs(k, i)) == smoment(sol.subs(k, i), 3) + assert kurtosis(sol.subs(k, i)) == smoment(sol.subs(k, i), 4) + + assert exp(I*t)/10 + Sum(exp(I*t*x)/(x*x - x), (x, 2, k)).subs(k, 10).doit() == characteristic_function(sol.subs(k, 10))(t) + assert exp(t)/10 + Sum(exp(t*x)/(x*x - x), (x, 2, k)).subs(k, 10).doit() == moment_generating_function(sol.subs(k, 10))(t) + +def test_robust_soliton(): + raises(ValueError, lambda : RobustSoliton('robSol', -12, 0.1, 0.02)) + raises(ValueError, lambda : RobustSoliton('robSol', 13, 1.89, 0.1)) + raises(ValueError, lambda : RobustSoliton('robSol', 15, 0.6, -2.31)) + f = Function('f') + raises(ValueError, lambda : density(RobustSoliton('robSol', 15, 0.6, 0.1)).pmf(f)) + + k = Symbol('k', integer=True, positive=True) + delta = Symbol('delta', positive=True) + c = Symbol('c', positive=True) + robSol = RobustSoliton('robSol', k, delta, c) + assert density(robSol).low == 1 + assert density(robSol).high == k + + k_vals = [10, 20, 50] + delta_vals = [0.2, 0.4, 0.6] + c_vals = [0.01, 0.03, 0.05] + for x in k_vals: + for y in delta_vals: + for z in c_vals: + assert E(robSol.subs({k: x, delta: y, c: z})) == moment(robSol.subs({k: x, delta: y, c: z}), 1) + assert variance(robSol.subs({k: x, delta: y, c: z})) == cmoment(robSol.subs({k: x, delta: y, c: z}), 2) + assert skewness(robSol.subs({k: x, delta: y, c: z})) == smoment(robSol.subs({k: x, delta: y, c: z}), 3) + assert kurtosis(robSol.subs({k: x, delta: y, c: z})) == smoment(robSol.subs({k: x, delta: y, c: z}), 4) + +def test_FiniteRV(): + F = FiniteRV('F', {1: S.Half, 2: Rational(1, 4), 3: Rational(1, 4)}, check=True) + p = Symbol("p", positive=True) + + assert dict(density(F).items()) == {S.One: S.Half, S(2): Rational(1, 4), S(3): Rational(1, 4)} + assert P(F >= 2) == S.Half + assert quantile(F)(p) == Piecewise((nan, p > S.One), (S.One, p <= S.Half),\ + (S(2), p <= Rational(3, 4)),(S(3), True)) + + assert pspace(F).domain.as_boolean() == Or( + *[Eq(F.symbol, i) for i in [1, 2, 3]]) + + assert F.pspace.domain.set == FiniteSet(1, 2, 3) + raises(ValueError, lambda: FiniteRV('F', {1: S.Half, 2: S.Half, 3: S.Half}, check=True)) + raises(ValueError, lambda: FiniteRV('F', {1: S.Half, 2: Rational(-1, 2), 3: S.One}, check=True)) + raises(ValueError, lambda: FiniteRV('F', {1: S.One, 2: Rational(3, 2), 3: S.Zero,\ + 4: Rational(-1, 2), 5: Rational(-3, 4), 6: Rational(-1, 4)}, check=True)) + + # purposeful invalid pmf but it should not raise since check=False + # see test_drv_types.test_ContinuousRV for explanation + X = FiniteRV('X', {1: 1, 2: 2}) + assert E(X) == 5 + assert P(X <= 2) + P(X > 2) != 1 + +def test_density_call(): + from sympy.abc import p + x = Bernoulli('x', p) + d = density(x) + assert d(0) == 1 - p + assert d(S.Zero) == 1 - p + assert d(5) == 0 + + assert 0 in d + assert 5 not in d + assert d(S.Zero) == d[S.Zero] + + +def test_DieDistribution(): + from sympy.abc import x + X = DieDistribution(6) + assert X.pmf(S.Half) is S.Zero + assert X.pmf(x).subs({x: 1}).doit() == Rational(1, 6) + assert X.pmf(x).subs({x: 7}).doit() == 0 + assert X.pmf(x).subs({x: -1}).doit() == 0 + assert X.pmf(x).subs({x: Rational(1, 3)}).doit() == 0 + raises(ValueError, lambda: X.pmf(Matrix([0, 0]))) + raises(ValueError, lambda: X.pmf(x**2 - 1)) + +def test_FinitePSpace(): + X = Die('X', 6) + space = pspace(X) + assert space.density == DieDistribution(6) + +def test_symbolic_conditions(): + B = Bernoulli('B', Rational(1, 4)) + D = Die('D', 4) + b, n = symbols('b, n') + Y = P(Eq(B, b)) + Z = E(D > n) + assert Y == \ + Piecewise((Rational(1, 4), Eq(b, 1)), (0, True)) + \ + Piecewise((Rational(3, 4), Eq(b, 0)), (0, True)) + assert Z == \ + Piecewise((Rational(1, 4), n < 1), (0, True)) + Piecewise((S.Half, n < 2), (0, True)) + \ + Piecewise((Rational(3, 4), n < 3), (0, True)) + Piecewise((S.One, n < 4), (0, True)) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_joint_rv.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_joint_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..057fc313dfbb31826b07fd1315205d22b86a7f96 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_joint_rv.py @@ -0,0 +1,436 @@ +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial) +from sympy.functions.elementary.complexes import polar_lift +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.bessel import besselk +from sympy.functions.special.gamma_functions import gamma +from sympy.matrices.dense import eye +from sympy.matrices.expressions.determinant import Determinant +from sympy.sets.fancysets import Range +from sympy.sets.sets import (Interval, ProductSet) +from sympy.simplify.simplify import simplify +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.core.numbers import comp +from sympy.integrals.integrals import integrate +from sympy.matrices import Matrix, MatrixSymbol +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.stats import density, median, marginal_distribution, Normal, Laplace, E, sample +from sympy.stats.joint_rv_types import (JointRV, MultivariateNormalDistribution, + JointDistributionHandmade, MultivariateT, NormalGamma, + GeneralizedMultivariateLogGammaOmega as GMVLGO, MultivariateBeta, + GeneralizedMultivariateLogGamma as GMVLG, MultivariateEwens, + Multinomial, NegativeMultinomial, MultivariateNormal, + MultivariateLaplace) +from sympy.testing.pytest import raises, XFAIL, skip, slow +from sympy.external import import_module + +from sympy.abc import x, y + + + +def test_Normal(): + m = Normal('A', [1, 2], [[1, 0], [0, 1]]) + A = MultivariateNormal('A', [1, 2], [[1, 0], [0, 1]]) + assert m == A + assert density(m)(1, 2) == 1/(2*pi) + assert m.pspace.distribution.set == ProductSet(S.Reals, S.Reals) + raises (ValueError, lambda:m[2]) + n = Normal('B', [1, 2, 3], [[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + p = Normal('C', Matrix([1, 2]), Matrix([[1, 0], [0, 1]])) + assert density(m)(x, y) == density(p)(x, y) + assert marginal_distribution(n, 0, 1)(1, 2) == 1/(2*pi) + raises(ValueError, lambda: marginal_distribution(m)) + assert integrate(density(m)(x, y), (x, -oo, oo), (y, -oo, oo)).evalf() == 1.0 + N = Normal('N', [1, 2], [[x, 0], [0, y]]) + assert density(N)(0, 0) == exp(-((4*x + y)/(2*x*y)))/(2*pi*sqrt(x*y)) + + raises (ValueError, lambda: Normal('M', [1, 2], [[1, 1], [1, -1]])) + # symbolic + n = symbols('n', integer=True, positive=True) + mu = MatrixSymbol('mu', n, 1) + sigma = MatrixSymbol('sigma', n, n) + X = Normal('X', mu, sigma) + assert density(X) == MultivariateNormalDistribution(mu, sigma) + raises (NotImplementedError, lambda: median(m)) + # Below tests should work after issue #17267 is resolved + # assert E(X) == mu + # assert variance(X) == sigma + + # test symbolic multivariate normal densities + n = 3 + + Sg = MatrixSymbol('Sg', n, n) + mu = MatrixSymbol('mu', n, 1) + obs = MatrixSymbol('obs', n, 1) + + X = MultivariateNormal('X', mu, Sg) + density_X = density(X) + + eval_a = density_X(obs).subs({Sg: eye(3), + mu: Matrix([0, 0, 0]), obs: Matrix([0, 0, 0])}).doit() + eval_b = density_X(0, 0, 0).subs({Sg: eye(3), mu: Matrix([0, 0, 0])}).doit() + + assert eval_a == sqrt(2)/(4*pi**Rational(3/2)) + assert eval_b == sqrt(2)/(4*pi**Rational(3/2)) + + n = symbols('n', integer=True, positive=True) + + Sg = MatrixSymbol('Sg', n, n) + mu = MatrixSymbol('mu', n, 1) + obs = MatrixSymbol('obs', n, 1) + + X = MultivariateNormal('X', mu, Sg) + density_X_at_obs = density(X)(obs) + + expected_density = MatrixElement( + exp((S(1)/2) * (mu.T - obs.T) * Sg**(-1) * (-mu + obs)) / \ + sqrt((2*pi)**n * Determinant(Sg)), 0, 0) + + assert density_X_at_obs == expected_density + + +def test_MultivariateTDist(): + t1 = MultivariateT('T', [0, 0], [[1, 0], [0, 1]], 2) + assert(density(t1))(1, 1) == 1/(8*pi) + assert t1.pspace.distribution.set == ProductSet(S.Reals, S.Reals) + assert integrate(density(t1)(x, y), (x, -oo, oo), \ + (y, -oo, oo)).evalf() == 1.0 + raises(ValueError, lambda: MultivariateT('T', [1, 2], [[1, 1], [1, -1]], 1)) + t2 = MultivariateT('t2', [1, 2], [[x, 0], [0, y]], 1) + assert density(t2)(1, 2) == 1/(2*pi*sqrt(x*y)) + + +def test_multivariate_laplace(): + raises(ValueError, lambda: Laplace('T', [1, 2], [[1, 2], [2, 1]])) + L = Laplace('L', [1, 0], [[1, 0], [0, 1]]) + L2 = MultivariateLaplace('L2', [1, 0], [[1, 0], [0, 1]]) + assert density(L)(2, 3) == exp(2)*besselk(0, sqrt(39))/pi + L1 = Laplace('L1', [1, 2], [[x, 0], [0, y]]) + assert density(L1)(0, 1) == \ + exp(2/y)*besselk(0, sqrt((2 + 4/y + 1/x)/y))/(pi*sqrt(x*y)) + assert L.pspace.distribution.set == ProductSet(S.Reals, S.Reals) + assert L.pspace.distribution == L2.pspace.distribution + + +def test_NormalGamma(): + ng = NormalGamma('G', 1, 2, 3, 4) + assert density(ng)(1, 1) == 32*exp(-4)/sqrt(pi) + assert ng.pspace.distribution.set == ProductSet(S.Reals, Interval(0, oo)) + raises(ValueError, lambda:NormalGamma('G', 1, 2, 3, -1)) + assert marginal_distribution(ng, 0)(1) == \ + 3*sqrt(10)*gamma(Rational(7, 4))/(10*sqrt(pi)*gamma(Rational(5, 4))) + assert marginal_distribution(ng, y)(1) == exp(Rational(-1, 4))/128 + assert marginal_distribution(ng,[0,1])(x) == x**2*exp(-x/4)/128 + + +def test_GeneralizedMultivariateLogGammaDistribution(): + h = S.Half + omega = Matrix([[1, h, h, h], + [h, 1, h, h], + [h, h, 1, h], + [h, h, h, 1]]) + v, l, mu = (4, [1, 2, 3, 4], [1, 2, 3, 4]) + y_1, y_2, y_3, y_4 = symbols('y_1:5', real=True) + delta = symbols('d', positive=True) + G = GMVLGO('G', omega, v, l, mu) + Gd = GMVLG('Gd', delta, v, l, mu) + dend = ("d**4*Sum(4*24**(-n - 4)*(1 - d)**n*exp((n + 4)*(y_1 + 2*y_2 + 3*y_3 " + "+ 4*y_4) - exp(y_1) - exp(2*y_2)/2 - exp(3*y_3)/3 - exp(4*y_4)/4)/" + "(gamma(n + 1)*gamma(n + 4)**3), (n, 0, oo))") + assert str(density(Gd)(y_1, y_2, y_3, y_4)) == dend + den = ("5*2**(2/3)*5**(1/3)*Sum(4*24**(-n - 4)*(-2**(2/3)*5**(1/3)/4 + 1)**n*" + "exp((n + 4)*(y_1 + 2*y_2 + 3*y_3 + 4*y_4) - exp(y_1) - exp(2*y_2)/2 - " + "exp(3*y_3)/3 - exp(4*y_4)/4)/(gamma(n + 1)*gamma(n + 4)**3), (n, 0, oo))/64") + assert str(density(G)(y_1, y_2, y_3, y_4)) == den + marg = ("5*2**(2/3)*5**(1/3)*exp(4*y_1)*exp(-exp(y_1))*Integral(exp(-exp(4*G[3])" + "/4)*exp(16*G[3])*Integral(exp(-exp(3*G[2])/3)*exp(12*G[2])*Integral(exp(" + "-exp(2*G[1])/2)*exp(8*G[1])*Sum((-1/4)**n*(-4 + 2**(2/3)*5**(1/3" + "))**n*exp(n*y_1)*exp(2*n*G[1])*exp(3*n*G[2])*exp(4*n*G[3])/(24**n*gamma(n + 1)" + "*gamma(n + 4)**3), (n, 0, oo)), (G[1], -oo, oo)), (G[2], -oo, oo)), (G[3]" + ", -oo, oo))/5308416") + assert str(marginal_distribution(G, G[0])(y_1)) == marg + omega_f1 = Matrix([[1, h, h]]) + omega_f2 = Matrix([[1, h, h, h], + [h, 1, 2, h], + [h, h, 1, h], + [h, h, h, 1]]) + omega_f3 = Matrix([[6, h, h, h], + [h, 1, 2, h], + [h, h, 1, h], + [h, h, h, 1]]) + v_f = symbols("v_f", positive=False, real=True) + l_f = [1, 2, v_f, 4] + m_f = [v_f, 2, 3, 4] + omega_f4 = Matrix([[1, h, h, h, h], + [h, 1, h, h, h], + [h, h, 1, h, h], + [h, h, h, 1, h], + [h, h, h, h, 1]]) + l_f1 = [1, 2, 3, 4, 5] + omega_f5 = Matrix([[1]]) + mu_f5 = l_f5 = [1] + + raises(ValueError, lambda: GMVLGO('G', omega_f1, v, l, mu)) + raises(ValueError, lambda: GMVLGO('G', omega_f2, v, l, mu)) + raises(ValueError, lambda: GMVLGO('G', omega_f3, v, l, mu)) + raises(ValueError, lambda: GMVLGO('G', omega, v_f, l, mu)) + raises(ValueError, lambda: GMVLGO('G', omega, v, l_f, mu)) + raises(ValueError, lambda: GMVLGO('G', omega, v, l, m_f)) + raises(ValueError, lambda: GMVLGO('G', omega_f4, v, l, mu)) + raises(ValueError, lambda: GMVLGO('G', omega, v, l_f1, mu)) + raises(ValueError, lambda: GMVLGO('G', omega_f5, v, l_f5, mu_f5)) + raises(ValueError, lambda: GMVLG('G', Rational(3, 2), v, l, mu)) + + +def test_MultivariateBeta(): + a1, a2 = symbols('a1, a2', positive=True) + a1_f, a2_f = symbols('a1, a2', positive=False, real=True) + mb = MultivariateBeta('B', [a1, a2]) + mb_c = MultivariateBeta('C', a1, a2) + assert density(mb)(1, 2) == S(2)**(a2 - 1)*gamma(a1 + a2)/\ + (gamma(a1)*gamma(a2)) + assert marginal_distribution(mb_c, 0)(3) == S(3)**(a1 - 1)*gamma(a1 + a2)/\ + (a2*gamma(a1)*gamma(a2)) + raises(ValueError, lambda: MultivariateBeta('b1', [a1_f, a2])) + raises(ValueError, lambda: MultivariateBeta('b2', [a1, a2_f])) + raises(ValueError, lambda: MultivariateBeta('b3', [0, 0])) + raises(ValueError, lambda: MultivariateBeta('b4', [a1_f, a2_f])) + assert mb.pspace.distribution.set == ProductSet(Interval(0, 1), Interval(0, 1)) + + +def test_MultivariateEwens(): + n, theta, i = symbols('n theta i', positive=True) + + # tests for integer dimensions + theta_f = symbols('t_f', negative=True) + a = symbols('a_1:4', positive = True, integer = True) + ed = MultivariateEwens('E', 3, theta) + assert density(ed)(a[0], a[1], a[2]) == Piecewise((6*2**(-a[1])*3**(-a[2])* + theta**a[0]*theta**a[1]*theta**a[2]/ + (theta*(theta + 1)*(theta + 2)* + factorial(a[0])*factorial(a[1])* + factorial(a[2])), Eq(a[0] + 2*a[1] + + 3*a[2], 3)), (0, True)) + assert marginal_distribution(ed, ed[1])(a[1]) == Piecewise((6*2**(-a[1])* + theta**a[1]/((theta + 1)* + (theta + 2)*factorial(a[1])), + Eq(2*a[1] + 1, 3)), (0, True)) + raises(ValueError, lambda: MultivariateEwens('e1', 5, theta_f)) + assert ed.pspace.distribution.set == ProductSet(Range(0, 4, 1), + Range(0, 2, 1), Range(0, 2, 1)) + + # tests for symbolic dimensions + eds = MultivariateEwens('E', n, theta) + a = IndexedBase('a') + j, k = symbols('j, k') + den = Piecewise((factorial(n)*Product(theta**a[j]*(j + 1)**(-a[j])/ + factorial(a[j]), (j, 0, n - 1))/RisingFactorial(theta, n), + Eq(n, Sum((k + 1)*a[k], (k, 0, n - 1)))), (0, True)) + assert density(eds)(a).dummy_eq(den) + + +def test_Multinomial(): + n, x1, x2, x3, x4 = symbols('n, x1, x2, x3, x4', nonnegative=True, integer=True) + p1, p2, p3, p4 = symbols('p1, p2, p3, p4', positive=True) + p1_f, n_f = symbols('p1_f, n_f', negative=True) + M = Multinomial('M', n, [p1, p2, p3, p4]) + C = Multinomial('C', 3, p1, p2, p3) + f = factorial + assert density(M)(x1, x2, x3, x4) == Piecewise((p1**x1*p2**x2*p3**x3*p4**x4* + f(n)/(f(x1)*f(x2)*f(x3)*f(x4)), + Eq(n, x1 + x2 + x3 + x4)), (0, True)) + assert marginal_distribution(C, C[0])(x1).subs(x1, 1) ==\ + 3*p1*p2**2 +\ + 6*p1*p2*p3 +\ + 3*p1*p3**2 + raises(ValueError, lambda: Multinomial('b1', 5, [p1, p2, p3, p1_f])) + raises(ValueError, lambda: Multinomial('b2', n_f, [p1, p2, p3, p4])) + raises(ValueError, lambda: Multinomial('b3', n, 0.5, 0.4, 0.3, 0.1)) + + +def test_NegativeMultinomial(): + k0, x1, x2, x3, x4 = symbols('k0, x1, x2, x3, x4', nonnegative=True, integer=True) + p1, p2, p3, p4 = symbols('p1, p2, p3, p4', positive=True) + p1_f = symbols('p1_f', negative=True) + N = NegativeMultinomial('N', 4, [p1, p2, p3, p4]) + C = NegativeMultinomial('C', 4, 0.1, 0.2, 0.3) + g = gamma + f = factorial + assert simplify(density(N)(x1, x2, x3, x4) - + p1**x1*p2**x2*p3**x3*p4**x4*(-p1 - p2 - p3 - p4 + 1)**4*g(x1 + x2 + + x3 + x4 + 4)/(6*f(x1)*f(x2)*f(x3)*f(x4))) is S.Zero + assert comp(marginal_distribution(C, C[0])(1).evalf(), 0.33, .01) + raises(ValueError, lambda: NegativeMultinomial('b1', 5, [p1, p2, p3, p1_f])) + raises(ValueError, lambda: NegativeMultinomial('b2', k0, 0.5, 0.4, 0.3, 0.4)) + assert N.pspace.distribution.set == ProductSet(Range(0, oo, 1), + Range(0, oo, 1), Range(0, oo, 1), Range(0, oo, 1)) + + +@slow +def test_JointPSpace_marginal_distribution(): + T = MultivariateT('T', [0, 0], [[1, 0], [0, 1]], 2) + got = marginal_distribution(T, T[1])(x) + ans = sqrt(2)*(x**2/2 + 1)/(4*polar_lift(x**2/2 + 1)**(S(5)/2)) + assert got == ans, got + assert integrate(marginal_distribution(T, 1)(x), (x, -oo, oo)) == 1 + + t = MultivariateT('T', [0, 0, 0], [[1, 0, 0], [0, 1, 0], [0, 0, 1]], 3) + assert comp(marginal_distribution(t, 0)(1).evalf(), 0.2, .01) + + +def test_JointRV(): + x1, x2 = (Indexed('x', i) for i in (1, 2)) + pdf = exp(-x1**2/2 + x1 - x2**2/2 - S.Half)/(2*pi) + X = JointRV('x', pdf) + assert density(X)(1, 2) == exp(-2)/(2*pi) + assert isinstance(X.pspace.distribution, JointDistributionHandmade) + assert marginal_distribution(X, 0)(2) == sqrt(2)*exp(Rational(-1, 2))/(2*sqrt(pi)) + + +def test_expectation(): + m = Normal('A', [x, y], [[1, 0], [0, 1]]) + assert simplify(E(m[1])) == y + + +@XFAIL +def test_joint_vector_expectation(): + m = Normal('A', [x, y], [[1, 0], [0, 1]]) + assert E(m) == (x, y) + + +def test_sample_numpy(): + distribs_numpy = [ + MultivariateNormal("M", [3, 4], [[2, 1], [1, 2]]), + MultivariateBeta("B", [0.4, 5, 15, 50, 203]), + Multinomial("N", 50, [0.3, 0.2, 0.1, 0.25, 0.15]) + ] + size = 3 + numpy = import_module('numpy') + if not numpy: + skip('Numpy is not installed. Abort tests for _sample_numpy.') + else: + for X in distribs_numpy: + samps = sample(X, size=size, library='numpy') + for sam in samps: + assert tuple(sam) in X.pspace.distribution.set + N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1) + raises(NotImplementedError, lambda: sample(N_c, library='numpy')) + + +def test_sample_scipy(): + distribs_scipy = [ + MultivariateNormal("M", [0, 0], [[0.1, 0.025], [0.025, 0.1]]), + MultivariateBeta("B", [0.4, 5, 15]), + Multinomial("N", 8, [0.3, 0.2, 0.1, 0.4]) + ] + + size = 3 + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests for _sample_scipy.') + else: + for X in distribs_scipy: + samps = sample(X, size=size) + samps2 = sample(X, size=(2, 2)) + for sam in samps: + assert tuple(sam) in X.pspace.distribution.set + for i in range(2): + for j in range(2): + assert tuple(samps2[i][j]) in X.pspace.distribution.set + N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1) + raises(NotImplementedError, lambda: sample(N_c)) + + +def test_sample_pymc(): + distribs_pymc = [ + MultivariateNormal("M", [5, 2], [[1, 0], [0, 1]]), + MultivariateBeta("B", [0.4, 5, 15]), + Multinomial("N", 4, [0.3, 0.2, 0.1, 0.4]) + ] + size = 3 + pymc = import_module('pymc') + if not pymc: + skip('PyMC is not installed. Abort tests for _sample_pymc.') + else: + for X in distribs_pymc: + samps = sample(X, size=size, library='pymc') + for sam in samps: + assert tuple(sam.flatten()) in X.pspace.distribution.set + N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1) + raises(NotImplementedError, lambda: sample(N_c, library='pymc')) + + +def test_sample_seed(): + x1, x2 = (Indexed('x', i) for i in (1, 2)) + pdf = exp(-x1**2/2 + x1 - x2**2/2 - S.Half)/(2*pi) + X = JointRV('x', pdf) + + libraries = ['scipy', 'numpy', 'pymc'] + for lib in libraries: + try: + imported_lib = import_module(lib) + if imported_lib: + s0, s1, s2 = [], [], [] + s0 = sample(X, size=10, library=lib, seed=0) + s1 = sample(X, size=10, library=lib, seed=0) + s2 = sample(X, size=10, library=lib, seed=1) + assert all(s0 == s1) + assert all(s1 != s2) + except NotImplementedError: + continue + +# +# XXX: This fails for pymc. Previously the test appeared to pass but that is +# just because the library argument was not passed so the test always used +# scipy. +# +def test_issue_21057(): + m = Normal("x", [0, 0], [[0, 0], [0, 0]]) + n = MultivariateNormal("x", [0, 0], [[0, 0], [0, 0]]) + p = Normal("x", [0, 0], [[0, 0], [0, 1]]) + assert m == n + libraries = ('scipy', 'numpy') # , 'pymc') # <-- pymc fails + for library in libraries: + try: + imported_lib = import_module(library) + if imported_lib: + s1 = sample(m, size=8, library=library) + s2 = sample(n, size=8, library=library) + s3 = sample(p, size=8, library=library) + assert tuple(s1.flatten()) == tuple(s2.flatten()) + for s in s3: + assert tuple(s.flatten()) in p.pspace.distribution.set + except NotImplementedError: + continue + + +# +# When this passes the pymc part can be uncommented in test_issue_21057 above +# and this can be deleted. +# +@XFAIL +def test_issue_21057_pymc(): + m = Normal("x", [0, 0], [[0, 0], [0, 0]]) + n = MultivariateNormal("x", [0, 0], [[0, 0], [0, 0]]) + p = Normal("x", [0, 0], [[0, 0], [0, 1]]) + assert m == n + libraries = ('pymc',) + for library in libraries: + try: + imported_lib = import_module(library) + if imported_lib: + s1 = sample(m, size=8, library=library) + s2 = sample(n, size=8, library=library) + s3 = sample(p, size=8, library=library) + assert tuple(s1.flatten()) == tuple(s2.flatten()) + for s in s3: + assert tuple(s.flatten()) in p.pspace.distribution.set + except NotImplementedError: + continue diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_matrix_distributions.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_matrix_distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a2dcdd853793d9f77e1a88adf63158ed68e3ba --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_matrix_distributions.py @@ -0,0 +1,186 @@ +from sympy.concrete.products import Product +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.gamma_functions import gamma +from sympy.matrices import Determinant, Matrix, Trace, MatrixSymbol, MatrixSet +from sympy.stats import density, sample +from sympy.stats.matrix_distributions import (MatrixGammaDistribution, + MatrixGamma, MatrixPSpace, Wishart, MatrixNormal, MatrixStudentT) +from sympy.testing.pytest import raises, skip +from sympy.external import import_module + + +def test_MatrixPSpace(): + M = MatrixGammaDistribution(1, 2, [[2, 1], [1, 2]]) + MP = MatrixPSpace('M', M, 2, 2) + assert MP.distribution == M + raises(ValueError, lambda: MatrixPSpace('M', M, 1.2, 2)) + +def test_MatrixGamma(): + M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]]) + assert M.pspace.distribution.set == MatrixSet(2, 2, S.Reals) + assert isinstance(density(M), MatrixGammaDistribution) + X = MatrixSymbol('X', 2, 2) + num = exp(Trace(Matrix([[-S(1)/2, 0], [0, -S(1)/2]])*X)) + assert density(M)(X).doit() == num/(4*pi*sqrt(Determinant(X))) + assert density(M)([[2, 1], [1, 2]]).doit() == sqrt(3)*exp(-2)/(12*pi) + X = MatrixSymbol('X', 1, 2) + Y = MatrixSymbol('Y', 1, 2) + assert density(M)([X, Y]).doit() == exp(-X[0, 0]/2 - Y[0, 1]/2)/(4*pi*sqrt( + X[0, 0]*Y[0, 1] - X[0, 1]*Y[0, 0])) + # symbolic + a, b = symbols('a b', positive=True) + d = symbols('d', positive=True, integer=True) + Y = MatrixSymbol('Y', d, d) + Z = MatrixSymbol('Z', 2, 2) + SM = MatrixSymbol('SM', d, d) + M2 = MatrixGamma('M2', a, b, SM) + M3 = MatrixGamma('M3', 2, 3, [[2, 1], [1, 2]]) + k = Dummy('k') + exprd = pi**(-d*(d - 1)/4)*b**(-a*d)*exp(Trace((-1/b)*SM**(-1)*Y) + )*Determinant(SM)**(-a)*Determinant(Y)**(a - d/2 - S(1)/2)/Product( + gamma(-k/2 + a + S(1)/2), (k, 1, d)) + assert density(M2)(Y).dummy_eq(exprd) + raises(NotImplementedError, lambda: density(M3 + M)(Z)) + raises(ValueError, lambda: density(M)(1)) + raises(ValueError, lambda: MatrixGamma('M', -1, 2, [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixGamma('M', -1, -2, [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixGamma('M', -1, 2, [[1, 0], [2, 1]])) + raises(ValueError, lambda: MatrixGamma('M', -1, 2, [[1, 0], [0]])) + +def test_Wishart(): + W = Wishart('W', 5, [[1, 0], [0, 1]]) + assert W.pspace.distribution.set == MatrixSet(2, 2, S.Reals) + X = MatrixSymbol('X', 2, 2) + term1 = exp(Trace(Matrix([[-S(1)/2, 0], [0, -S(1)/2]])*X)) + assert density(W)(X).doit() == term1 * Determinant(X)/(24*pi) + assert density(W)([[2, 1], [1, 2]]).doit() == exp(-2)/(8*pi) + n = symbols('n', positive=True) + d = symbols('d', positive=True, integer=True) + Y = MatrixSymbol('Y', d, d) + SM = MatrixSymbol('SM', d, d) + W = Wishart('W', n, SM) + k = Dummy('k') + exprd = 2**(-d*n/2)*pi**(-d*(d - 1)/4)*exp(Trace(-(S(1)/2)*SM**(-1)*Y) + )*Determinant(SM)**(-n/2)*Determinant(Y)**( + -d/2 + n/2 - S(1)/2)/Product(gamma(-k/2 + n/2 + S(1)/2), (k, 1, d)) + assert density(W)(Y).dummy_eq(exprd) + raises(ValueError, lambda: density(W)(1)) + raises(ValueError, lambda: Wishart('W', -1, [[1, 0], [0, 1]])) + raises(ValueError, lambda: Wishart('W', -1, [[1, 0], [2, 1]])) + raises(ValueError, lambda: Wishart('W', 2, [[1, 0], [0]])) + +def test_MatrixNormal(): + M = MatrixNormal('M', [[5, 6]], [4], [[2, 1], [1, 2]]) + assert M.pspace.distribution.set == MatrixSet(1, 2, S.Reals) + X = MatrixSymbol('X', 1, 2) + term1 = exp(-Trace(Matrix([[ S(2)/3, -S(1)/3], [-S(1)/3, S(2)/3]])*( + Matrix([[-5], [-6]]) + X.T)*Matrix([[S(1)/4]])*(Matrix([[-5, -6]]) + X))/2) + assert density(M)(X).doit() == (sqrt(3)) * term1/(24*pi) + assert density(M)([[7, 8]]).doit() == sqrt(3)*exp(-S(1)/3)/(24*pi) + d, n = symbols('d n', positive=True, integer=True) + SM2 = MatrixSymbol('SM2', d, d) + SM1 = MatrixSymbol('SM1', n, n) + LM = MatrixSymbol('LM', n, d) + Y = MatrixSymbol('Y', n, d) + M = MatrixNormal('M', LM, SM1, SM2) + exprd = (2*pi)**(-d*n/2)*exp(-Trace(SM2**(-1)*(-LM.T + Y.T)*SM1**(-1)*(-LM + Y) + )/2)*Determinant(SM1)**(-d/2)*Determinant(SM2)**(-n/2) + assert density(M)(Y).doit() == exprd + raises(ValueError, lambda: density(M)(1)) + raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [0, 1]], [[1, 0], [2, 1]])) + raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2, 1]], [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [0, 1]], [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2]], [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2, 1]], [[1, 0], [0]])) + raises(ValueError, lambda: MatrixNormal('M', [[1, 2]], [[1, 0], [0, 1]], [[1, 0]])) + raises(ValueError, lambda: MatrixNormal('M', [[1, 2]], [1], [[1, 0]])) + +def test_MatrixStudentT(): + M = MatrixStudentT('M', 2, [[5, 6]], [[2, 1], [1, 2]], [4]) + assert M.pspace.distribution.set == MatrixSet(1, 2, S.Reals) + X = MatrixSymbol('X', 1, 2) + D = pi ** (-1.0) * Determinant(Matrix([[4]])) ** (-1.0) * Determinant(Matrix([[2, 1], [1, 2]])) \ + ** (-0.5) / Determinant(Matrix([[S(1) / 4]]) * (Matrix([[-5, -6]]) + X) + * Matrix([[S(2) / 3, -S(1) / 3], [-S(1) / 3, S(2) / 3]]) * ( + Matrix([[-5], [-6]]) + X.T) + Matrix([[1]])) ** 2 + assert density(M)(X) == D + + v = symbols('v', positive=True) + n, p = 1, 2 + Omega = MatrixSymbol('Omega', p, p) + Sigma = MatrixSymbol('Sigma', n, n) + Location = MatrixSymbol('Location', n, p) + Y = MatrixSymbol('Y', n, p) + M = MatrixStudentT('M', v, Location, Omega, Sigma) + + exprd = gamma(v/2 + 1)*Determinant(Matrix([[1]]) + Sigma**(-1)*(-Location + Y)*Omega**(-1)*(-Location.T + Y.T))**(-v/2 - 1) / \ + (pi*gamma(v/2)*sqrt(Determinant(Omega))*Determinant(Sigma)) + + assert density(M)(Y) == exprd + raises(ValueError, lambda: density(M)(1)) + raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [0, 1]], [[1, 0], [2, 1]])) + raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [2, 1]], [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [0, 1]], [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [2]], [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [2, 1]], [[1], [2]])) + raises(ValueError, lambda: MatrixStudentT('M', 1, [[1, 2]], [[1, 0], [0, 1]], [[1, 0]])) + raises(ValueError, lambda: MatrixStudentT('M', 1, [[1, 2]], [1], [[1, 0]])) + raises(ValueError, lambda: MatrixStudentT('M', -1, [1, 2], [[1, 0], [0, 1]], [4])) + +def test_sample_scipy(): + distribs_scipy = [ + MatrixNormal('M', [[5, 6]], [4], [[2, 1], [1, 2]]), + Wishart('W', 5, [[1, 0], [0, 1]]) + ] + + size = 5 + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests for _sample_scipy.') + else: + for X in distribs_scipy: + samps = sample(X, size=size) + for sam in samps: + assert Matrix(sam) in X.pspace.distribution.set + M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]]) + raises(NotImplementedError, lambda: sample(M, size=3)) + +def test_sample_pymc(): + distribs_pymc = [ + MatrixNormal('M', [[5, 6], [3, 4]], [[1, 0], [0, 1]], [[2, 1], [1, 2]]), + Wishart('W', 7, [[2, 1], [1, 2]]) + ] + size = 3 + pymc = import_module('pymc') + if not pymc: + skip('PyMC is not installed. Abort tests for _sample_pymc.') + else: + for X in distribs_pymc: + samps = sample(X, size=size, library='pymc') + for sam in samps: + assert Matrix(sam) in X.pspace.distribution.set + M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]]) + raises(NotImplementedError, lambda: sample(M, size=3)) + +def test_sample_seed(): + X = MatrixNormal('M', [[5, 6], [3, 4]], [[1, 0], [0, 1]], [[2, 1], [1, 2]]) + + libraries = ['scipy', 'numpy', 'pymc'] + for lib in libraries: + try: + imported_lib = import_module(lib) + if imported_lib: + s0, s1, s2 = [], [], [] + s0 = sample(X, size=10, library=lib, seed=0) + s1 = sample(X, size=10, library=lib, seed=0) + s2 = sample(X, size=10, library=lib, seed=1) + for i in range(10): + assert (s0[i] == s1[i]).all() + assert (s1[i] != s2[i]).all() + + except NotImplementedError: + continue diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_mix.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_mix.py new file mode 100644 index 0000000000000000000000000000000000000000..4334d9b144a5ddaad938f195f0276e0e8993aa35 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_mix.py @@ -0,0 +1,82 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import (Integer, oo, pi) +from sympy.core.power import Pow +from sympy.core.relational import (Eq, Ne) +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.delta_functions import DiracDelta +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import Integral +from sympy.simplify.simplify import simplify +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.functions.elementary.piecewise import ExprCondPair +from sympy.stats import (Poisson, Beta, Exponential, P, + Multinomial, MultivariateBeta) +from sympy.stats.crv_types import Normal +from sympy.stats.drv_types import PoissonDistribution +from sympy.stats.compound_rv import CompoundPSpace, CompoundDistribution +from sympy.stats.joint_rv import MarginalDistribution +from sympy.stats.rv import pspace, density +from sympy.testing.pytest import ignore_warnings + +def test_density(): + x = Symbol('x') + l = Symbol('l', positive=True) + rate = Beta(l, 2, 3) + X = Poisson(x, rate) + assert isinstance(pspace(X), CompoundPSpace) + assert density(X, Eq(rate, rate.symbol)) == PoissonDistribution(l) + N1 = Normal('N1', 0, 1) + N2 = Normal('N2', N1, 2) + assert density(N2)(0).doit() == sqrt(10)/(10*sqrt(pi)) + assert simplify(density(N2, Eq(N1, 1))(x)) == \ + sqrt(2)*exp(-(x - 1)**2/8)/(4*sqrt(pi)) + assert simplify(density(N2)(x)) == sqrt(10)*exp(-x**2/10)/(10*sqrt(pi)) + +def test_MarginalDistribution(): + a1, p1, p2 = symbols('a1 p1 p2', positive=True) + C = Multinomial('C', 2, p1, p2) + B = MultivariateBeta('B', a1, C[0]) + MGR = MarginalDistribution(B, (C[0],)) + mgrc = Mul(Symbol('B'), Piecewise(ExprCondPair(Mul(Integer(2), + Pow(Symbol('p1', positive=True), Indexed(IndexedBase(Symbol('C')), + Integer(0))), Pow(Symbol('p2', positive=True), + Indexed(IndexedBase(Symbol('C')), Integer(1))), + Pow(factorial(Indexed(IndexedBase(Symbol('C')), Integer(0))), Integer(-1)), + Pow(factorial(Indexed(IndexedBase(Symbol('C')), Integer(1))), Integer(-1))), + Eq(Add(Indexed(IndexedBase(Symbol('C')), Integer(0)), + Indexed(IndexedBase(Symbol('C')), Integer(1))), Integer(2))), + ExprCondPair(Integer(0), True)), Pow(gamma(Symbol('a1', positive=True)), + Integer(-1)), gamma(Add(Symbol('a1', positive=True), + Indexed(IndexedBase(Symbol('C')), Integer(0)))), + Pow(gamma(Indexed(IndexedBase(Symbol('C')), Integer(0))), Integer(-1)), + Pow(Indexed(IndexedBase(Symbol('B')), Integer(0)), + Add(Symbol('a1', positive=True), Integer(-1))), + Pow(Indexed(IndexedBase(Symbol('B')), Integer(1)), + Add(Indexed(IndexedBase(Symbol('C')), Integer(0)), Integer(-1)))) + assert MGR(C) == mgrc + +def test_compound_distribution(): + Y = Poisson('Y', 1) + Z = Poisson('Z', Y) + assert isinstance(pspace(Z), CompoundPSpace) + assert isinstance(pspace(Z).distribution, CompoundDistribution) + assert Z.pspace.distribution.pdf(1).doit() == exp(-2)*exp(exp(-1)) + +def test_mix_expression(): + Y, E = Poisson('Y', 1), Exponential('E', 1) + k = Dummy('k') + expr1 = Integral(Sum(exp(-1)*Integral(exp(-k)*DiracDelta(k - 2), (k, 0, oo) + )/factorial(k), (k, 0, oo)), (k, -oo, 0)) + expr2 = Integral(Sum(exp(-1)*Integral(exp(-k)*DiracDelta(k - 2), (k, 0, oo) + )/factorial(k), (k, 0, oo)), (k, 0, oo)) + assert P(Eq(Y + E, 1)) == 0 + assert P(Ne(Y + E, 2)) == 1 + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert P(E + Y < 2, evaluate=False).rewrite(Integral).dummy_eq(expr1) + assert P(E + Y > 2, evaluate=False).rewrite(Integral).dummy_eq(expr2) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_random_matrix.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_random_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..ba570a16bc42620d53bce19be71e7d125965ede1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_random_matrix.py @@ -0,0 +1,135 @@ +from sympy.concrete.products import Product +from sympy.core.function import Lambda +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals.integrals import Integral +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.trace import Trace +from sympy.tensor.indexed import IndexedBase +from sympy.stats import (GaussianUnitaryEnsemble as GUE, density, + GaussianOrthogonalEnsemble as GOE, + GaussianSymplecticEnsemble as GSE, + joint_eigen_distribution, + CircularUnitaryEnsemble as CUE, + CircularOrthogonalEnsemble as COE, + CircularSymplecticEnsemble as CSE, + JointEigenDistribution, + level_spacing_distribution, + Normal, Beta) +from sympy.stats.joint_rv_types import JointDistributionHandmade +from sympy.stats.rv import RandomMatrixSymbol +from sympy.stats.random_matrix_models import GaussianEnsemble, RandomMatrixPSpace +from sympy.testing.pytest import raises + +def test_GaussianEnsemble(): + G = GaussianEnsemble('G', 3) + assert density(G) == G.pspace.model + raises(ValueError, lambda: GaussianEnsemble('G', 3.5)) + +def test_GaussianUnitaryEnsemble(): + H = RandomMatrixSymbol('H', 3, 3) + G = GUE('U', 3) + assert density(G)(H) == sqrt(2)*exp(-3*Trace(H**2)/2)/(4*pi**Rational(9, 2)) + i, j = (Dummy('i', integer=True, positive=True), + Dummy('j', integer=True, positive=True)) + l = IndexedBase('l') + assert joint_eigen_distribution(G).dummy_eq( + Lambda((l[1], l[2], l[3]), + 27*sqrt(6)*exp(-3*(l[1]**2)/2 - 3*(l[2]**2)/2 - 3*(l[3]**2)/2)* + Product(Abs(l[i] - l[j])**2, (j, i + 1, 3), (i, 1, 2))/(16*pi**Rational(3, 2)))) + s = Dummy('s') + assert level_spacing_distribution(G).dummy_eq(Lambda(s, 32*s**2*exp(-4*s**2/pi)/pi**2)) + + +def test_GaussianOrthogonalEnsemble(): + H = RandomMatrixSymbol('H', 3, 3) + _H = MatrixSymbol('_H', 3, 3) + G = GOE('O', 3) + assert density(G)(H) == exp(-3*Trace(H**2)/4)/Integral(exp(-3*Trace(_H**2)/4), _H) + i, j = (Dummy('i', integer=True, positive=True), + Dummy('j', integer=True, positive=True)) + l = IndexedBase('l') + assert joint_eigen_distribution(G).dummy_eq( + Lambda((l[1], l[2], l[3]), + 9*sqrt(2)*exp(-3*l[1]**2/2 - 3*l[2]**2/2 - 3*l[3]**2/2)* + Product(Abs(l[i] - l[j]), (j, i + 1, 3), (i, 1, 2))/(32*pi))) + s = Dummy('s') + assert level_spacing_distribution(G).dummy_eq(Lambda(s, s*pi*exp(-s**2*pi/4)/2)) + +def test_GaussianSymplecticEnsemble(): + H = RandomMatrixSymbol('H', 3, 3) + _H = MatrixSymbol('_H', 3, 3) + G = GSE('O', 3) + assert density(G)(H) == exp(-3*Trace(H**2))/Integral(exp(-3*Trace(_H**2)), _H) + i, j = (Dummy('i', integer=True, positive=True), + Dummy('j', integer=True, positive=True)) + l = IndexedBase('l') + assert joint_eigen_distribution(G).dummy_eq( + Lambda((l[1], l[2], l[3]), + 162*sqrt(3)*exp(-3*l[1]**2/2 - 3*l[2]**2/2 - 3*l[3]**2/2)* + Product(Abs(l[i] - l[j])**4, (j, i + 1, 3), (i, 1, 2))/(5*pi**Rational(3, 2)))) + s = Dummy('s') + assert level_spacing_distribution(G).dummy_eq(Lambda(s, S(262144)*s**4*exp(-64*s**2/(9*pi))/(729*pi**3))) + +def test_CircularUnitaryEnsemble(): + CU = CUE('U', 3) + j, k = (Dummy('j', integer=True, positive=True), + Dummy('k', integer=True, positive=True)) + t = IndexedBase('t') + assert joint_eigen_distribution(CU).dummy_eq( + Lambda((t[1], t[2], t[3]), + Product(Abs(exp(I*t[j]) - exp(I*t[k]))**2, + (j, k + 1, 3), (k, 1, 2))/(48*pi**3)) + ) + +def test_CircularOrthogonalEnsemble(): + CO = COE('U', 3) + j, k = (Dummy('j', integer=True, positive=True), + Dummy('k', integer=True, positive=True)) + t = IndexedBase('t') + assert joint_eigen_distribution(CO).dummy_eq( + Lambda((t[1], t[2], t[3]), + Product(Abs(exp(I*t[j]) - exp(I*t[k])), + (j, k + 1, 3), (k, 1, 2))/(48*pi**2)) + ) + +def test_CircularSymplecticEnsemble(): + CS = CSE('U', 3) + j, k = (Dummy('j', integer=True, positive=True), + Dummy('k', integer=True, positive=True)) + t = IndexedBase('t') + assert joint_eigen_distribution(CS).dummy_eq( + Lambda((t[1], t[2], t[3]), + Product(Abs(exp(I*t[j]) - exp(I*t[k]))**4, + (j, k + 1, 3), (k, 1, 2))/(720*pi**3)) + ) + +def test_JointEigenDistribution(): + A = Matrix([[Normal('A00', 0, 1), Normal('A01', 1, 1)], + [Beta('A10', 1, 1), Beta('A11', 1, 1)]]) + assert JointEigenDistribution(A) == \ + JointDistributionHandmade(-sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + A[1, 1]**2)/2 + + A[0, 0]/2 + A[1, 1]/2, sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + A[1, 1]**2)/2 + A[0, 0]/2 + A[1, 1]/2) + raises(ValueError, lambda: JointEigenDistribution(Matrix([[1, 0], [2, 1]]))) + +def test_issue_19841(): + G1 = GUE('U', 2) + G2 = G1.xreplace({2: 2}) + assert G1.args == G2.args + + X = MatrixSymbol('X', 2, 2) + G = GSE('U', 2) + h_pspace = RandomMatrixPSpace('P', model=density(G)) + H = RandomMatrixSymbol('H', 2, 2, pspace=h_pspace) + H2 = RandomMatrixSymbol('H', 2, 2, pspace=None) + assert H.doit() == H + + assert (2*H).xreplace({H: X}) == 2*X + assert (2*H).xreplace({H2: X}) == 2*H + assert (2*H2).xreplace({H: X}) == 2*H2 + assert (2*H2).xreplace({H2: X}) == 2*X diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_rv.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..185756300556f2fe70b76c402113ec2bb2501ef4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_rv.py @@ -0,0 +1,441 @@ +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.function import Lambda +from sympy.core.numbers import (Rational, nan, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import (FallingFactorial, binomial) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.delta_functions import DiracDelta +from sympy.integrals.integrals import integrate +from sympy.logic.boolalg import (And, Or) +from sympy.matrices.dense import Matrix +from sympy.sets.sets import Interval +from sympy.tensor.indexed import Indexed +from sympy.stats import (Die, Normal, Exponential, FiniteRV, P, E, H, variance, + density, given, independent, dependent, where, pspace, GaussianUnitaryEnsemble, + random_symbols, sample, Geometric, factorial_moment, Binomial, Hypergeometric, + DiscreteUniform, Poisson, characteristic_function, moment_generating_function, + BernoulliProcess, Variance, Expectation, Probability, Covariance, covariance, cmoment, + moment, median) +from sympy.stats.rv import (IndependentProductPSpace, rs_swap, Density, NamedArgsMixin, + RandomSymbol, sample_iter, PSpace, is_random, RandomIndexedSymbol, RandomMatrixSymbol) +from sympy.testing.pytest import raises, skip, XFAIL, warns_deprecated_sympy +from sympy.external import import_module +from sympy.core.numbers import comp +from sympy.stats.frv_types import BernoulliDistribution +from sympy.core.symbol import Dummy +from sympy.functions.elementary.piecewise import Piecewise + +def test_where(): + X, Y = Die('X'), Die('Y') + Z = Normal('Z', 0, 1) + + assert where(Z**2 <= 1).set == Interval(-1, 1) + assert where(Z**2 <= 1).as_boolean() == Interval(-1, 1).as_relational(Z.symbol) + assert where(And(X > Y, Y > 4)).as_boolean() == And( + Eq(X.symbol, 6), Eq(Y.symbol, 5)) + + assert len(where(X < 3).set) == 2 + assert 1 in where(X < 3).set + + X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + assert where(And(X**2 <= 1, X >= 0)).set == Interval(0, 1) + XX = given(X, And(X**2 <= 1, X >= 0)) + assert XX.pspace.domain.set == Interval(0, 1) + assert XX.pspace.domain.as_boolean() == \ + And(0 <= X.symbol, X.symbol**2 <= 1, -oo < X.symbol, X.symbol < oo) + + with raises(TypeError): + XX = given(X, X + 3) + + +def test_random_symbols(): + X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + + assert set(random_symbols(2*X + 1)) == {X} + assert set(random_symbols(2*X + Y)) == {X, Y} + assert set(random_symbols(2*X + Y.symbol)) == {X} + assert set(random_symbols(2)) == set() + + +def test_characteristic_function(): + # Imports I from sympy + from sympy.core.numbers import I + X = Normal('X',0,1) + Y = DiscreteUniform('Y', [1,2,7]) + Z = Poisson('Z', 2) + t = symbols('_t') + P = Lambda(t, exp(-t**2/2)) + Q = Lambda(t, exp(7*t*I)/3 + exp(2*t*I)/3 + exp(t*I)/3) + R = Lambda(t, exp(2 * exp(t*I) - 2)) + + + assert characteristic_function(X).dummy_eq(P) + assert characteristic_function(Y).dummy_eq(Q) + assert characteristic_function(Z).dummy_eq(R) + + +def test_moment_generating_function(): + + X = Normal('X',0,1) + Y = DiscreteUniform('Y', [1,2,7]) + Z = Poisson('Z', 2) + t = symbols('_t') + P = Lambda(t, exp(t**2/2)) + Q = Lambda(t, (exp(7*t)/3 + exp(2*t)/3 + exp(t)/3)) + R = Lambda(t, exp(2 * exp(t) - 2)) + + + assert moment_generating_function(X).dummy_eq(P) + assert moment_generating_function(Y).dummy_eq(Q) + assert moment_generating_function(Z).dummy_eq(R) + +def test_sample_iter(): + + X = Normal('X',0,1) + Y = DiscreteUniform('Y', [1, 2, 7]) + Z = Poisson('Z', 2) + + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + expr = X**2 + 3 + iterator = sample_iter(expr) + + expr2 = Y**2 + 5*Y + 4 + iterator2 = sample_iter(expr2) + + expr3 = Z**3 + 4 + iterator3 = sample_iter(expr3) + + def is_iterator(obj): + if ( + hasattr(obj, '__iter__') and + (hasattr(obj, 'next') or + hasattr(obj, '__next__')) and + callable(obj.__iter__) and + obj.__iter__() is obj + ): + return True + else: + return False + assert is_iterator(iterator) + assert is_iterator(iterator2) + assert is_iterator(iterator3) + +def test_pspace(): + X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + x = Symbol('x') + + raises(ValueError, lambda: pspace(5 + 3)) + raises(ValueError, lambda: pspace(x < 1)) + assert pspace(X) == X.pspace + assert pspace(2*X + 1) == X.pspace + assert pspace(2*X + Y) == IndependentProductPSpace(Y.pspace, X.pspace) + +def test_rs_swap(): + X = Normal('x', 0, 1) + Y = Exponential('y', 1) + + XX = Normal('x', 0, 2) + YY = Normal('y', 0, 3) + + expr = 2*X + Y + assert expr.subs(rs_swap((X, Y), (YY, XX))) == 2*XX + YY + + +def test_RandomSymbol(): + + X = Normal('x', 0, 1) + Y = Normal('x', 0, 2) + assert X.symbol == Y.symbol + assert X != Y + + assert X.name == X.symbol.name + + X = Normal('lambda', 0, 1) # make sure we can use protected terms + X = Normal('Lambda', 0, 1) # make sure we can use SymPy terms + + +def test_RandomSymbol_diff(): + X = Normal('x', 0, 1) + assert (2*X).diff(X) + + +def test_random_symbol_no_pspace(): + x = RandomSymbol(Symbol('x')) + assert x.pspace == PSpace() + +def test_overlap(): + X = Normal('x', 0, 1) + Y = Normal('x', 0, 2) + + raises(ValueError, lambda: P(X > Y)) + + +def test_IndependentProductPSpace(): + X = Normal('X', 0, 1) + Y = Normal('Y', 0, 1) + px = X.pspace + py = Y.pspace + assert pspace(X + Y) == IndependentProductPSpace(px, py) + assert pspace(X + Y) == IndependentProductPSpace(py, px) + + +def test_E(): + assert E(5) == 5 + + +def test_H(): + X = Normal('X', 0, 1) + D = Die('D', sides = 4) + G = Geometric('G', 0.5) + assert H(X, X > 0) == -log(2)/2 + S.Half + log(pi)/2 + assert H(D, D > 2) == log(2) + assert comp(H(G).evalf().round(2), 1.39) + + +def test_Sample(): + X = Die('X', 6) + Y = Normal('Y', 0, 1) + z = Symbol('z', integer=True) + + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + assert sample(X) in [1, 2, 3, 4, 5, 6] + assert isinstance(sample(X + Y), float) + + assert P(X + Y > 0, Y < 0, numsamples=10).is_number + assert E(X + Y, numsamples=10).is_number + assert E(X**2 + Y, numsamples=10).is_number + assert E((X + Y)**2, numsamples=10).is_number + assert variance(X + Y, numsamples=10).is_number + + raises(TypeError, lambda: P(Y > z, numsamples=5)) + + assert P(sin(Y) <= 1, numsamples=10) == 1.0 + assert P(sin(Y) <= 1, cos(Y) < 1, numsamples=10) == 1.0 + + assert all(i in range(1, 7) for i in density(X, numsamples=10)) + assert all(i in range(4, 7) for i in density(X, X>3, numsamples=10)) + + numpy = import_module('numpy') + if not numpy: + skip('Numpy is not installed. Abort tests') + #Test Issue #21563: Output of sample must be a float or array + assert isinstance(sample(X), (numpy.int32, numpy.int64)) + assert isinstance(sample(Y), numpy.float64) + assert isinstance(sample(X, size=2), numpy.ndarray) + + with warns_deprecated_sympy(): + sample(X, numsamples=2) + +@XFAIL +def test_samplingE(): + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + Y = Normal('Y', 0, 1) + z = Symbol('z', integer=True) + assert E(Sum(1/z**Y, (z, 1, oo)), Y > 2, numsamples=3).is_number + + +def test_given(): + X = Normal('X', 0, 1) + Y = Normal('Y', 0, 1) + A = given(X, True) + B = given(X, Y > 2) + + assert X == A == B + + +def test_factorial_moment(): + X = Poisson('X', 2) + Y = Binomial('Y', 2, S.Half) + Z = Hypergeometric('Z', 4, 2, 2) + assert factorial_moment(X, 2) == 4 + assert factorial_moment(Y, 2) == S.Half + assert factorial_moment(Z, 2) == Rational(1, 3) + + x, y, z, l = symbols('x y z l') + Y = Binomial('Y', 2, y) + Z = Hypergeometric('Z', 10, 2, 3) + assert factorial_moment(Y, l) == y**2*FallingFactorial( + 2, l) + 2*y*(1 - y)*FallingFactorial(1, l) + (1 - y)**2*\ + FallingFactorial(0, l) + assert factorial_moment(Z, l) == 7*FallingFactorial(0, l)/\ + 15 + 7*FallingFactorial(1, l)/15 + FallingFactorial(2, l)/15 + + +def test_dependence(): + X, Y = Die('X'), Die('Y') + assert independent(X, 2*Y) + assert not dependent(X, 2*Y) + + X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + assert independent(X, Y) + assert dependent(X, 2*X) + + # Create a dependency + XX, YY = given(Tuple(X, Y), Eq(X + Y, 3)) + assert dependent(XX, YY) + +def test_dependent_finite(): + X, Y = Die('X'), Die('Y') + # Dependence testing requires symbolic conditions which currently break + # finite random variables + assert dependent(X, Y + X) + + XX, YY = given(Tuple(X, Y), X + Y > 5) # Create a dependency + assert dependent(XX, YY) + + +def test_normality(): + X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + x = Symbol('x', real=True) + z = Symbol('z', real=True) + dens = density(X - Y, Eq(X + Y, z)) + + assert integrate(dens(x), (x, -oo, oo)) == 1 + + +def test_Density(): + X = Die('X', 6) + d = Density(X) + assert d.doit() == density(X) + +def test_NamedArgsMixin(): + class Foo(Basic, NamedArgsMixin): + _argnames = 'foo', 'bar' + + a = Foo(S(1), S(2)) + + assert a.foo == 1 + assert a.bar == 2 + + raises(AttributeError, lambda: a.baz) + + class Bar(Basic, NamedArgsMixin): + pass + + raises(AttributeError, lambda: Bar(S(1), S(2)).foo) + +def test_density_constant(): + assert density(3)(2) == 0 + assert density(3)(3) == DiracDelta(0) + +def test_cmoment_constant(): + assert variance(3) == 0 + assert cmoment(3, 3) == 0 + assert cmoment(3, 4) == 0 + x = Symbol('x') + assert variance(x) == 0 + assert cmoment(x, 15) == 0 + assert cmoment(x, 0) == 1 + +def test_moment_constant(): + assert moment(3, 0) == 1 + assert moment(3, 1) == 3 + assert moment(3, 2) == 9 + x = Symbol('x') + assert moment(x, 2) == x**2 + +def test_median_constant(): + assert median(3) == 3 + x = Symbol('x') + assert median(x) == x + +def test_real(): + x = Normal('x', 0, 1) + assert x.is_real + + +def test_issue_10052(): + X = Exponential('X', 3) + assert P(X < oo) == 1 + assert P(X > oo) == 0 + assert P(X < 2, X > oo) == 0 + assert P(X < oo, X > oo) == 0 + assert P(X < oo, X > 2) == 1 + assert P(X < 3, X == 2) == 0 + raises(ValueError, lambda: P(1)) + raises(ValueError, lambda: P(X < 1, 2)) + +def test_issue_11934(): + density = {0: .5, 1: .5} + X = FiniteRV('X', density) + assert E(X) == 0.5 + assert P( X>= 2) == 0 + +def test_issue_8129(): + X = Exponential('X', 4) + assert P(X >= X) == 1 + assert P(X > X) == 0 + assert P(X > X+1) == 0 + +def test_issue_12237(): + X = Normal('X', 0, 1) + Y = Normal('Y', 0, 1) + U = P(X > 0, X) + V = P(Y < 0, X) + W = P(X + Y > 0, X) + assert W == P(X + Y > 0, X) + assert U == BernoulliDistribution(S.Half, S.Zero, S.One) + assert V == S.Half + +def test_is_random(): + X = Normal('X', 0, 1) + Y = Normal('Y', 0, 1) + a, b = symbols('a, b') + G = GaussianUnitaryEnsemble('U', 2) + B = BernoulliProcess('B', 0.9) + assert not is_random(a) + assert not is_random(a + b) + assert not is_random(a * b) + assert not is_random(Matrix([a**2, b**2])) + assert is_random(X) + assert is_random(X**2 + Y) + assert is_random(Y + b**2) + assert is_random(Y > 5) + assert is_random(B[3] < 1) + assert is_random(G) + assert is_random(X * Y * B[1]) + assert is_random(Matrix([[X, B[2]], [G, Y]])) + assert is_random(Eq(X, 4)) + +def test_issue_12283(): + x = symbols('x') + X = RandomSymbol(x) + Y = RandomSymbol('Y') + Z = RandomMatrixSymbol('Z', 2, 1) + W = RandomMatrixSymbol('W', 2, 1) + RI = RandomIndexedSymbol(Indexed('RI', 3)) + assert pspace(Z) == PSpace() + assert pspace(RI) == PSpace() + assert pspace(X) == PSpace() + assert E(X) == Expectation(X) + assert P(Y > 3) == Probability(Y > 3) + assert variance(X) == Variance(X) + assert variance(RI) == Variance(RI) + assert covariance(X, Y) == Covariance(X, Y) + assert covariance(W, Z) == Covariance(W, Z) + +def test_issue_6810(): + X = Die('X', 6) + Y = Normal('Y', 0, 1) + assert P(Eq(X, 2)) == S(1)/6 + assert P(Eq(Y, 0)) == 0 + assert P(Or(X > 2, X < 3)) == 1 + assert P(And(X > 3, X > 2)) == S(1)/2 + +def test_issue_20286(): + n, p = symbols('n p') + B = Binomial('B', n, p) + k = Dummy('k', integer = True) + eq = Sum(Piecewise((-p**k*(1 - p)**(-k + n)*log(p**k*(1 - p)**(-k + n)*binomial(n, k))*binomial(n, k), (k >= 0) & (k <= n)), (nan, True)), (k, 0, n)) + assert eq.dummy_eq(H(B)) diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_stochastic_process.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_stochastic_process.py new file mode 100644 index 0000000000000000000000000000000000000000..3e42ffc8632240b1d85a774467a057e9857c567c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_stochastic_process.py @@ -0,0 +1,763 @@ +from sympy.concrete.summations import Sum +from sympy.core.containers import Tuple +from sympy.core.function import Lambda +from sympy.core.numbers import (Float, Rational, oo, pi) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.integers import ceiling +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import (gamma, lowergamma) +from sympy.logic.boolalg import (And, Not) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.immutable import ImmutableMatrix +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Range +from sympy.sets.sets import (FiniteSet, Interval) +from sympy.stats import (DiscreteMarkovChain, P, TransitionMatrixOf, E, + StochasticStateSpaceOf, variance, ContinuousMarkovChain, + BernoulliProcess, PoissonProcess, WienerProcess, + GammaProcess, sample_stochastic_process) +from sympy.stats.joint_rv import JointDistribution +from sympy.stats.joint_rv_types import JointDistributionHandmade +from sympy.stats.rv import RandomIndexedSymbol +from sympy.stats.symbolic_probability import Probability, Expectation +from sympy.testing.pytest import (raises, skip, ignore_warnings, + warns_deprecated_sympy) +from sympy.external import import_module +from sympy.stats.frv_types import BernoulliDistribution +from sympy.stats.drv_types import PoissonDistribution +from sympy.stats.crv_types import NormalDistribution, GammaDistribution +from sympy.core.symbol import Str + + +def test_DiscreteMarkovChain(): + + # pass only the name + X = DiscreteMarkovChain("X") + assert isinstance(X.state_space, Range) + assert X.index_set == S.Naturals0 + assert isinstance(X.transition_probabilities, MatrixSymbol) + t = symbols('t', positive=True, integer=True) + assert isinstance(X[t], RandomIndexedSymbol) + assert E(X[0]) == Expectation(X[0]) + raises(TypeError, lambda: DiscreteMarkovChain(1)) + raises(NotImplementedError, lambda: X(t)) + raises(NotImplementedError, lambda: X.communication_classes()) + raises(NotImplementedError, lambda: X.canonical_form()) + raises(NotImplementedError, lambda: X.decompose()) + + nz = Symbol('n', integer=True) + TZ = MatrixSymbol('M', nz, nz) + SZ = Range(nz) + YZ = DiscreteMarkovChain('Y', SZ, TZ) + assert P(Eq(YZ[2], 1), Eq(YZ[1], 0)) == TZ[0, 1] + + raises(ValueError, lambda: sample_stochastic_process(t)) + raises(ValueError, lambda: next(sample_stochastic_process(X))) + # pass name and state_space + # any hashable object should be a valid state + # states should be valid as a tuple/set/list/Tuple/Range + sym, rainy, cloudy, sunny = symbols('a Rainy Cloudy Sunny', real=True) + state_spaces = [(1, 2, 3), [Str('Hello'), sym, DiscreteMarkovChain("Y", (1,2,3))], + Tuple(S(1), exp(sym), Str('World'), sympify=False), Range(-1, 5, 2), + [rainy, cloudy, sunny]] + chains = [DiscreteMarkovChain("Y", state_space) for state_space in state_spaces] + + for i, Y in enumerate(chains): + assert isinstance(Y.transition_probabilities, MatrixSymbol) + assert Y.state_space == state_spaces[i] or Y.state_space == FiniteSet(*state_spaces[i]) + assert Y.number_of_states == 3 + + with ignore_warnings(UserWarning): # TODO: Restore tests once warnings are removed + assert P(Eq(Y[2], 1), Eq(Y[0], 2), evaluate=False) == Probability(Eq(Y[2], 1), Eq(Y[0], 2)) + assert E(Y[0]) == Expectation(Y[0]) + + raises(ValueError, lambda: next(sample_stochastic_process(Y))) + + raises(TypeError, lambda: DiscreteMarkovChain("Y", {1: 1})) + Y = DiscreteMarkovChain("Y", Range(1, t, 2)) + assert Y.number_of_states == ceiling((t-1)/2) + + # pass name and transition_probabilities + chains = [DiscreteMarkovChain("Y", trans_probs=Matrix([])), + DiscreteMarkovChain("Y", trans_probs=Matrix([[0, 1], [1, 0]])), + DiscreteMarkovChain("Y", trans_probs=Matrix([[pi, 1-pi], [sym, 1-sym]]))] + for Z in chains: + assert Z.number_of_states == Z.transition_probabilities.shape[0] + assert isinstance(Z.transition_probabilities, ImmutableMatrix) + + # pass name, state_space and transition_probabilities + T = Matrix([[0.5, 0.2, 0.3],[0.2, 0.5, 0.3],[0.2, 0.3, 0.5]]) + TS = MatrixSymbol('T', 3, 3) + Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + YS = DiscreteMarkovChain("Y", ['One', 'Two', 3], TS) + assert Y.joint_distribution(1, Y[2], 3) == JointDistribution(Y[1], Y[2], Y[3]) + raises(ValueError, lambda: Y.joint_distribution(Y[1].symbol, Y[2].symbol)) + assert P(Eq(Y[3], 2), Eq(Y[1], 1)).round(2) == Float(0.36, 2) + assert (P(Eq(YS[3], 2), Eq(YS[1], 1)) - + (TS[0, 2]*TS[1, 0] + TS[1, 1]*TS[1, 2] + TS[1, 2]*TS[2, 2])).simplify() == 0 + assert P(Eq(YS[1], 1), Eq(YS[2], 2)) == Probability(Eq(YS[1], 1)) + assert P(Eq(YS[3], 3), Eq(YS[1], 1)) == TS[0, 2]*TS[1, 0] + TS[1, 1]*TS[1, 2] + TS[1, 2]*TS[2, 2] + TO = Matrix([[0.25, 0.75, 0],[0, 0.25, 0.75],[0.75, 0, 0.25]]) + assert P(Eq(Y[3], 2), Eq(Y[1], 1) & TransitionMatrixOf(Y, TO)).round(3) == Float(0.375, 3) + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert E(Y[3], evaluate=False) == Expectation(Y[3]) + assert E(Y[3], Eq(Y[2], 1)).round(2) == Float(1.1, 3) + TSO = MatrixSymbol('T', 4, 4) + raises(ValueError, lambda: str(P(Eq(YS[3], 2), Eq(YS[1], 1) & TransitionMatrixOf(YS, TSO)))) + raises(TypeError, lambda: DiscreteMarkovChain("Z", [0, 1, 2], symbols('M'))) + raises(ValueError, lambda: DiscreteMarkovChain("Z", [0, 1, 2], MatrixSymbol('T', 3, 4))) + raises(ValueError, lambda: E(Y[3], Eq(Y[2], 6))) + raises(ValueError, lambda: E(Y[2], Eq(Y[3], 1))) + + + # extended tests for probability queries + TO1 = Matrix([[Rational(1, 4), Rational(3, 4), 0],[Rational(1, 3), Rational(1, 3), Rational(1, 3)],[0, Rational(1, 4), Rational(3, 4)]]) + assert P(And(Eq(Y[2], 1), Eq(Y[1], 1), Eq(Y[0], 0)), + Eq(Probability(Eq(Y[0], 0)), Rational(1, 4)) & TransitionMatrixOf(Y, TO1)) == Rational(1, 16) + assert P(And(Eq(Y[2], 1), Eq(Y[1], 1), Eq(Y[0], 0)), TransitionMatrixOf(Y, TO1)) == \ + Probability(Eq(Y[0], 0))/4 + assert P(Lt(X[1], 2) & Gt(X[1], 0), Eq(X[0], 2) & + StochasticStateSpaceOf(X, [0, 1, 2]) & TransitionMatrixOf(X, TO1)) == Rational(1, 4) + assert P(Lt(X[1], 2) & Gt(X[1], 0), Eq(X[0], 2) & + StochasticStateSpaceOf(X, [S(0), '0', 1]) & TransitionMatrixOf(X, TO1)) == Rational(1, 4) + assert P(Ne(X[1], 2) & Ne(X[1], 1), Eq(X[0], 2) & + StochasticStateSpaceOf(X, [0, 1, 2]) & TransitionMatrixOf(X, TO1)) is S.Zero + assert P(Ne(X[1], 2) & Ne(X[1], 1), Eq(X[0], 2) & + StochasticStateSpaceOf(X, [S(0), '0', 1]) & TransitionMatrixOf(X, TO1)) is S.Zero + assert P(And(Eq(Y[2], 1), Eq(Y[1], 1), Eq(Y[0], 0)), Eq(Y[1], 1)) == 0.1*Probability(Eq(Y[0], 0)) + + # testing properties of Markov chain + TO2 = Matrix([[S.One, 0, 0],[Rational(1, 3), Rational(1, 3), Rational(1, 3)],[0, Rational(1, 4), Rational(3, 4)]]) + TO3 = Matrix([[Rational(1, 4), Rational(3, 4), 0],[Rational(1, 3), Rational(1, 3), Rational(1, 3)], [0, Rational(1, 4), Rational(3, 4)]]) + Y2 = DiscreteMarkovChain('Y', trans_probs=TO2) + Y3 = DiscreteMarkovChain('Y', trans_probs=TO3) + assert Y3.fundamental_matrix() == ImmutableMatrix([[176, 81, -132], [36, 141, -52], [-44, -39, 208]])/125 + assert Y2.is_absorbing_chain() == True + assert Y3.is_absorbing_chain() == False + assert Y2.canonical_form() == ([0, 1, 2], TO2) + assert Y3.canonical_form() == ([0, 1, 2], TO3) + assert Y2.decompose() == ([0, 1, 2], TO2[0:1, 0:1], TO2[1:3, 0:1], TO2[1:3, 1:3]) + assert Y3.decompose() == ([0, 1, 2], TO3, Matrix(0, 3, []), Matrix(0, 0, [])) + TO4 = Matrix([[Rational(1, 5), Rational(2, 5), Rational(2, 5)], [Rational(1, 10), S.Half, Rational(2, 5)], [Rational(3, 5), Rational(3, 10), Rational(1, 10)]]) + Y4 = DiscreteMarkovChain('Y', trans_probs=TO4) + w = ImmutableMatrix([[Rational(11, 39), Rational(16, 39), Rational(4, 13)]]) + assert Y4.limiting_distribution == w + assert Y4.is_regular() == True + assert Y4.is_ergodic() == True + TS1 = MatrixSymbol('T', 3, 3) + Y5 = DiscreteMarkovChain('Y', trans_probs=TS1) + assert Y5.limiting_distribution(w, TO4).doit() == True + assert Y5.stationary_distribution(condition_set=True).subs(TS1, TO4).contains(w).doit() == S.true + TO6 = Matrix([[S.One, 0, 0, 0, 0],[S.Half, 0, S.Half, 0, 0],[0, S.Half, 0, S.Half, 0], [0, 0, S.Half, 0, S.Half], [0, 0, 0, 0, 1]]) + Y6 = DiscreteMarkovChain('Y', trans_probs=TO6) + assert Y6.fundamental_matrix() == ImmutableMatrix([[Rational(3, 2), S.One, S.Half], [S.One, S(2), S.One], [S.Half, S.One, Rational(3, 2)]]) + assert Y6.absorbing_probabilities() == ImmutableMatrix([[Rational(3, 4), Rational(1, 4)], [S.Half, S.Half], [Rational(1, 4), Rational(3, 4)]]) + with warns_deprecated_sympy(): + Y6.absorbing_probabilites() + TO7 = Matrix([[Rational(1, 2), Rational(1, 4), Rational(1, 4)], [Rational(1, 2), 0, Rational(1, 2)], [Rational(1, 4), Rational(1, 4), Rational(1, 2)]]) + Y7 = DiscreteMarkovChain('Y', trans_probs=TO7) + assert Y7.is_absorbing_chain() == False + assert Y7.fundamental_matrix() == ImmutableMatrix([[Rational(86, 75), Rational(1, 25), Rational(-14, 75)], + [Rational(2, 25), Rational(21, 25), Rational(2, 25)], + [Rational(-14, 75), Rational(1, 25), Rational(86, 75)]]) + + # test for zero-sized matrix functionality + X = DiscreteMarkovChain('X', trans_probs=Matrix([])) + assert X.number_of_states == 0 + assert X.stationary_distribution() == Matrix([[]]) + assert X.communication_classes() == [] + assert X.canonical_form() == ([], Matrix([])) + assert X.decompose() == ([], Matrix([]), Matrix([]), Matrix([])) + assert X.is_regular() == False + assert X.is_ergodic() == False + + # test communication_class + # see https://drive.google.com/drive/folders/1HbxLlwwn2b3U8Lj7eb_ASIUb5vYaNIjg?usp=sharing + # tutorial 2.pdf + TO7 = Matrix([[0, 5, 5, 0, 0], + [0, 0, 0, 10, 0], + [5, 0, 5, 0, 0], + [0, 10, 0, 0, 0], + [0, 3, 0, 3, 4]])/10 + Y7 = DiscreteMarkovChain('Y', trans_probs=TO7) + tuples = Y7.communication_classes() + classes, recurrence, periods = list(zip(*tuples)) + assert classes == ([1, 3], [0, 2], [4]) + assert recurrence == (True, False, False) + assert periods == (2, 1, 1) + + TO8 = Matrix([[0, 0, 0, 10, 0, 0], + [5, 0, 5, 0, 0, 0], + [0, 4, 0, 0, 0, 6], + [10, 0, 0, 0, 0, 0], + [0, 10, 0, 0, 0, 0], + [0, 0, 0, 5, 5, 0]])/10 + Y8 = DiscreteMarkovChain('Y', trans_probs=TO8) + tuples = Y8.communication_classes() + classes, recurrence, periods = list(zip(*tuples)) + assert classes == ([0, 3], [1, 2, 5, 4]) + assert recurrence == (True, False) + assert periods == (2, 2) + + TO9 = Matrix([[2, 0, 0, 3, 0, 0, 3, 2, 0, 0], + [0, 10, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 2, 2, 0, 0, 0, 0, 0, 3, 3], + [0, 0, 0, 3, 0, 0, 6, 1, 0, 0], + [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 10, 0, 0, 0, 0], + [4, 0, 0, 5, 0, 0, 1, 0, 0, 0], + [2, 0, 0, 4, 0, 0, 2, 2, 0, 0], + [3, 0, 1, 0, 0, 0, 0, 0, 4, 2], + [0, 0, 4, 0, 0, 0, 0, 0, 3, 3]])/10 + Y9 = DiscreteMarkovChain('Y', trans_probs=TO9) + tuples = Y9.communication_classes() + classes, recurrence, periods = list(zip(*tuples)) + assert classes == ([0, 3, 6, 7], [1], [2, 8, 9], [5], [4]) + assert recurrence == (True, True, False, True, False) + assert periods == (1, 1, 1, 1, 1) + + # test canonical form + # see https://web.archive.org/web/20201230182007/https://www.dartmouth.edu/~chance/teaching_aids/books_articles/probability_book/Chapter11.pdf + # example 11.13 + T = Matrix([[1, 0, 0, 0, 0], + [S(1) / 2, 0, S(1) / 2, 0, 0], + [0, S(1) / 2, 0, S(1) / 2, 0], + [0, 0, S(1) / 2, 0, S(1) / 2], + [0, 0, 0, 0, S(1)]]) + DW = DiscreteMarkovChain('DW', [0, 1, 2, 3, 4], T) + states, A, B, C = DW.decompose() + assert states == [0, 4, 1, 2, 3] + assert A == Matrix([[1, 0], [0, 1]]) + assert B == Matrix([[S(1)/2, 0], [0, 0], [0, S(1)/2]]) + assert C == Matrix([[0, S(1)/2, 0], [S(1)/2, 0, S(1)/2], [0, S(1)/2, 0]]) + states, new_matrix = DW.canonical_form() + assert states == [0, 4, 1, 2, 3] + assert new_matrix == Matrix([[1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [S(1)/2, 0, 0, S(1)/2, 0], + [0, 0, S(1)/2, 0, S(1)/2], + [0, S(1)/2, 0, S(1)/2, 0]]) + + # test regular and ergodic + # https://web.archive.org/web/20201230182007/https://www.dartmouth.edu/~chance/teaching_aids/books_articles/probability_book/Chapter11.pdf + T = Matrix([[0, 4, 0, 0, 0], + [1, 0, 3, 0, 0], + [0, 2, 0, 2, 0], + [0, 0, 3, 0, 1], + [0, 0, 0, 4, 0]])/4 + X = DiscreteMarkovChain('X', trans_probs=T) + assert not X.is_regular() + assert X.is_ergodic() + T = Matrix([[0, 1], [1, 0]]) + X = DiscreteMarkovChain('X', trans_probs=T) + assert not X.is_regular() + assert X.is_ergodic() + # http://www.math.wisc.edu/~valko/courses/331/MC2.pdf + T = Matrix([[2, 1, 1], + [2, 0, 2], + [1, 1, 2]])/4 + X = DiscreteMarkovChain('X', trans_probs=T) + assert X.is_regular() + assert X.is_ergodic() + # https://docs.ufpr.br/~lucambio/CE222/1S2014/Kemeny-Snell1976.pdf + T = Matrix([[1, 1], [1, 1]])/2 + X = DiscreteMarkovChain('X', trans_probs=T) + assert X.is_regular() + assert X.is_ergodic() + + # test is_absorbing_chain + T = Matrix([[0, 1, 0], + [1, 0, 0], + [0, 0, 1]]) + X = DiscreteMarkovChain('X', trans_probs=T) + assert not X.is_absorbing_chain() + # https://en.wikipedia.org/wiki/Absorbing_Markov_chain + T = Matrix([[1, 1, 0, 0], + [0, 1, 1, 0], + [1, 0, 0, 1], + [0, 0, 0, 2]])/2 + X = DiscreteMarkovChain('X', trans_probs=T) + assert X.is_absorbing_chain() + T = Matrix([[2, 0, 0, 0, 0], + [1, 0, 1, 0, 0], + [0, 1, 0, 1, 0], + [0, 0, 1, 0, 1], + [0, 0, 0, 0, 2]])/2 + X = DiscreteMarkovChain('X', trans_probs=T) + assert X.is_absorbing_chain() + + # test custom state space + Y10 = DiscreteMarkovChain('Y', [1, 2, 3], TO2) + tuples = Y10.communication_classes() + classes, recurrence, periods = list(zip(*tuples)) + assert classes == ([1], [2, 3]) + assert recurrence == (True, False) + assert periods == (1, 1) + assert Y10.canonical_form() == ([1, 2, 3], TO2) + assert Y10.decompose() == ([1, 2, 3], TO2[0:1, 0:1], TO2[1:3, 0:1], TO2[1:3, 1:3]) + + # testing miscellaneous queries + T = Matrix([[S.Half, Rational(1, 4), Rational(1, 4)], + [Rational(1, 3), 0, Rational(2, 3)], + [S.Half, S.Half, 0]]) + X = DiscreteMarkovChain('X', [0, 1, 2], T) + assert P(Eq(X[1], 2) & Eq(X[2], 1) & Eq(X[3], 0), + Eq(P(Eq(X[1], 0)), Rational(1, 4)) & Eq(P(Eq(X[1], 1)), Rational(1, 4))) == Rational(1, 12) + assert P(Eq(X[2], 1) | Eq(X[2], 2), Eq(X[1], 1)) == Rational(2, 3) + assert P(Eq(X[2], 1) & Eq(X[2], 2), Eq(X[1], 1)) is S.Zero + assert P(Ne(X[2], 2), Eq(X[1], 1)) == Rational(1, 3) + assert E(X[1]**2, Eq(X[0], 1)) == Rational(8, 3) + assert variance(X[1], Eq(X[0], 1)) == Rational(8, 9) + raises(ValueError, lambda: E(X[1], Eq(X[2], 1))) + raises(ValueError, lambda: DiscreteMarkovChain('X', [0, 1], T)) + + # testing miscellaneous queries with different state space + X = DiscreteMarkovChain('X', ['A', 'B', 'C'], T) + assert P(Eq(X[1], 2) & Eq(X[2], 1) & Eq(X[3], 0), + Eq(P(Eq(X[1], 0)), Rational(1, 4)) & Eq(P(Eq(X[1], 1)), Rational(1, 4))) == Rational(1, 12) + assert P(Eq(X[2], 1) | Eq(X[2], 2), Eq(X[1], 1)) == Rational(2, 3) + assert P(Eq(X[2], 1) & Eq(X[2], 2), Eq(X[1], 1)) is S.Zero + assert P(Ne(X[2], 2), Eq(X[1], 1)) == Rational(1, 3) + a = X.state_space.args[0] + c = X.state_space.args[2] + assert (E(X[1] ** 2, Eq(X[0], 1)) - (a**2/3 + 2*c**2/3)).simplify() == 0 + assert (variance(X[1], Eq(X[0], 1)) - (2*(-a/3 + c/3)**2/3 + (2*a/3 - 2*c/3)**2/3)).simplify() == 0 + raises(ValueError, lambda: E(X[1], Eq(X[2], 1))) + + #testing queries with multiple RandomIndexedSymbols + T = Matrix([[Rational(5, 10), Rational(3, 10), Rational(2, 10)], [Rational(2, 10), Rational(7, 10), Rational(1, 10)], [Rational(3, 10), Rational(3, 10), Rational(4, 10)]]) + Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + assert P(Eq(Y[7], Y[5]), Eq(Y[2], 0)).round(5) == Float(0.44428, 5) + assert P(Gt(Y[3], Y[1]), Eq(Y[0], 0)).round(2) == Float(0.36, 2) + assert P(Le(Y[5], Y[10]), Eq(Y[4], 2)).round(6) == Float(0.583120, 6) + assert Float(P(Eq(Y[10], Y[5]), Eq(Y[4], 1)), 14) == Float(1 - P(Ne(Y[10], Y[5]), Eq(Y[4], 1)), 14) + assert Float(P(Gt(Y[8], Y[9]), Eq(Y[3], 2)), 14) == Float(1 - P(Le(Y[8], Y[9]), Eq(Y[3], 2)), 14) + assert Float(P(Lt(Y[1], Y[4]), Eq(Y[0], 0)), 14) == Float(1 - P(Ge(Y[1], Y[4]), Eq(Y[0], 0)), 14) + assert P(Eq(Y[5], Y[10]), Eq(Y[2], 1)) == P(Eq(Y[10], Y[5]), Eq(Y[2], 1)) + assert P(Gt(Y[1], Y[2]), Eq(Y[0], 1)) == P(Lt(Y[2], Y[1]), Eq(Y[0], 1)) + assert P(Ge(Y[7], Y[6]), Eq(Y[4], 1)) == P(Le(Y[6], Y[7]), Eq(Y[4], 1)) + + #test symbolic queries + a, b, c, d = symbols('a b c d') + T = Matrix([[Rational(1, 10), Rational(4, 10), Rational(5, 10)], [Rational(3, 10), Rational(4, 10), Rational(3, 10)], [Rational(7, 10), Rational(2, 10), Rational(1, 10)]]) + Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + query = P(Eq(Y[a], b), Eq(Y[c], d)) + assert query.subs({a:10, b:2, c:5, d:1}).evalf().round(4) == P(Eq(Y[10], 2), Eq(Y[5], 1)).round(4) + assert query.subs({a:15, b:0, c:10, d:1}).evalf().round(4) == P(Eq(Y[15], 0), Eq(Y[10], 1)).round(4) + query_gt = P(Gt(Y[a], b), Eq(Y[c], d)) + query_le = P(Le(Y[a], b), Eq(Y[c], d)) + assert query_gt.subs({a:5, b:2, c:1, d:0}).evalf() + query_le.subs({a:5, b:2, c:1, d:0}).evalf() == 1.0 + query_ge = P(Ge(Y[a], b), Eq(Y[c], d)) + query_lt = P(Lt(Y[a], b), Eq(Y[c], d)) + assert query_ge.subs({a:4, b:1, c:0, d:2}).evalf() + query_lt.subs({a:4, b:1, c:0, d:2}).evalf() == 1.0 + + #test issue 20078 + assert (2*Y[1] + 3*Y[1]).simplify() == 5*Y[1] + assert (2*Y[1] - 3*Y[1]).simplify() == -Y[1] + assert (2*(0.25*Y[1])).simplify() == 0.5*Y[1] + assert ((2*Y[1]) * (0.25*Y[1])).simplify() == 0.5*Y[1]**2 + assert (Y[1]**2 + Y[1]**3).simplify() == (Y[1] + 1)*Y[1]**2 + +def test_sample_stochastic_process(): + if not import_module('scipy'): + skip('SciPy Not installed. Skip sampling tests') + import random + random.seed(0) + numpy = import_module('numpy') + if numpy: + numpy.random.seed(0) # scipy uses numpy to sample so to set its seed + T = Matrix([[0.5, 0.2, 0.3],[0.2, 0.5, 0.3],[0.2, 0.3, 0.5]]) + Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + for samps in range(10): + assert next(sample_stochastic_process(Y)) in Y.state_space + Z = DiscreteMarkovChain("Z", ['1', 1, 0], T) + for samps in range(10): + assert next(sample_stochastic_process(Z)) in Z.state_space + + T = Matrix([[S.Half, Rational(1, 4), Rational(1, 4)], + [Rational(1, 3), 0, Rational(2, 3)], + [S.Half, S.Half, 0]]) + X = DiscreteMarkovChain('X', [0, 1, 2], T) + for samps in range(10): + assert next(sample_stochastic_process(X)) in X.state_space + W = DiscreteMarkovChain('W', [1, pi, oo], T) + for samps in range(10): + assert next(sample_stochastic_process(W)) in W.state_space + + +def test_ContinuousMarkovChain(): + T1 = Matrix([[S(-2), S(2), S.Zero], + [S.Zero, S.NegativeOne, S.One], + [Rational(3, 2), Rational(3, 2), S(-3)]]) + C1 = ContinuousMarkovChain('C', [0, 1, 2], T1) + assert C1.limiting_distribution() == ImmutableMatrix([[Rational(3, 19), Rational(12, 19), Rational(4, 19)]]) + + T2 = Matrix([[-S.One, S.One, S.Zero], [S.One, -S.One, S.Zero], [S.Zero, S.One, -S.One]]) + C2 = ContinuousMarkovChain('C', [0, 1, 2], T2) + A, t = C2.generator_matrix, symbols('t', positive=True) + assert C2.transition_probabilities(A)(t) == Matrix([[S.Half + exp(-2*t)/2, S.Half - exp(-2*t)/2, 0], + [S.Half - exp(-2*t)/2, S.Half + exp(-2*t)/2, 0], + [S.Half - exp(-t) + exp(-2*t)/2, S.Half - exp(-2*t)/2, exp(-t)]]) + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert P(Eq(C2(1), 1), Eq(C2(0), 1), evaluate=False) == Probability(Eq(C2(1), 1), Eq(C2(0), 1)) + assert P(Eq(C2(1), 1), Eq(C2(0), 1)) == exp(-2)/2 + S.Half + assert P(Eq(C2(1), 0) & Eq(C2(2), 1) & Eq(C2(3), 1), + Eq(P(Eq(C2(1), 0)), S.Half)) == (Rational(1, 4) - exp(-2)/4)*(exp(-2)/2 + S.Half) + assert P(Not(Eq(C2(1), 0) & Eq(C2(2), 1) & Eq(C2(3), 2)) | + (Eq(C2(1), 0) & Eq(C2(2), 1) & Eq(C2(3), 2)), + Eq(P(Eq(C2(1), 0)), Rational(1, 4)) & Eq(P(Eq(C2(1), 1)), Rational(1, 4))) is S.One + assert E(C2(Rational(3, 2)), Eq(C2(0), 2)) == -exp(-3)/2 + 2*exp(Rational(-3, 2)) + S.Half + assert variance(C2(Rational(3, 2)), Eq(C2(0), 1)) == ((S.Half - exp(-3)/2)**2*(exp(-3)/2 + S.Half) + + (Rational(-1, 2) - exp(-3)/2)**2*(S.Half - exp(-3)/2)) + raises(KeyError, lambda: P(Eq(C2(1), 0), Eq(P(Eq(C2(1), 1)), S.Half))) + assert P(Eq(C2(1), 0), Eq(P(Eq(C2(5), 1)), S.Half)) == Probability(Eq(C2(1), 0)) + TS1 = MatrixSymbol('G', 3, 3) + CS1 = ContinuousMarkovChain('C', [0, 1, 2], TS1) + A = CS1.generator_matrix + assert CS1.transition_probabilities(A)(t) == exp(t*A) + + C3 = ContinuousMarkovChain('C', [Symbol('0'), Symbol('1'), Symbol('2')], T2) + assert P(Eq(C3(1), 1), Eq(C3(0), 1)) == exp(-2)/2 + S.Half + assert P(Eq(C3(1), Symbol('1')), Eq(C3(0), Symbol('1'))) == exp(-2)/2 + S.Half + + #test probability queries + G = Matrix([[-S(1), Rational(1, 10), Rational(9, 10)], [Rational(2, 5), -S(1), Rational(3, 5)], [Rational(1, 2), Rational(1, 2), -S(1)]]) + C = ContinuousMarkovChain('C', state_space=[0, 1, 2], gen_mat=G) + assert P(Eq(C(7.385), C(3.19)), Eq(C(0.862), 0)).round(5) == Float(0.35469, 5) + assert P(Gt(C(98.715), C(19.807)), Eq(C(11.314), 2)).round(5) == Float(0.32452, 5) + assert P(Le(C(5.9), C(10.112)), Eq(C(4), 1)).round(6) == Float(0.675214, 6) + assert Float(P(Eq(C(7.32), C(2.91)), Eq(C(2.63), 1)), 14) == Float(1 - P(Ne(C(7.32), C(2.91)), Eq(C(2.63), 1)), 14) + assert Float(P(Gt(C(3.36), C(1.101)), Eq(C(0.8), 2)), 14) == Float(1 - P(Le(C(3.36), C(1.101)), Eq(C(0.8), 2)), 14) + assert Float(P(Lt(C(4.9), C(2.79)), Eq(C(1.61), 0)), 14) == Float(1 - P(Ge(C(4.9), C(2.79)), Eq(C(1.61), 0)), 14) + assert P(Eq(C(5.243), C(10.912)), Eq(C(2.174), 1)) == P(Eq(C(10.912), C(5.243)), Eq(C(2.174), 1)) + assert P(Gt(C(2.344), C(9.9)), Eq(C(1.102), 1)) == P(Lt(C(9.9), C(2.344)), Eq(C(1.102), 1)) + assert P(Ge(C(7.87), C(1.008)), Eq(C(0.153), 1)) == P(Le(C(1.008), C(7.87)), Eq(C(0.153), 1)) + + #test symbolic queries + a, b, c, d = symbols('a b c d') + query = P(Eq(C(a), b), Eq(C(c), d)) + assert query.subs({a:3.65, b:2, c:1.78, d:1}).evalf().round(10) == P(Eq(C(3.65), 2), Eq(C(1.78), 1)).round(10) + query_gt = P(Gt(C(a), b), Eq(C(c), d)) + query_le = P(Le(C(a), b), Eq(C(c), d)) + assert query_gt.subs({a:13.2, b:0, c:3.29, d:2}).evalf() + query_le.subs({a:13.2, b:0, c:3.29, d:2}).evalf() == 1.0 + query_ge = P(Ge(C(a), b), Eq(C(c), d)) + query_lt = P(Lt(C(a), b), Eq(C(c), d)) + assert query_ge.subs({a:7.43, b:1, c:1.45, d:0}).evalf() + query_lt.subs({a:7.43, b:1, c:1.45, d:0}).evalf() == 1.0 + + #test issue 20078 + assert (2*C(1) + 3*C(1)).simplify() == 5*C(1) + assert (2*C(1) - 3*C(1)).simplify() == -C(1) + assert (2*(0.25*C(1))).simplify() == 0.5*C(1) + assert (2*C(1) * 0.25*C(1)).simplify() == 0.5*C(1)**2 + assert (C(1)**2 + C(1)**3).simplify() == (C(1) + 1)*C(1)**2 + +def test_BernoulliProcess(): + + B = BernoulliProcess("B", p=0.6, success=1, failure=0) + assert B.state_space == FiniteSet(0, 1) + assert B.index_set == S.Naturals0 + assert B.success == 1 + assert B.failure == 0 + + X = BernoulliProcess("X", p=Rational(1,3), success='H', failure='T') + assert X.state_space == FiniteSet('H', 'T') + H, T = symbols("H,T") + assert E(X[1]+X[2]*X[3]) == H**2/9 + 4*H*T/9 + H/3 + 4*T**2/9 + 2*T/3 + + t, x = symbols('t, x', positive=True, integer=True) + assert isinstance(B[t], RandomIndexedSymbol) + + raises(ValueError, lambda: BernoulliProcess("X", p=1.1, success=1, failure=0)) + raises(NotImplementedError, lambda: B(t)) + + raises(IndexError, lambda: B[-3]) + assert B.joint_distribution(B[3], B[9]) == JointDistributionHandmade(Lambda((B[3], B[9]), + Piecewise((0.6, Eq(B[3], 1)), (0.4, Eq(B[3], 0)), (0, True)) + *Piecewise((0.6, Eq(B[9], 1)), (0.4, Eq(B[9], 0)), (0, True)))) + + assert B.joint_distribution(2, B[4]) == JointDistributionHandmade(Lambda((B[2], B[4]), + Piecewise((0.6, Eq(B[2], 1)), (0.4, Eq(B[2], 0)), (0, True)) + *Piecewise((0.6, Eq(B[4], 1)), (0.4, Eq(B[4], 0)), (0, True)))) + + # Test for the sum distribution of Bernoulli Process RVs + Y = B[1] + B[2] + B[3] + assert P(Eq(Y, 0)).round(2) == Float(0.06, 1) + assert P(Eq(Y, 2)).round(2) == Float(0.43, 2) + assert P(Eq(Y, 4)).round(2) == 0 + assert P(Gt(Y, 1)).round(2) == Float(0.65, 2) + # Test for independency of each Random Indexed variable + assert P(Eq(B[1], 0) & Eq(B[2], 1) & Eq(B[3], 0) & Eq(B[4], 1)).round(2) == Float(0.06, 1) + + assert E(2 * B[1] + B[2]).round(2) == Float(1.80, 3) + assert E(2 * B[1] + B[2] + 5).round(2) == Float(6.80, 3) + assert E(B[2] * B[4] + B[10]).round(2) == Float(0.96, 2) + assert E(B[2] > 0, Eq(B[1],1) & Eq(B[2],1)).round(2) == Float(0.60,2) + assert E(B[1]) == 0.6 + assert P(B[1] > 0).round(2) == Float(0.60, 2) + assert P(B[1] < 1).round(2) == Float(0.40, 2) + assert P(B[1] > 0, B[2] <= 1).round(2) == Float(0.60, 2) + assert P(B[12] * B[5] > 0).round(2) == Float(0.36, 2) + assert P(B[12] * B[5] > 0, B[4] < 1).round(2) == Float(0.36, 2) + assert P(Eq(B[2], 1), B[2] > 0) == 1.0 + assert P(Eq(B[5], 3)) == 0 + assert P(Eq(B[1], 1), B[1] < 0) == 0 + assert P(B[2] > 0, Eq(B[2], 1)) == 1 + assert P(B[2] < 0, Eq(B[2], 1)) == 0 + assert P(B[2] > 0, B[2]==7) == 0 + assert P(B[5] > 0, B[5]) == BernoulliDistribution(0.6, 0, 1) + raises(ValueError, lambda: P(3)) + raises(ValueError, lambda: P(B[3] > 0, 3)) + + # test issue 19456 + expr = Sum(B[t], (t, 0, 4)) + expr2 = Sum(B[t], (t, 1, 3)) + expr3 = Sum(B[t]**2, (t, 1, 3)) + assert expr.doit() == B[0] + B[1] + B[2] + B[3] + B[4] + assert expr2.doit() == Y + assert expr3.doit() == B[1]**2 + B[2]**2 + B[3]**2 + assert B[2*t].free_symbols == {B[2*t], t} + assert B[4].free_symbols == {B[4]} + assert B[x*t].free_symbols == {B[x*t], x, t} + + #test issue 20078 + assert (2*B[t] + 3*B[t]).simplify() == 5*B[t] + assert (2*B[t] - 3*B[t]).simplify() == -B[t] + assert (2*(0.25*B[t])).simplify() == 0.5*B[t] + assert (2*B[t] * 0.25*B[t]).simplify() == 0.5*B[t]**2 + assert (B[t]**2 + B[t]**3).simplify() == (B[t] + 1)*B[t]**2 + +def test_PoissonProcess(): + X = PoissonProcess("X", 3) + assert X.state_space == S.Naturals0 + assert X.index_set == Interval(0, oo) + assert X.lamda == 3 + + t, d, x, y = symbols('t d x y', positive=True) + assert isinstance(X(t), RandomIndexedSymbol) + assert X.distribution(t) == PoissonDistribution(3*t) + with warns_deprecated_sympy(): + X.distribution(X(t)) + raises(ValueError, lambda: PoissonProcess("X", -1)) + raises(NotImplementedError, lambda: X[t]) + raises(IndexError, lambda: X(-5)) + + assert X.joint_distribution(X(2), X(3)) == JointDistributionHandmade(Lambda((X(2), X(3)), + 6**X(2)*9**X(3)*exp(-15)/(factorial(X(2))*factorial(X(3))))) + + assert X.joint_distribution(4, 6) == JointDistributionHandmade(Lambda((X(4), X(6)), + 12**X(4)*18**X(6)*exp(-30)/(factorial(X(4))*factorial(X(6))))) + + assert P(X(t) < 1) == exp(-3*t) + assert P(Eq(X(t), 0), Contains(t, Interval.Lopen(3, 5))) == exp(-6) # exp(-2*lamda) + res = P(Eq(X(t), 1), Contains(t, Interval.Lopen(3, 4))) + assert res == 3*exp(-3) + + # Equivalent to P(Eq(X(t), 1))**4 because of non-overlapping intervals + assert P(Eq(X(t), 1) & Eq(X(d), 1) & Eq(X(x), 1) & Eq(X(y), 1), Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Lopen(1, 2)) & Contains(x, Interval.Lopen(2, 3)) + & Contains(y, Interval.Lopen(3, 4))) == res**4 + + # Return Probability because of overlapping intervals + assert P(Eq(X(t), 2) & Eq(X(d), 3), Contains(t, Interval.Lopen(0, 2)) + & Contains(d, Interval.Ropen(2, 4))) == \ + Probability(Eq(X(d), 3) & Eq(X(t), 2), Contains(t, Interval.Lopen(0, 2)) + & Contains(d, Interval.Ropen(2, 4))) + + raises(ValueError, lambda: P(Eq(X(t), 2) & Eq(X(d), 3), + Contains(t, Interval.Lopen(0, 4)) & Contains(d, Interval.Lopen(3, oo)))) # no bound on d + assert P(Eq(X(3), 2)) == 81*exp(-9)/2 + assert P(Eq(X(t), 2), Contains(t, Interval.Lopen(0, 5))) == 225*exp(-15)/2 + + # Check that probability works correctly by adding it to 1 + res1 = P(X(t) <= 3, Contains(t, Interval.Lopen(0, 5))) + res2 = P(X(t) > 3, Contains(t, Interval.Lopen(0, 5))) + assert res1 == 691*exp(-15) + assert (res1 + res2).simplify() == 1 + + # Check Not and Or + assert P(Not(Eq(X(t), 2) & (X(d) > 3)), Contains(t, Interval.Ropen(2, 4)) & \ + Contains(d, Interval.Lopen(7, 8))).simplify() == -18*exp(-6) + 234*exp(-9) + 1 + assert P(Eq(X(t), 2) | Ne(X(t), 4), Contains(t, Interval.Ropen(2, 4))) == 1 - 36*exp(-6) + raises(ValueError, lambda: P(X(t) > 2, X(t) + X(d))) + assert E(X(t)) == 3*t # property of the distribution at a given timestamp + assert E(X(t)**2 + X(d)*2 + X(y)**3, Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Lopen(1, 2)) & Contains(y, Interval.Ropen(3, 4))) == 75 + assert E(X(t)**2, Contains(t, Interval.Lopen(0, 1))) == 12 + assert E(x*(X(t) + X(d))*(X(t)**2+X(d)**2), Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Ropen(1, 2))) == \ + Expectation(x*(X(d) + X(t))*(X(d)**2 + X(t)**2), Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Ropen(1, 2))) + + # Value Error because of infinite time bound + raises(ValueError, lambda: E(X(t)**3, Contains(t, Interval.Lopen(1, oo)))) + + # Equivalent to E(X(t)**2) - E(X(d)**2) == E(X(1)**2) - E(X(1)**2) == 0 + assert E((X(t) + X(d))*(X(t) - X(d)), Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Lopen(1, 2))) == 0 + assert E(X(2) + x*E(X(5))) == 15*x + 6 + assert E(x*X(1) + y) == 3*x + y + assert P(Eq(X(1), 2) & Eq(X(t), 3), Contains(t, Interval.Lopen(1, 2))) == 81*exp(-6)/4 + Y = PoissonProcess("Y", 6) + Z = X + Y + assert Z.lamda == X.lamda + Y.lamda == 9 + raises(ValueError, lambda: X + 5) # should be added be only PoissonProcess instance + N, M = Z.split(4, 5) + assert N.lamda == 4 + assert M.lamda == 5 + raises(ValueError, lambda: Z.split(3, 2)) # 2+3 != 9 + + raises(ValueError, lambda :P(Eq(X(t), 0), Contains(t, Interval.Lopen(1, 3)) & Eq(X(1), 0))) + # check if it handles queries with two random variables in one args + res1 = P(Eq(N(3), N(5))) + assert res1 == P(Eq(N(t), 0), Contains(t, Interval(3, 5))) + res2 = P(N(3) > N(1)) + assert res2 == P((N(t) > 0), Contains(t, Interval(1, 3))) + assert P(N(3) < N(1)) == 0 # condition is not possible + res3 = P(N(3) <= N(1)) # holds only for Eq(N(3), N(1)) + assert res3 == P(Eq(N(t), 0), Contains(t, Interval(1, 3))) + + # tests from https://www.probabilitycourse.com/chapter11/11_1_2_basic_concepts_of_the_poisson_process.php + X = PoissonProcess('X', 10) # 11.1 + assert P(Eq(X(S(1)/3), 3) & Eq(X(1), 10)) == exp(-10)*Rational(8000000000, 11160261) + assert P(Eq(X(1), 1), Eq(X(S(1)/3), 3)) == 0 + assert P(Eq(X(1), 10), Eq(X(S(1)/3), 3)) == P(Eq(X(S(2)/3), 7)) + + X = PoissonProcess('X', 2) # 11.2 + assert P(X(S(1)/2) < 1) == exp(-1) + assert P(X(3) < 1, Eq(X(1), 0)) == exp(-4) + assert P(Eq(X(4), 3), Eq(X(2), 3)) == exp(-4) + + X = PoissonProcess('X', 3) + assert P(Eq(X(2), 5) & Eq(X(1), 2)) == Rational(81, 4)*exp(-6) + + # check few properties + assert P(X(2) <= 3, X(1)>=1) == 3*P(Eq(X(1), 0)) + 2*P(Eq(X(1), 1)) + P(Eq(X(1), 2)) + assert P(X(2) <= 3, X(1) > 1) == 2*P(Eq(X(1), 0)) + 1*P(Eq(X(1), 1)) + assert P(Eq(X(2), 5) & Eq(X(1), 2)) == P(Eq(X(1), 3))*P(Eq(X(1), 2)) + assert P(Eq(X(3), 4), Eq(X(1), 3)) == P(Eq(X(2), 1)) + + #test issue 20078 + assert (2*X(t) + 3*X(t)).simplify() == 5*X(t) + assert (2*X(t) - 3*X(t)).simplify() == -X(t) + assert (2*(0.25*X(t))).simplify() == 0.5*X(t) + assert (2*X(t) * 0.25*X(t)).simplify() == 0.5*X(t)**2 + assert (X(t)**2 + X(t)**3).simplify() == (X(t) + 1)*X(t)**2 + +def test_WienerProcess(): + X = WienerProcess("X") + assert X.state_space == S.Reals + assert X.index_set == Interval(0, oo) + + t, d, x, y = symbols('t d x y', positive=True) + assert isinstance(X(t), RandomIndexedSymbol) + assert X.distribution(t) == NormalDistribution(0, sqrt(t)) + with warns_deprecated_sympy(): + X.distribution(X(t)) + raises(ValueError, lambda: PoissonProcess("X", -1)) + raises(NotImplementedError, lambda: X[t]) + raises(IndexError, lambda: X(-2)) + + assert X.joint_distribution(X(2), X(3)) == JointDistributionHandmade( + Lambda((X(2), X(3)), sqrt(6)*exp(-X(2)**2/4)*exp(-X(3)**2/6)/(12*pi))) + assert X.joint_distribution(4, 6) == JointDistributionHandmade( + Lambda((X(4), X(6)), sqrt(6)*exp(-X(4)**2/8)*exp(-X(6)**2/12)/(24*pi))) + + assert P(X(t) < 3).simplify() == erf(3*sqrt(2)/(2*sqrt(t)))/2 + S(1)/2 + assert P(X(t) > 2, Contains(t, Interval.Lopen(3, 7))).simplify() == S(1)/2 -\ + erf(sqrt(2)/2)/2 + + # Equivalent to P(X(1)>1)**4 + assert P((X(t) > 4) & (X(d) > 3) & (X(x) > 2) & (X(y) > 1), + Contains(t, Interval.Lopen(0, 1)) & Contains(d, Interval.Lopen(1, 2)) + & Contains(x, Interval.Lopen(2, 3)) & Contains(y, Interval.Lopen(3, 4))).simplify() ==\ + (1 - erf(sqrt(2)/2))*(1 - erf(sqrt(2)))*(1 - erf(3*sqrt(2)/2))*(1 - erf(2*sqrt(2)))/16 + + # Contains an overlapping interval so, return Probability + assert P((X(t)< 2) & (X(d)> 3), Contains(t, Interval.Lopen(0, 2)) + & Contains(d, Interval.Ropen(2, 4))) == Probability((X(d) > 3) & (X(t) < 2), + Contains(d, Interval.Ropen(2, 4)) & Contains(t, Interval.Lopen(0, 2))) + + assert str(P(Not((X(t) < 5) & (X(d) > 3)), Contains(t, Interval.Ropen(2, 4)) & + Contains(d, Interval.Lopen(7, 8))).simplify()) == \ + '-(1 - erf(3*sqrt(2)/2))*(2 - erfc(5/2))/4 + 1' + # Distribution has mean 0 at each timestamp + assert E(X(t)) == 0 + assert E(x*(X(t) + X(d))*(X(t)**2+X(d)**2), Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Ropen(1, 2))) == Expectation(x*(X(d) + X(t))*(X(d)**2 + X(t)**2), + Contains(d, Interval.Ropen(1, 2)) & Contains(t, Interval.Lopen(0, 1))) + assert E(X(t) + x*E(X(3))) == 0 + + #test issue 20078 + assert (2*X(t) + 3*X(t)).simplify() == 5*X(t) + assert (2*X(t) - 3*X(t)).simplify() == -X(t) + assert (2*(0.25*X(t))).simplify() == 0.5*X(t) + assert (2*X(t) * 0.25*X(t)).simplify() == 0.5*X(t)**2 + assert (X(t)**2 + X(t)**3).simplify() == (X(t) + 1)*X(t)**2 + + +def test_GammaProcess_symbolic(): + t, d, x, y, g, l = symbols('t d x y g l', positive=True) + X = GammaProcess("X", l, g) + + raises(NotImplementedError, lambda: X[t]) + raises(IndexError, lambda: X(-1)) + assert isinstance(X(t), RandomIndexedSymbol) + assert X.state_space == Interval(0, oo) + assert X.distribution(t) == GammaDistribution(g*t, 1/l) + with warns_deprecated_sympy(): + X.distribution(X(t)) + assert X.joint_distribution(5, X(3)) == JointDistributionHandmade(Lambda( + (X(5), X(3)), l**(8*g)*exp(-l*X(3))*exp(-l*X(5))*X(3)**(3*g - 1)*X(5)**(5*g + - 1)/(gamma(3*g)*gamma(5*g)))) + # property of the gamma process at any given timestamp + assert E(X(t)) == g*t/l + assert variance(X(t)).simplify() == g*t/l**2 + + # Equivalent to E(2*X(1)) + E(X(1)**2) + E(X(1)**3), where E(X(1)) == g/l + assert E(X(t)**2 + X(d)*2 + X(y)**3, Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Lopen(1, 2)) & Contains(y, Interval.Ropen(3, 4))) == \ + 2*g/l + (g**2 + g)/l**2 + (g**3 + 3*g**2 + 2*g)/l**3 + + assert P(X(t) > 3, Contains(t, Interval.Lopen(3, 4))).simplify() == \ + 1 - lowergamma(g, 3*l)/gamma(g) # equivalent to P(X(1)>3) + + + #test issue 20078 + assert (2*X(t) + 3*X(t)).simplify() == 5*X(t) + assert (2*X(t) - 3*X(t)).simplify() == -X(t) + assert (2*(0.25*X(t))).simplify() == 0.5*X(t) + assert (2*X(t) * 0.25*X(t)).simplify() == 0.5*X(t)**2 + assert (X(t)**2 + X(t)**3).simplify() == (X(t) + 1)*X(t)**2 +def test_GammaProcess_numeric(): + t, d, x, y = symbols('t d x y', positive=True) + X = GammaProcess("X", 1, 2) + assert X.state_space == Interval(0, oo) + assert X.index_set == Interval(0, oo) + assert X.lamda == 1 + assert X.gamma == 2 + + raises(ValueError, lambda: GammaProcess("X", -1, 2)) + raises(ValueError, lambda: GammaProcess("X", 0, -2)) + raises(ValueError, lambda: GammaProcess("X", -1, -2)) + + # all are independent because of non-overlapping intervals + assert P((X(t) > 4) & (X(d) > 3) & (X(x) > 2) & (X(y) > 1), Contains(t, + Interval.Lopen(0, 1)) & Contains(d, Interval.Lopen(1, 2)) & Contains(x, + Interval.Lopen(2, 3)) & Contains(y, Interval.Lopen(3, 4))).simplify() == \ + 120*exp(-10) + + # Check working with Not and Or + assert P(Not((X(t) < 5) & (X(d) > 3)), Contains(t, Interval.Ropen(2, 4)) & + Contains(d, Interval.Lopen(7, 8))).simplify() == -4*exp(-3) + 472*exp(-8)/3 + 1 + assert P((X(t) > 2) | (X(t) < 4), Contains(t, Interval.Ropen(1, 4))).simplify() == \ + -643*exp(-4)/15 + 109*exp(-2)/15 + 1 + + assert E(X(t)) == 2*t # E(X(t)) == gamma*t/l + assert E(X(2) + x*E(X(5))) == 10*x + 4 diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_symbolic_multivariate.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_symbolic_multivariate.py new file mode 100644 index 0000000000000000000000000000000000000000..79979e20a6f10d2a2ddfe85ce4c8df145e98c3fd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_symbolic_multivariate.py @@ -0,0 +1,172 @@ +from sympy.stats import Expectation, Normal, Variance, Covariance +from sympy.testing.pytest import raises +from sympy.core.symbol import symbols +from sympy.matrices.exceptions import ShapeError +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.stats.rv import RandomMatrixSymbol +from sympy.stats.symbolic_multivariate_probability import (ExpectationMatrix, + VarianceMatrix, CrossCovarianceMatrix) + +j, k = symbols("j,k") + +A = MatrixSymbol("A", k, k) +B = MatrixSymbol("B", k, k) +C = MatrixSymbol("C", k, k) +D = MatrixSymbol("D", k, k) + +a = MatrixSymbol("a", k, 1) +b = MatrixSymbol("b", k, 1) + +A2 = MatrixSymbol("A2", 2, 2) +B2 = MatrixSymbol("B2", 2, 2) + +X = RandomMatrixSymbol("X", k, 1) +Y = RandomMatrixSymbol("Y", k, 1) +Z = RandomMatrixSymbol("Z", k, 1) +W = RandomMatrixSymbol("W", k, 1) + +R = RandomMatrixSymbol("R", k, k) + +X2 = RandomMatrixSymbol("X2", 2, 1) + +normal = Normal("normal", 0, 1) + +m1 = Matrix([ + [1, j*Normal("normal2", 2, 1)], + [normal, 0] +]) + +def test_multivariate_expectation(): + expr = Expectation(a) + assert expr == Expectation(a) == ExpectationMatrix(a) + assert expr.expand() == a + + expr = Expectation(X) + assert expr == Expectation(X) == ExpectationMatrix(X) + assert expr.shape == (k, 1) + assert expr.rows == k + assert expr.cols == 1 + assert isinstance(expr, ExpectationMatrix) + + expr = Expectation(A*X + b) + assert expr == ExpectationMatrix(A*X + b) + assert expr.expand() == A*ExpectationMatrix(X) + b + assert isinstance(expr, ExpectationMatrix) + assert expr.shape == (k, 1) + + expr = Expectation(m1*X2) + assert expr.expand() == expr + + expr = Expectation(A2*m1*B2*X2) + assert expr.args[0].args == (A2, m1, B2, X2) + assert expr.expand() == A2*ExpectationMatrix(m1*B2*X2) + + expr = Expectation((X + Y)*(X - Y).T) + assert expr.expand() == ExpectationMatrix(X*X.T) - ExpectationMatrix(X*Y.T) +\ + ExpectationMatrix(Y*X.T) - ExpectationMatrix(Y*Y.T) + + expr = Expectation(A*X + B*Y) + assert expr.expand() == A*ExpectationMatrix(X) + B*ExpectationMatrix(Y) + + assert Expectation(m1).doit() == Matrix([[1, 2*j], [0, 0]]) + + x1 = Matrix([ + [Normal('N11', 11, 1), Normal('N12', 12, 1)], + [Normal('N21', 21, 1), Normal('N22', 22, 1)] + ]) + x2 = Matrix([ + [Normal('M11', 1, 1), Normal('M12', 2, 1)], + [Normal('M21', 3, 1), Normal('M22', 4, 1)] + ]) + + assert Expectation(Expectation(x1 + x2)).doit(deep=False) == ExpectationMatrix(x1 + x2) + assert Expectation(Expectation(x1 + x2)).doit() == Matrix([[12, 14], [24, 26]]) + + +def test_multivariate_variance(): + raises(ShapeError, lambda: Variance(A)) + + expr = Variance(a) + assert expr == Variance(a) == VarianceMatrix(a) + assert expr.expand() == ZeroMatrix(k, k) + expr = Variance(a.T) + assert expr == Variance(a.T) == VarianceMatrix(a.T) + assert expr.expand() == ZeroMatrix(k, k) + + expr = Variance(X) + assert expr == Variance(X) == VarianceMatrix(X) + assert expr.shape == (k, k) + assert expr.rows == k + assert expr.cols == k + assert isinstance(expr, VarianceMatrix) + + expr = Variance(A*X) + assert expr == VarianceMatrix(A*X) + assert expr.expand() == A*VarianceMatrix(X)*A.T + assert isinstance(expr, VarianceMatrix) + assert expr.shape == (k, k) + + expr = Variance(A*B*X) + assert expr.expand() == A*B*VarianceMatrix(X)*B.T*A.T + + expr = Variance(m1*X2) + assert expr.expand() == expr + + expr = Variance(A2*m1*B2*X2) + assert expr.args[0].args == (A2, m1, B2, X2) + assert expr.expand() == expr + + expr = Variance(A*X + B*Y) + assert expr.expand() == 2*A*CrossCovarianceMatrix(X, Y)*B.T +\ + A*VarianceMatrix(X)*A.T + B*VarianceMatrix(Y)*B.T + +def test_multivariate_crosscovariance(): + raises(ShapeError, lambda: Covariance(X, Y.T)) + raises(ShapeError, lambda: Covariance(X, A)) + + + expr = Covariance(a.T, b.T) + assert expr.shape == (1, 1) + assert expr.expand() == ZeroMatrix(1, 1) + + expr = Covariance(a, b) + assert expr == Covariance(a, b) == CrossCovarianceMatrix(a, b) + assert expr.expand() == ZeroMatrix(k, k) + assert expr.shape == (k, k) + assert expr.rows == k + assert expr.cols == k + assert isinstance(expr, CrossCovarianceMatrix) + + expr = Covariance(A*X + a, b) + assert expr.expand() == ZeroMatrix(k, k) + + expr = Covariance(X, Y) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == expr + + expr = Covariance(X, X) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == VarianceMatrix(X) + + expr = Covariance(X + Y, Z) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == CrossCovarianceMatrix(X, Z) + CrossCovarianceMatrix(Y, Z) + + expr = Covariance(A*X, Y) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == A*CrossCovarianceMatrix(X, Y) + + expr = Covariance(X, B*Y) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == CrossCovarianceMatrix(X, Y)*B.T + + expr = Covariance(A*X + a, B.T*Y + b) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == A*CrossCovarianceMatrix(X, Y)*B + + expr = Covariance(A*X + B*Y + a, C.T*Z + D.T*W + b) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == A*CrossCovarianceMatrix(X, W)*D + A*CrossCovarianceMatrix(X, Z)*C \ + + B*CrossCovarianceMatrix(Y, W)*D + B*CrossCovarianceMatrix(Y, Z)*C diff --git a/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_symbolic_probability.py b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_symbolic_probability.py new file mode 100644 index 0000000000000000000000000000000000000000..edac942ac081c0d44cafd31761b77bc577b6a3fd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/stats/tests/test_symbolic_probability.py @@ -0,0 +1,175 @@ +from sympy.concrete.summations import Sum +from sympy.core.mul import Mul +from sympy.core.numbers import (oo, pi) +from sympy.core.relational import Eq +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.integrals.integrals import Integral +from sympy.core.expr import unchanged +from sympy.stats import (Normal, Poisson, variance, Covariance, Variance, + Probability, Expectation, Moment, CentralMoment) +from sympy.stats.rv import probability, expectation + + +def test_literal_probability(): + X = Normal('X', 2, 3) + Y = Normal('Y', 3, 4) + Z = Poisson('Z', 4) + W = Poisson('W', 3) + x = symbols('x', real=True) + y, w, z = symbols('y, w, z') + + assert Probability(X > 0).evaluate_integral() == probability(X > 0) + assert Probability(X > x).evaluate_integral() == probability(X > x) + assert Probability(X > 0).rewrite(Integral).doit() == probability(X > 0) + assert Probability(X > x).rewrite(Integral).doit() == probability(X > x) + + assert Expectation(X).evaluate_integral() == expectation(X) + assert Expectation(X).rewrite(Integral).doit() == expectation(X) + assert Expectation(X**2).evaluate_integral() == expectation(X**2) + assert Expectation(x*X).args == (x*X,) + assert Expectation(x*X).expand() == x*Expectation(X) + assert Expectation(2*X + 3*Y + z*X*Y).expand() == 2*Expectation(X) + 3*Expectation(Y) + z*Expectation(X*Y) + assert Expectation(2*X + 3*Y + z*X*Y).args == (2*X + 3*Y + z*X*Y,) + assert Expectation(sin(X)) == Expectation(sin(X)).expand() + assert Expectation(2*x*sin(X)*Y + y*X**2 + z*X*Y).expand() == 2*x*Expectation(sin(X)*Y) \ + + y*Expectation(X**2) + z*Expectation(X*Y) + assert Expectation(X + Y).expand() == Expectation(X) + Expectation(Y) + assert Expectation((X + Y)*(X - Y)).expand() == Expectation(X**2) - Expectation(Y**2) + assert Expectation((X + Y)*(X - Y)).expand().doit() == -12 + assert Expectation(X + Y, evaluate=True).doit() == 5 + assert Expectation(X + Expectation(Y)).doit() == 5 + assert Expectation(X + Expectation(Y)).doit(deep=False) == 2 + Expectation(Expectation(Y)) + assert Expectation(X + Expectation(Y + Expectation(2*X))).doit(deep=False) == 2 \ + + Expectation(Expectation(Y + Expectation(2*X))) + assert Expectation(X + Expectation(Y + Expectation(2*X))).doit() == 9 + assert Expectation(Expectation(2*X)).doit() == 4 + assert Expectation(Expectation(2*X)).doit(deep=False) == Expectation(2*X) + assert Expectation(4*Expectation(2*X)).doit(deep=False) == 4*Expectation(2*X) + assert Expectation((X + Y)**3).expand() == 3*Expectation(X*Y**2) +\ + 3*Expectation(X**2*Y) + Expectation(X**3) + Expectation(Y**3) + assert Expectation((X - Y)**3).expand() == 3*Expectation(X*Y**2) -\ + 3*Expectation(X**2*Y) + Expectation(X**3) - Expectation(Y**3) + assert Expectation((X - Y)**2).expand() == -2*Expectation(X*Y) +\ + Expectation(X**2) + Expectation(Y**2) + + assert Variance(w).args == (w,) + assert Variance(w).expand() == 0 + assert Variance(X).evaluate_integral() == Variance(X).rewrite(Integral).doit() == variance(X) + assert Variance(X + z).args == (X + z,) + assert Variance(X + z).expand() == Variance(X) + assert Variance(X*Y).args == (Mul(X, Y),) + assert type(Variance(X*Y)) == Variance + assert Variance(z*X).expand() == z**2*Variance(X) + assert Variance(X + Y).expand() == Variance(X) + Variance(Y) + 2*Covariance(X, Y) + assert Variance(X + Y + Z + W).expand() == (Variance(X) + Variance(Y) + Variance(Z) + Variance(W) + + 2 * Covariance(X, Y) + 2 * Covariance(X, Z) + 2 * Covariance(X, W) + + 2 * Covariance(Y, Z) + 2 * Covariance(Y, W) + 2 * Covariance(W, Z)) + assert Variance(X**2).evaluate_integral() == variance(X**2) + assert unchanged(Variance, X**2) + assert Variance(x*X**2).expand() == x**2*Variance(X**2) + assert Variance(sin(X)).args == (sin(X),) + assert Variance(sin(X)).expand() == Variance(sin(X)) + assert Variance(x*sin(X)).expand() == x**2*Variance(sin(X)) + + assert Covariance(w, z).args == (w, z) + assert Covariance(w, z).expand() == 0 + assert Covariance(X, w).expand() == 0 + assert Covariance(w, X).expand() == 0 + assert Covariance(X, Y).args == (X, Y) + assert type(Covariance(X, Y)) == Covariance + assert Covariance(z*X + 3, Y).expand() == z*Covariance(X, Y) + assert Covariance(X, X).args == (X, X) + assert Covariance(X, X).expand() == Variance(X) + assert Covariance(z*X + 3, w*Y + 4).expand() == w*z*Covariance(X,Y) + assert Covariance(X, Y) == Covariance(Y, X) + assert Covariance(X + Y, Z + W).expand() == Covariance(W, X) + Covariance(W, Y) + Covariance(X, Z) + Covariance(Y, Z) + assert Covariance(x*X + y*Y, z*Z + w*W).expand() == (x*w*Covariance(W, X) + w*y*Covariance(W, Y) + + x*z*Covariance(X, Z) + y*z*Covariance(Y, Z)) + assert Covariance(x*X**2 + y*sin(Y), z*Y*Z**2 + w*W).expand() == (w*x*Covariance(W, X**2) + w*y*Covariance(sin(Y), W) + + x*z*Covariance(Y*Z**2, X**2) + y*z*Covariance(Y*Z**2, sin(Y))) + assert Covariance(X, X**2).expand() == Covariance(X, X**2) + assert Covariance(X, sin(X)).expand() == Covariance(sin(X), X) + assert Covariance(X**2, sin(X)*Y).expand() == Covariance(sin(X)*Y, X**2) + assert Covariance(w, X).evaluate_integral() == 0 + + +def test_probability_rewrite(): + X = Normal('X', 2, 3) + Y = Normal('Y', 3, 4) + Z = Poisson('Z', 4) + W = Poisson('W', 3) + x, y, w, z = symbols('x, y, w, z') + + assert Variance(w).rewrite(Expectation) == 0 + assert Variance(X).rewrite(Expectation) == Expectation(X ** 2) - Expectation(X) ** 2 + assert Variance(X, condition=Y).rewrite(Expectation) == Expectation(X ** 2, Y) - Expectation(X, Y) ** 2 + assert Variance(X, Y) != Expectation(X**2) - Expectation(X)**2 + assert Variance(X + z).rewrite(Expectation) == Expectation((X + z) ** 2) - Expectation(X + z) ** 2 + assert Variance(X * Y).rewrite(Expectation) == Expectation(X ** 2 * Y ** 2) - Expectation(X * Y) ** 2 + + assert Covariance(w, X).rewrite(Expectation) == -w*Expectation(X) + Expectation(w*X) + assert Covariance(X, Y).rewrite(Expectation) == Expectation(X*Y) - Expectation(X)*Expectation(Y) + assert Covariance(X, Y, condition=W).rewrite(Expectation) == Expectation(X * Y, W) - Expectation(X, W) * Expectation(Y, W) + + w, x, z = symbols("W, x, z") + px = Probability(Eq(X, x)) + pz = Probability(Eq(Z, z)) + + assert Expectation(X).rewrite(Probability) == Integral(x*px, (x, -oo, oo)) + assert Expectation(Z).rewrite(Probability) == Sum(z*pz, (z, 0, oo)) + assert Variance(X).rewrite(Probability) == Integral(x**2*px, (x, -oo, oo)) - Integral(x*px, (x, -oo, oo))**2 + assert Variance(Z).rewrite(Probability) == Sum(z**2*pz, (z, 0, oo)) - Sum(z*pz, (z, 0, oo))**2 + assert Covariance(w, X).rewrite(Probability) == \ + -w*Integral(x*Probability(Eq(X, x)), (x, -oo, oo)) + Integral(w*x*Probability(Eq(X, x)), (x, -oo, oo)) + + # To test rewrite as sum function + assert Variance(X).rewrite(Sum) == Variance(X).rewrite(Integral) + assert Expectation(X).rewrite(Sum) == Expectation(X).rewrite(Integral) + + assert Covariance(w, X).rewrite(Sum) == 0 + + assert Covariance(w, X).rewrite(Integral) == 0 + + assert Variance(X, condition=Y).rewrite(Probability) == Integral(x**2*Probability(Eq(X, x), Y), (x, -oo, oo)) - \ + Integral(x*Probability(Eq(X, x), Y), (x, -oo, oo))**2 + + +def test_symbolic_Moment(): + mu = symbols('mu', real=True) + sigma = symbols('sigma', positive=True) + x = symbols('x') + X = Normal('X', mu, sigma) + M = Moment(X, 4, 2) + assert M.rewrite(Expectation) == Expectation((X - 2)**4) + assert M.rewrite(Probability) == Integral((x - 2)**4*Probability(Eq(X, x)), + (x, -oo, oo)) + k = Dummy('k') + expri = Integral(sqrt(2)*(k - 2)**4*exp(-(k - \ + mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (k, -oo, oo)) + assert M.rewrite(Integral).dummy_eq(expri) + assert M.doit() == (mu**4 - 8*mu**3 + 6*mu**2*sigma**2 + \ + 24*mu**2 - 24*mu*sigma**2 - 32*mu + 3*sigma**4 + 24*sigma**2 + 16) + M = Moment(2, 5) + assert M.doit() == 2**5 + + +def test_symbolic_CentralMoment(): + mu = symbols('mu', real=True) + sigma = symbols('sigma', positive=True) + x = symbols('x') + X = Normal('X', mu, sigma) + CM = CentralMoment(X, 6) + assert CM.rewrite(Expectation) == Expectation((X - Expectation(X))**6) + assert CM.rewrite(Probability) == Integral((x - Integral(x*Probability(True), + (x, -oo, oo)))**6*Probability(Eq(X, x)), (x, -oo, oo)) + k = Dummy('k') + expri = Integral(sqrt(2)*(k - Integral(sqrt(2)*k*exp(-(k - \ + mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (k, -oo, oo)))**6*exp(-(k - \ + mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (k, -oo, oo)) + assert CM.rewrite(Integral).dummy_eq(expri) + assert CM.doit().simplify() == 15*sigma**6 + CM = Moment(5, 5) + assert CM.doit() == 5**5 diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/__init__.py b/.venv/lib/python3.13/site-packages/sympy/tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a832614b1d48e26bf01e16f040f34dd412e8e32b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/__init__.py @@ -0,0 +1,23 @@ +"""A module to manipulate symbolic objects with indices including tensors + +""" +from .indexed import IndexedBase, Idx, Indexed +from .index_methods import get_contraction_structure, get_indices +from .functions import shape +from .array import (MutableDenseNDimArray, ImmutableDenseNDimArray, + MutableSparseNDimArray, ImmutableSparseNDimArray, NDimArray, tensorproduct, + tensorcontraction, tensordiagonal, derive_by_array, permutedims, Array, + DenseNDimArray, SparseNDimArray,) + +__all__ = [ + 'IndexedBase', 'Idx', 'Indexed', + + 'get_contraction_structure', 'get_indices', + + 'shape', + + 'MutableDenseNDimArray', 'ImmutableDenseNDimArray', + 'MutableSparseNDimArray', 'ImmutableSparseNDimArray', 'NDimArray', + 'tensorproduct', 'tensorcontraction', 'tensordiagonal', 'derive_by_array', 'permutedims', + 'Array', 'DenseNDimArray', 'SparseNDimArray', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/functions.py b/.venv/lib/python3.13/site-packages/sympy/tensor/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..f14599d69152db1713f21c9dd785683901c5eeb9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/functions.py @@ -0,0 +1,154 @@ +from collections.abc import Iterable +from functools import singledispatch + +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.core.parameters import global_parameters + + +class TensorProduct(Expr): + """ + Generic class for tensor products. + """ + is_number = False + + def __new__(cls, *args, **kwargs): + from sympy.tensor.array import NDimArray, tensorproduct, Array + from sympy.matrices.expressions.matexpr import MatrixExpr + from sympy.matrices.matrixbase import MatrixBase + from sympy.strategies import flatten + + args = [sympify(arg) for arg in args] + evaluate = kwargs.get("evaluate", global_parameters.evaluate) + + if not evaluate: + obj = Expr.__new__(cls, *args) + return obj + + arrays = [] + other = [] + scalar = S.One + for arg in args: + if isinstance(arg, (Iterable, MatrixBase, NDimArray)): + arrays.append(Array(arg)) + elif isinstance(arg, (MatrixExpr,)): + other.append(arg) + else: + scalar *= arg + + coeff = scalar*tensorproduct(*arrays) + if len(other) == 0: + return coeff + if coeff != 1: + newargs = [coeff] + other + else: + newargs = other + obj = Expr.__new__(cls, *newargs, **kwargs) + return flatten(obj) + + def rank(self): + return len(self.shape) + + def _get_args_shapes(self): + from sympy.tensor.array import Array + return [i.shape if hasattr(i, "shape") else Array(i).shape for i in self.args] + + @property + def shape(self): + shape_list = self._get_args_shapes() + return sum(shape_list, ()) + + def __getitem__(self, index): + index = iter(index) + return Mul.fromiter( + arg.__getitem__(tuple(next(index) for i in shp)) + for arg, shp in zip(self.args, self._get_args_shapes()) + ) + + +@singledispatch +def shape(expr): + """ + Return the shape of the *expr* as a tuple. *expr* should represent + suitable object such as matrix or array. + + Parameters + ========== + + expr : SymPy object having ``MatrixKind`` or ``ArrayKind``. + + Raises + ====== + + NoShapeError : Raised when object with wrong kind is passed. + + Examples + ======== + + This function returns the shape of any object representing matrix or array. + + >>> from sympy import shape, Array, ImmutableDenseMatrix, Integral + >>> from sympy.abc import x + >>> A = Array([1, 2]) + >>> shape(A) + (2,) + >>> shape(Integral(A, x)) + (2,) + >>> M = ImmutableDenseMatrix([1, 2]) + >>> shape(M) + (2, 1) + >>> shape(Integral(M, x)) + (2, 1) + + You can support new type by dispatching. + + >>> from sympy import Expr + >>> class NewExpr(Expr): + ... pass + >>> @shape.register(NewExpr) + ... def _(expr): + ... return shape(expr.args[0]) + >>> shape(NewExpr(M)) + (2, 1) + + If unsuitable expression is passed, ``NoShapeError()`` will be raised. + + >>> shape(Integral(x, x)) + Traceback (most recent call last): + ... + sympy.tensor.functions.NoShapeError: shape() called on non-array object: Integral(x, x) + + Notes + ===== + + Array-like classes (such as ``Matrix`` or ``NDimArray``) has ``shape`` + property which returns its shape, but it cannot be used for non-array + classes containing array. This function returns the shape of any + registered object representing array. + + """ + if hasattr(expr, "shape"): + return expr.shape + raise NoShapeError( + "%s does not have shape, or its type is not registered to shape()." % expr) + + +class NoShapeError(Exception): + """ + Raised when ``shape()`` is called on non-array object. + + This error can be imported from ``sympy.tensor.functions``. + + Examples + ======== + + >>> from sympy import shape + >>> from sympy.abc import x + >>> shape(x) + Traceback (most recent call last): + ... + sympy.tensor.functions.NoShapeError: shape() called on non-array object: x + """ + pass diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/index_methods.py b/.venv/lib/python3.13/site-packages/sympy/tensor/index_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..12f707b60b4ad0bcadc35a222d9abe0cc5e033fc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/index_methods.py @@ -0,0 +1,469 @@ +"""Module with functions operating on IndexedBase, Indexed and Idx objects + + - Check shape conformance + - Determine indices in resulting expression + + etc. + + Methods in this module could be implemented by calling methods on Expr + objects instead. When things stabilize this could be a useful + refactoring. +""" + +from functools import reduce + +from sympy.core.function import Function +from sympy.functions import exp, Piecewise +from sympy.tensor.indexed import Idx, Indexed +from sympy.utilities import sift + +from collections import OrderedDict + +class IndexConformanceException(Exception): + pass + +def _unique_and_repeated(inds): + """ + Returns the unique and repeated indices. Also note, from the examples given below + that the order of indices is maintained as given in the input. + + Examples + ======== + + >>> from sympy.tensor.index_methods import _unique_and_repeated + >>> _unique_and_repeated([2, 3, 1, 3, 0, 4, 0]) + ([2, 1, 4], [3, 0]) + """ + uniq = OrderedDict() + for i in inds: + if i in uniq: + uniq[i] = 0 + else: + uniq[i] = 1 + return sift(uniq, lambda x: uniq[x], binary=True) + +def _remove_repeated(inds): + """ + Removes repeated objects from sequences + + Returns a set of the unique objects and a tuple of all that have been + removed. + + Examples + ======== + + >>> from sympy.tensor.index_methods import _remove_repeated + >>> l1 = [1, 2, 3, 2] + >>> _remove_repeated(l1) + ({1, 3}, (2,)) + + """ + u, r = _unique_and_repeated(inds) + return set(u), tuple(r) + + +def _get_indices_Mul(expr, return_dummies=False): + """Determine the outer indices of a Mul object. + + Examples + ======== + + >>> from sympy.tensor.index_methods import _get_indices_Mul + >>> from sympy.tensor.indexed import IndexedBase, Idx + >>> i, j, k = map(Idx, ['i', 'j', 'k']) + >>> x = IndexedBase('x') + >>> y = IndexedBase('y') + >>> _get_indices_Mul(x[i, k]*y[j, k]) + ({i, j}, {}) + >>> _get_indices_Mul(x[i, k]*y[j, k], return_dummies=True) + ({i, j}, {}, (k,)) + + """ + + inds = list(map(get_indices, expr.args)) + inds, syms = list(zip(*inds)) + + inds = list(map(list, inds)) + inds = list(reduce(lambda x, y: x + y, inds)) + inds, dummies = _remove_repeated(inds) + + symmetry = {} + for s in syms: + for pair in s: + if pair in symmetry: + symmetry[pair] *= s[pair] + else: + symmetry[pair] = s[pair] + + if return_dummies: + return inds, symmetry, dummies + else: + return inds, symmetry + + +def _get_indices_Pow(expr): + """Determine outer indices of a power or an exponential. + + A power is considered a universal function, so that the indices of a Pow is + just the collection of indices present in the expression. This may be + viewed as a bit inconsistent in the special case: + + x[i]**2 = x[i]*x[i] (1) + + The above expression could have been interpreted as the contraction of x[i] + with itself, but we choose instead to interpret it as a function + + lambda y: y**2 + + applied to each element of x (a universal function in numpy terms). In + order to allow an interpretation of (1) as a contraction, we need + contravariant and covariant Idx subclasses. (FIXME: this is not yet + implemented) + + Expressions in the base or exponent are subject to contraction as usual, + but an index that is present in the exponent, will not be considered + contractable with its own base. Note however, that indices in the same + exponent can be contracted with each other. + + Examples + ======== + + >>> from sympy.tensor.index_methods import _get_indices_Pow + >>> from sympy import Pow, exp, IndexedBase, Idx + >>> A = IndexedBase('A') + >>> x = IndexedBase('x') + >>> i, j, k = map(Idx, ['i', 'j', 'k']) + >>> _get_indices_Pow(exp(A[i, j]*x[j])) + ({i}, {}) + >>> _get_indices_Pow(Pow(x[i], x[i])) + ({i}, {}) + >>> _get_indices_Pow(Pow(A[i, j]*x[j], x[i])) + ({i}, {}) + + """ + base, exp = expr.as_base_exp() + binds, bsyms = get_indices(base) + einds, esyms = get_indices(exp) + + inds = binds | einds + + # FIXME: symmetries from power needs to check special cases, else nothing + symmetries = {} + + return inds, symmetries + + +def _get_indices_Add(expr): + """Determine outer indices of an Add object. + + In a sum, each term must have the same set of outer indices. A valid + expression could be + + x(i)*y(j) - x(j)*y(i) + + But we do not allow expressions like: + + x(i)*y(j) - z(j)*z(j) + + FIXME: Add support for Numpy broadcasting + + Examples + ======== + + >>> from sympy.tensor.index_methods import _get_indices_Add + >>> from sympy.tensor.indexed import IndexedBase, Idx + >>> i, j, k = map(Idx, ['i', 'j', 'k']) + >>> x = IndexedBase('x') + >>> y = IndexedBase('y') + >>> _get_indices_Add(x[i] + x[k]*y[i, k]) + ({i}, {}) + + """ + + inds = list(map(get_indices, expr.args)) + inds, syms = list(zip(*inds)) + + # allow broadcast of scalars + non_scalars = [x for x in inds if x != set()] + if not non_scalars: + return set(), {} + + if not all(x == non_scalars[0] for x in non_scalars[1:]): + raise IndexConformanceException("Indices are not consistent: %s" % expr) + if not reduce(lambda x, y: x != y or y, syms): + symmetries = syms[0] + else: + # FIXME: search for symmetries + symmetries = {} + + return non_scalars[0], symmetries + + +def get_indices(expr): + """Determine the outer indices of expression ``expr`` + + By *outer* we mean indices that are not summation indices. Returns a set + and a dict. The set contains outer indices and the dict contains + information about index symmetries. + + Examples + ======== + + >>> from sympy.tensor.index_methods import get_indices + >>> from sympy import symbols + >>> from sympy.tensor import IndexedBase + >>> x, y, A = map(IndexedBase, ['x', 'y', 'A']) + >>> i, j, a, z = symbols('i j a z', integer=True) + + The indices of the total expression is determined, Repeated indices imply a + summation, for instance the trace of a matrix A: + + >>> get_indices(A[i, i]) + (set(), {}) + + In the case of many terms, the terms are required to have identical + outer indices. Else an IndexConformanceException is raised. + + >>> get_indices(x[i] + A[i, j]*y[j]) + ({i}, {}) + + :Exceptions: + + An IndexConformanceException means that the terms ar not compatible, e.g. + + >>> get_indices(x[i] + y[j]) #doctest: +SKIP + (...) + IndexConformanceException: Indices are not consistent: x(i) + y(j) + + .. warning:: + The concept of *outer* indices applies recursively, starting on the deepest + level. This implies that dummies inside parenthesis are assumed to be + summed first, so that the following expression is handled gracefully: + + >>> get_indices((x[i] + A[i, j]*y[j])*x[j]) + ({i, j}, {}) + + This is correct and may appear convenient, but you need to be careful + with this as SymPy will happily .expand() the product, if requested. The + resulting expression would mix the outer ``j`` with the dummies inside + the parenthesis, which makes it a different expression. To be on the + safe side, it is best to avoid such ambiguities by using unique indices + for all contractions that should be held separate. + + """ + # We call ourself recursively to determine indices of sub expressions. + + # break recursion + if isinstance(expr, Indexed): + c = expr.indices + inds, dummies = _remove_repeated(c) + return inds, {} + elif expr is None: + return set(), {} + elif isinstance(expr, Idx): + return {expr}, {} + elif expr.is_Atom: + return set(), {} + + + # recurse via specialized functions + else: + if expr.is_Mul: + return _get_indices_Mul(expr) + elif expr.is_Add: + return _get_indices_Add(expr) + elif expr.is_Pow or isinstance(expr, exp): + return _get_indices_Pow(expr) + + elif isinstance(expr, Piecewise): + # FIXME: No support for Piecewise yet + return set(), {} + elif isinstance(expr, Function): + # Support ufunc like behaviour by returning indices from arguments. + # Functions do not interpret repeated indices across arguments + # as summation + ind0 = set() + for arg in expr.args: + ind, sym = get_indices(arg) + ind0 |= ind + return ind0, sym + + # this test is expensive, so it should be at the end + elif not expr.has(Indexed): + return set(), {} + raise NotImplementedError( + "FIXME: No specialized handling of type %s" % type(expr)) + + +def get_contraction_structure(expr): + """Determine dummy indices of ``expr`` and describe its structure + + By *dummy* we mean indices that are summation indices. + + The structure of the expression is determined and described as follows: + + 1) A conforming summation of Indexed objects is described with a dict where + the keys are summation indices and the corresponding values are sets + containing all terms for which the summation applies. All Add objects + in the SymPy expression tree are described like this. + + 2) For all nodes in the SymPy expression tree that are *not* of type Add, the + following applies: + + If a node discovers contractions in one of its arguments, the node + itself will be stored as a key in the dict. For that key, the + corresponding value is a list of dicts, each of which is the result of a + recursive call to get_contraction_structure(). The list contains only + dicts for the non-trivial deeper contractions, omitting dicts with None + as the one and only key. + + .. Note:: The presence of expressions among the dictionary keys indicates + multiple levels of index contractions. A nested dict displays nested + contractions and may itself contain dicts from a deeper level. In + practical calculations the summation in the deepest nested level must be + calculated first so that the outer expression can access the resulting + indexed object. + + Examples + ======== + + >>> from sympy.tensor.index_methods import get_contraction_structure + >>> from sympy import default_sort_key + >>> from sympy.tensor import IndexedBase, Idx + >>> x, y, A = map(IndexedBase, ['x', 'y', 'A']) + >>> i, j, k, l = map(Idx, ['i', 'j', 'k', 'l']) + >>> get_contraction_structure(x[i]*y[i] + A[j, j]) + {(i,): {x[i]*y[i]}, (j,): {A[j, j]}} + >>> get_contraction_structure(x[i]*y[j]) + {None: {x[i]*y[j]}} + + A multiplication of contracted factors results in nested dicts representing + the internal contractions. + + >>> d = get_contraction_structure(x[i, i]*y[j, j]) + >>> sorted(d.keys(), key=default_sort_key) + [None, x[i, i]*y[j, j]] + + In this case, the product has no contractions: + + >>> d[None] + {x[i, i]*y[j, j]} + + Factors are contracted "first": + + >>> sorted(d[x[i, i]*y[j, j]], key=default_sort_key) + [{(i,): {x[i, i]}}, {(j,): {y[j, j]}}] + + A parenthesized Add object is also returned as a nested dictionary. The + term containing the parenthesis is a Mul with a contraction among the + arguments, so it will be found as a key in the result. It stores the + dictionary resulting from a recursive call on the Add expression. + + >>> d = get_contraction_structure(x[i]*(y[i] + A[i, j]*x[j])) + >>> sorted(d.keys(), key=default_sort_key) + [(A[i, j]*x[j] + y[i])*x[i], (i,)] + >>> d[(i,)] + {(A[i, j]*x[j] + y[i])*x[i]} + >>> d[x[i]*(A[i, j]*x[j] + y[i])] + [{None: {y[i]}, (j,): {A[i, j]*x[j]}}] + + Powers with contractions in either base or exponent will also be found as + keys in the dictionary, mapping to a list of results from recursive calls: + + >>> d = get_contraction_structure(A[j, j]**A[i, i]) + >>> d[None] + {A[j, j]**A[i, i]} + >>> nested_contractions = d[A[j, j]**A[i, i]] + >>> nested_contractions[0] + {(j,): {A[j, j]}} + >>> nested_contractions[1] + {(i,): {A[i, i]}} + + The description of the contraction structure may appear complicated when + represented with a string in the above examples, but it is easy to iterate + over: + + >>> from sympy import Expr + >>> for key in d: + ... if isinstance(key, Expr): + ... continue + ... for term in d[key]: + ... if term in d: + ... # treat deepest contraction first + ... pass + ... # treat outermost contactions here + + """ + + # We call ourself recursively to inspect sub expressions. + + if isinstance(expr, Indexed): + junk, key = _remove_repeated(expr.indices) + return {key or None: {expr}} + elif expr.is_Atom: + return {None: {expr}} + elif expr.is_Mul: + junk, junk, key = _get_indices_Mul(expr, return_dummies=True) + result = {key or None: {expr}} + # recurse on every factor + nested = [] + for fac in expr.args: + facd = get_contraction_structure(fac) + if not (None in facd and len(facd) == 1): + nested.append(facd) + if nested: + result[expr] = nested + return result + elif expr.is_Pow or isinstance(expr, exp): + # recurse in base and exp separately. If either has internal + # contractions we must include ourselves as a key in the returned dict + b, e = expr.as_base_exp() + dbase = get_contraction_structure(b) + dexp = get_contraction_structure(e) + + dicts = [] + for d in dbase, dexp: + if not (None in d and len(d) == 1): + dicts.append(d) + result = {None: {expr}} + if dicts: + result[expr] = dicts + return result + elif expr.is_Add: + # Note: we just collect all terms with identical summation indices, We + # do nothing to identify equivalent terms here, as this would require + # substitutions or pattern matching in expressions of unknown + # complexity. + result = {} + for term in expr.args: + # recurse on every term + d = get_contraction_structure(term) + for key in d: + if key in result: + result[key] |= d[key] + else: + result[key] = d[key] + return result + + elif isinstance(expr, Piecewise): + # FIXME: No support for Piecewise yet + return {None: expr} + elif isinstance(expr, Function): + # Collect non-trivial contraction structures in each argument + # We do not report repeated indices in separate arguments as a + # contraction + deeplist = [] + for arg in expr.args: + deep = get_contraction_structure(arg) + if not (None in deep and len(deep) == 1): + deeplist.append(deep) + d = {None: {expr}} + if deeplist: + d[expr] = deeplist + return d + + # this test is expensive, so it should be at the end + elif not expr.has(Indexed): + return {None: {expr}} + raise NotImplementedError( + "FIXME: No specialized handling of type %s" % type(expr)) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/indexed.py b/.venv/lib/python3.13/site-packages/sympy/tensor/indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..feddad21e52bbab2e1243beafdb11f30b2eded4d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/indexed.py @@ -0,0 +1,793 @@ +r"""Module that defines indexed objects. + +The classes ``IndexedBase``, ``Indexed``, and ``Idx`` represent a +matrix element ``M[i, j]`` as in the following diagram:: + + 1) The Indexed class represents the entire indexed object. + | + ___|___ + ' ' + M[i, j] + / \__\______ + | | + | | + | 2) The Idx class represents indices; each Idx can + | optionally contain information about its range. + | + 3) IndexedBase represents the 'stem' of an indexed object, here `M`. + The stem used by itself is usually taken to represent the entire + array. + +There can be any number of indices on an Indexed object. No +transformation properties are implemented in these Base objects, but +implicit contraction of repeated indices is supported. + +Note that the support for complicated (i.e. non-atomic) integer +expressions as indices is limited. (This should be improved in +future releases.) + +Examples +======== + +To express the above matrix element example you would write: + +>>> from sympy import symbols, IndexedBase, Idx +>>> M = IndexedBase('M') +>>> i, j = symbols('i j', cls=Idx) +>>> M[i, j] +M[i, j] + +Repeated indices in a product implies a summation, so to express a +matrix-vector product in terms of Indexed objects: + +>>> x = IndexedBase('x') +>>> M[i, j]*x[j] +M[i, j]*x[j] + +If the indexed objects will be converted to component based arrays, e.g. +with the code printers or the autowrap framework, you also need to provide +(symbolic or numerical) dimensions. This can be done by passing an +optional shape parameter to IndexedBase upon construction: + +>>> dim1, dim2 = symbols('dim1 dim2', integer=True) +>>> A = IndexedBase('A', shape=(dim1, 2*dim1, dim2)) +>>> A.shape +(dim1, 2*dim1, dim2) +>>> A[i, j, 3].shape +(dim1, 2*dim1, dim2) + +If an IndexedBase object has no shape information, it is assumed that the +array is as large as the ranges of its indices: + +>>> n, m = symbols('n m', integer=True) +>>> i = Idx('i', m) +>>> j = Idx('j', n) +>>> M[i, j].shape +(m, n) +>>> M[i, j].ranges +[(0, m - 1), (0, n - 1)] + +The above can be compared with the following: + +>>> A[i, 2, j].shape +(dim1, 2*dim1, dim2) +>>> A[i, 2, j].ranges +[(0, m - 1), None, (0, n - 1)] + +To analyze the structure of indexed expressions, you can use the methods +get_indices() and get_contraction_structure(): + +>>> from sympy.tensor import get_indices, get_contraction_structure +>>> get_indices(A[i, j, j]) +({i}, {}) +>>> get_contraction_structure(A[i, j, j]) +{(j,): {A[i, j, j]}} + +See the appropriate docstrings for a detailed explanation of the output. +""" + +# TODO: (some ideas for improvement) +# +# o test and guarantee numpy compatibility +# - implement full support for broadcasting +# - strided arrays +# +# o more functions to analyze indexed expressions +# - identify standard constructs, e.g matrix-vector product in a subexpression +# +# o functions to generate component based arrays (numpy and sympy.Matrix) +# - generate a single array directly from Indexed +# - convert simple sub-expressions +# +# o sophisticated indexing (possibly in subclasses to preserve simplicity) +# - Idx with range smaller than dimension of Indexed +# - Idx with stepsize != 1 +# - Idx with step determined by function call +from collections.abc import Iterable + +from sympy.core.numbers import Number +from sympy.core.assumptions import StdFactKB +from sympy.core import Expr, Tuple, sympify, S +from sympy.core.symbol import _filter_assumptions, Symbol +from sympy.core.logic import fuzzy_bool, fuzzy_not +from sympy.core.sympify import _sympify +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.multipledispatch import dispatch +from sympy.utilities.iterables import is_sequence, NotIterable +from sympy.utilities.misc import filldedent + + +class IndexException(Exception): + pass + + +class Indexed(Expr): + """Represents a mathematical object with indices. + + >>> from sympy import Indexed, IndexedBase, Idx, symbols + >>> i, j = symbols('i j', cls=Idx) + >>> Indexed('A', i, j) + A[i, j] + + It is recommended that ``Indexed`` objects be created by indexing ``IndexedBase``: + ``IndexedBase('A')[i, j]`` instead of ``Indexed(IndexedBase('A'), i, j)``. + + >>> A = IndexedBase('A') + >>> a_ij = A[i, j] # Prefer this, + >>> b_ij = Indexed(A, i, j) # over this. + >>> a_ij == b_ij + True + + """ + is_Indexed = True + is_symbol = True + is_Atom = True + + def __new__(cls, base, *args, **kw_args): + from sympy.tensor.array.ndim_array import NDimArray + from sympy.matrices.matrixbase import MatrixBase + + if not args: + raise IndexException("Indexed needs at least one index.") + if isinstance(base, (str, Symbol)): + base = IndexedBase(base) + elif not hasattr(base, '__getitem__') and not isinstance(base, IndexedBase): + raise TypeError(filldedent(""" + The base can only be replaced with a string, Symbol, + IndexedBase or an object with a method for getting + items (i.e. an object with a `__getitem__` method). + """)) + args = list(map(sympify, args)) + if isinstance(base, (NDimArray, Iterable, Tuple, MatrixBase)) and all(i.is_number for i in args): + if len(args) == 1: + return base[args[0]] + else: + return base[args] + + base = _sympify(base) + + obj = Expr.__new__(cls, base, *args, **kw_args) + + IndexedBase._set_assumptions(obj, base.assumptions0) + + return obj + + def _hashable_content(self): + return super()._hashable_content() + tuple(sorted(self.assumptions0.items())) + + @property + def name(self): + return str(self) + + @property + def _diff_wrt(self): + """Allow derivatives with respect to an ``Indexed`` object.""" + return True + + def _eval_derivative(self, wrt): + from sympy.tensor.array.ndim_array import NDimArray + + if isinstance(wrt, Indexed) and wrt.base == self.base: + if len(self.indices) != len(wrt.indices): + msg = "Different # of indices: d({!s})/d({!s})".format(self, + wrt) + raise IndexException(msg) + result = S.One + for index1, index2 in zip(self.indices, wrt.indices): + result *= KroneckerDelta(index1, index2) + return result + elif isinstance(self.base, NDimArray): + from sympy.tensor.array import derive_by_array + return Indexed(derive_by_array(self.base, wrt), *self.args[1:]) + else: + if Tuple(self.indices).has(wrt): + return S.NaN + return S.Zero + + @property + def assumptions0(self): + return {k: v for k, v in self._assumptions.items() if v is not None} + + @property + def base(self): + """Returns the ``IndexedBase`` of the ``Indexed`` object. + + Examples + ======== + + >>> from sympy import Indexed, IndexedBase, Idx, symbols + >>> i, j = symbols('i j', cls=Idx) + >>> Indexed('A', i, j).base + A + >>> B = IndexedBase('B') + >>> B == B[i, j].base + True + + """ + return self.args[0] + + @property + def indices(self): + """ + Returns the indices of the ``Indexed`` object. + + Examples + ======== + + >>> from sympy import Indexed, Idx, symbols + >>> i, j = symbols('i j', cls=Idx) + >>> Indexed('A', i, j).indices + (i, j) + + """ + return self.args[1:] + + @property + def rank(self): + """ + Returns the rank of the ``Indexed`` object. + + Examples + ======== + + >>> from sympy import Indexed, Idx, symbols + >>> i, j, k, l, m = symbols('i:m', cls=Idx) + >>> Indexed('A', i, j).rank + 2 + >>> q = Indexed('A', i, j, k, l, m) + >>> q.rank + 5 + >>> q.rank == len(q.indices) + True + + """ + return len(self.args) - 1 + + @property + def shape(self): + """Returns a list with dimensions of each index. + + Dimensions is a property of the array, not of the indices. Still, if + the ``IndexedBase`` does not define a shape attribute, it is assumed + that the ranges of the indices correspond to the shape of the array. + + >>> from sympy import IndexedBase, Idx, symbols + >>> n, m = symbols('n m', integer=True) + >>> i = Idx('i', m) + >>> j = Idx('j', m) + >>> A = IndexedBase('A', shape=(n, n)) + >>> B = IndexedBase('B') + >>> A[i, j].shape + (n, n) + >>> B[i, j].shape + (m, m) + """ + + if self.base.shape: + return self.base.shape + sizes = [] + for i in self.indices: + upper = getattr(i, 'upper', None) + lower = getattr(i, 'lower', None) + if None in (upper, lower): + raise IndexException(filldedent(""" + Range is not defined for all indices in: %s""" % self)) + try: + size = upper - lower + 1 + except TypeError: + raise IndexException(filldedent(""" + Shape cannot be inferred from Idx with + undefined range: %s""" % self)) + sizes.append(size) + return Tuple(*sizes) + + @property + def ranges(self): + """Returns a list of tuples with lower and upper range of each index. + + If an index does not define the data members upper and lower, the + corresponding slot in the list contains ``None`` instead of a tuple. + + Examples + ======== + + >>> from sympy import Indexed,Idx, symbols + >>> Indexed('A', Idx('i', 2), Idx('j', 4), Idx('k', 8)).ranges + [(0, 1), (0, 3), (0, 7)] + >>> Indexed('A', Idx('i', 3), Idx('j', 3), Idx('k', 3)).ranges + [(0, 2), (0, 2), (0, 2)] + >>> x, y, z = symbols('x y z', integer=True) + >>> Indexed('A', x, y, z).ranges + [None, None, None] + + """ + ranges = [] + sentinel = object() + for i in self.indices: + upper = getattr(i, 'upper', sentinel) + lower = getattr(i, 'lower', sentinel) + if sentinel not in (upper, lower): + ranges.append((lower, upper)) + else: + ranges.append(None) + return ranges + + def _sympystr(self, p): + indices = list(map(p.doprint, self.indices)) + return "%s[%s]" % (p.doprint(self.base), ", ".join(indices)) + + @property + def free_symbols(self): + base_free_symbols = self.base.free_symbols + indices_free_symbols = { + fs for i in self.indices for fs in i.free_symbols} + if base_free_symbols: + return {self} | base_free_symbols | indices_free_symbols + else: + return indices_free_symbols + + @property + def expr_free_symbols(self): + from sympy.utilities.exceptions import sympy_deprecation_warning + sympy_deprecation_warning(""" + The expr_free_symbols property is deprecated. Use free_symbols to get + the free symbols of an expression. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-expr-free-symbols") + + return {self} + + +class IndexedBase(Expr, NotIterable): + """Represent the base or stem of an indexed object + + The IndexedBase class represent an array that contains elements. The main purpose + of this class is to allow the convenient creation of objects of the Indexed + class. The __getitem__ method of IndexedBase returns an instance of + Indexed. Alone, without indices, the IndexedBase class can be used as a + notation for e.g. matrix equations, resembling what you could do with the + Symbol class. But, the IndexedBase class adds functionality that is not + available for Symbol instances: + + - An IndexedBase object can optionally store shape information. This can + be used in to check array conformance and conditions for numpy + broadcasting. (TODO) + - An IndexedBase object implements syntactic sugar that allows easy symbolic + representation of array operations, using implicit summation of + repeated indices. + - The IndexedBase object symbolizes a mathematical structure equivalent + to arrays, and is recognized as such for code generation and automatic + compilation and wrapping. + + >>> from sympy.tensor import IndexedBase, Idx + >>> from sympy import symbols + >>> A = IndexedBase('A'); A + A + >>> type(A) + + + When an IndexedBase object receives indices, it returns an array with named + axes, represented by an Indexed object: + + >>> i, j = symbols('i j', integer=True) + >>> A[i, j, 2] + A[i, j, 2] + >>> type(A[i, j, 2]) + + + The IndexedBase constructor takes an optional shape argument. If given, + it overrides any shape information in the indices. (But not the index + ranges!) + + >>> m, n, o, p = symbols('m n o p', integer=True) + >>> i = Idx('i', m) + >>> j = Idx('j', n) + >>> A[i, j].shape + (m, n) + >>> B = IndexedBase('B', shape=(o, p)) + >>> B[i, j].shape + (o, p) + + Assumptions can be specified with keyword arguments the same way as for Symbol: + + >>> A_real = IndexedBase('A', real=True) + >>> A_real.is_real + True + >>> A != A_real + True + + Assumptions can also be inherited if a Symbol is used to initialize the IndexedBase: + + >>> I = symbols('I', integer=True) + >>> C_inherit = IndexedBase(I) + >>> C_explicit = IndexedBase('I', integer=True) + >>> C_inherit == C_explicit + True + """ + is_symbol = True + is_Atom = True + + @staticmethod + def _set_assumptions(obj, assumptions): + """Set assumptions on obj, making sure to apply consistent values.""" + tmp_asm_copy = assumptions.copy() + is_commutative = fuzzy_bool(assumptions.get('commutative', True)) + assumptions['commutative'] = is_commutative + obj._assumptions = StdFactKB(assumptions) + obj._assumptions._generator = tmp_asm_copy # Issue #8873 + + def __new__(cls, label, shape=None, *, offset=S.Zero, strides=None, **kw_args): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array.ndim_array import NDimArray + + assumptions, kw_args = _filter_assumptions(kw_args) + if isinstance(label, str): + label = Symbol(label, **assumptions) + elif isinstance(label, Symbol): + assumptions = label._merge(assumptions) + elif isinstance(label, (MatrixBase, NDimArray)): + return label + elif isinstance(label, Iterable): + return _sympify(label) + else: + label = _sympify(label) + + if is_sequence(shape): + shape = Tuple(*shape) + elif shape is not None: + shape = Tuple(shape) + + if shape is not None: + obj = Expr.__new__(cls, label, shape) + else: + obj = Expr.__new__(cls, label) + obj._shape = shape + obj._offset = offset + obj._strides = strides + obj._name = str(label) + + IndexedBase._set_assumptions(obj, assumptions) + return obj + + @property + def name(self): + return self._name + + def _hashable_content(self): + return super()._hashable_content() + tuple(sorted(self.assumptions0.items())) + + @property + def assumptions0(self): + return {k: v for k, v in self._assumptions.items() if v is not None} + + def __getitem__(self, indices, **kw_args): + if is_sequence(indices): + # Special case needed because M[*my_tuple] is a syntax error. + if self.shape and len(self.shape) != len(indices): + raise IndexException("Rank mismatch.") + return Indexed(self, *indices, **kw_args) + else: + if self.shape and len(self.shape) != 1: + raise IndexException("Rank mismatch.") + return Indexed(self, indices, **kw_args) + + @property + def shape(self): + """Returns the shape of the ``IndexedBase`` object. + + Examples + ======== + + >>> from sympy import IndexedBase, Idx + >>> from sympy.abc import x, y + >>> IndexedBase('A', shape=(x, y)).shape + (x, y) + + Note: If the shape of the ``IndexedBase`` is specified, it will override + any shape information given by the indices. + + >>> A = IndexedBase('A', shape=(x, y)) + >>> B = IndexedBase('B') + >>> i = Idx('i', 2) + >>> j = Idx('j', 1) + >>> A[i, j].shape + (x, y) + >>> B[i, j].shape + (2, 1) + + """ + return self._shape + + @property + def strides(self): + """Returns the strided scheme for the ``IndexedBase`` object. + + Normally this is a tuple denoting the number of + steps to take in the respective dimension when traversing + an array. For code generation purposes strides='C' and + strides='F' can also be used. + + strides='C' would mean that code printer would unroll + in row-major order and 'F' means unroll in column major + order. + + """ + + return self._strides + + @property + def offset(self): + """Returns the offset for the ``IndexedBase`` object. + + This is the value added to the resulting index when the + 2D Indexed object is unrolled to a 1D form. Used in code + generation. + + Examples + ========== + >>> from sympy.printing import ccode + >>> from sympy.tensor import IndexedBase, Idx + >>> from sympy import symbols + >>> l, m, n, o = symbols('l m n o', integer=True) + >>> A = IndexedBase('A', strides=(l, m, n), offset=o) + >>> i, j, k = map(Idx, 'ijk') + >>> ccode(A[i, j, k]) + 'A[l*i + m*j + n*k + o]' + + """ + return self._offset + + @property + def label(self): + """Returns the label of the ``IndexedBase`` object. + + Examples + ======== + + >>> from sympy import IndexedBase + >>> from sympy.abc import x, y + >>> IndexedBase('A', shape=(x, y)).label + A + + """ + return self.args[0] + + def _sympystr(self, p): + return p.doprint(self.label) + + +class Idx(Expr): + """Represents an integer index as an ``Integer`` or integer expression. + + There are a number of ways to create an ``Idx`` object. The constructor + takes two arguments: + + ``label`` + An integer or a symbol that labels the index. + ``range`` + Optionally you can specify a range as either + + * ``Symbol`` or integer: This is interpreted as a dimension. Lower and + upper bounds are set to ``0`` and ``range - 1``, respectively. + * ``tuple``: The two elements are interpreted as the lower and upper + bounds of the range, respectively. + + Note: bounds of the range are assumed to be either integer or infinite (oo + and -oo are allowed to specify an unbounded range). If ``n`` is given as a + bound, then ``n.is_integer`` must not return false. + + For convenience, if the label is given as a string it is automatically + converted to an integer symbol. (Note: this conversion is not done for + range or dimension arguments.) + + Examples + ======== + + >>> from sympy import Idx, symbols, oo + >>> n, i, L, U = symbols('n i L U', integer=True) + + If a string is given for the label an integer ``Symbol`` is created and the + bounds are both ``None``: + + >>> idx = Idx('qwerty'); idx + qwerty + >>> idx.lower, idx.upper + (None, None) + + Both upper and lower bounds can be specified: + + >>> idx = Idx(i, (L, U)); idx + i + >>> idx.lower, idx.upper + (L, U) + + When only a single bound is given it is interpreted as the dimension + and the lower bound defaults to 0: + + >>> idx = Idx(i, n); idx.lower, idx.upper + (0, n - 1) + >>> idx = Idx(i, 4); idx.lower, idx.upper + (0, 3) + >>> idx = Idx(i, oo); idx.lower, idx.upper + (0, oo) + + """ + + is_integer = True + is_finite = True + is_real = True + is_symbol = True + is_Atom = True + _diff_wrt = True + + def __new__(cls, label, range=None, **kw_args): + + if isinstance(label, str): + label = Symbol(label, integer=True) + label, range = list(map(sympify, (label, range))) + + if label.is_Number: + if not label.is_integer: + raise TypeError("Index is not an integer number.") + return label + + if not label.is_integer: + raise TypeError("Idx object requires an integer label.") + + elif is_sequence(range): + if len(range) != 2: + raise ValueError(filldedent(""" + Idx range tuple must have length 2, but got %s""" % len(range))) + for bound in range: + if (bound.is_integer is False and bound is not S.Infinity + and bound is not S.NegativeInfinity): + raise TypeError("Idx object requires integer bounds.") + args = label, Tuple(*range) + elif isinstance(range, Expr): + if range is not S.Infinity and fuzzy_not(range.is_integer): + raise TypeError("Idx object requires an integer dimension.") + args = label, Tuple(0, range - 1) + elif range: + raise TypeError(filldedent(""" + The range must be an ordered iterable or + integer SymPy expression.""")) + else: + args = label, + + obj = Expr.__new__(cls, *args, **kw_args) + obj._assumptions["finite"] = True + obj._assumptions["real"] = True + return obj + + @property + def label(self): + """Returns the label (Integer or integer expression) of the Idx object. + + Examples + ======== + + >>> from sympy import Idx, Symbol + >>> x = Symbol('x', integer=True) + >>> Idx(x).label + x + >>> j = Symbol('j', integer=True) + >>> Idx(j).label + j + >>> Idx(j + 1).label + j + 1 + + """ + return self.args[0] + + @property + def lower(self): + """Returns the lower bound of the ``Idx``. + + Examples + ======== + + >>> from sympy import Idx + >>> Idx('j', 2).lower + 0 + >>> Idx('j', 5).lower + 0 + >>> Idx('j').lower is None + True + + """ + try: + return self.args[1][0] + except IndexError: + return + + @property + def upper(self): + """Returns the upper bound of the ``Idx``. + + Examples + ======== + + >>> from sympy import Idx + >>> Idx('j', 2).upper + 1 + >>> Idx('j', 5).upper + 4 + >>> Idx('j').upper is None + True + + """ + try: + return self.args[1][1] + except IndexError: + return + + def _sympystr(self, p): + return p.doprint(self.label) + + @property + def name(self): + return self.label.name if self.label.is_Symbol else str(self.label) + + @property + def free_symbols(self): + return {self} + + +@dispatch(Idx, Idx) +def _eval_is_ge(lhs, rhs): # noqa:F811 + + other_upper = rhs if rhs.upper is None else rhs.upper + other_lower = rhs if rhs.lower is None else rhs.lower + + if lhs.lower is not None and (lhs.lower >= other_upper) == True: + return True + if lhs.upper is not None and (lhs.upper < other_lower) == True: + return False + return None + + +@dispatch(Idx, Number) # type:ignore +def _eval_is_ge(lhs, rhs): # noqa:F811 + + other_upper = rhs + other_lower = rhs + + if lhs.lower is not None and (lhs.lower >= other_upper) == True: + return True + if lhs.upper is not None and (lhs.upper < other_lower) == True: + return False + return None + + +@dispatch(Number, Idx) # type:ignore +def _eval_is_ge(lhs, rhs): # noqa:F811 + + other_upper = lhs + other_lower = lhs + + if rhs.upper is not None and (rhs.upper <= other_lower) == True: + return True + if rhs.lower is not None and (rhs.lower > other_upper) == True: + return False + return None diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/tensor.py b/.venv/lib/python3.13/site-packages/sympy/tensor/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..579e7c7a86c2a1f18ab889af32ce0053a729ff5f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/tensor.py @@ -0,0 +1,5265 @@ +""" +This module defines tensors with abstract index notation. + +The abstract index notation has been first formalized by Penrose. + +Tensor indices are formal objects, with a tensor type; there is no +notion of index range, it is only possible to assign the dimension, +used to trace the Kronecker delta; the dimension can be a Symbol. + +The Einstein summation convention is used. +The covariant indices are indicated with a minus sign in front of the index. + +For instance the tensor ``t = p(a)*A(b,c)*q(-c)`` has the index ``c`` +contracted. + +A tensor expression ``t`` can be called; called with its +indices in sorted order it is equal to itself: +in the above example ``t(a, b) == t``; +one can call ``t`` with different indices; ``t(c, d) == p(c)*A(d,a)*q(-a)``. + +The contracted indices are dummy indices, internally they have no name, +the indices being represented by a graph-like structure. + +Tensors are put in canonical form using ``canon_bp``, which uses +the Butler-Portugal algorithm for canonicalization using the monoterm +symmetries of the tensors. + +If there is a (anti)symmetric metric, the indices can be raised and +lowered when the tensor is put in canonical form. +""" + +from __future__ import annotations +from typing import Any +from functools import reduce +from math import prod + +from abc import abstractmethod, ABC +from collections import defaultdict +import operator +import itertools + +from sympy.core.numbers import (Integer, Rational) +from sympy.combinatorics import Permutation +from sympy.combinatorics.tensor_can import get_symmetric_group_sgs, \ + bsgs_direct_product, canonicalize, riemann_bsgs +from sympy.core import Basic, Expr, sympify, Add, Mul, S +from sympy.core.cache import clear_cache +from sympy.core.containers import Tuple, Dict +from sympy.core.function import WildFunction +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Symbol, symbols, Wild +from sympy.core.sympify import CantSympify, _sympify +from sympy.core.operations import AssocOp +from sympy.external.gmpy import SYMPY_INTS +from sympy.matrices import eye +from sympy.utilities.exceptions import (sympy_deprecation_warning, + SymPyDeprecationWarning, + ignore_warnings) +from sympy.utilities.decorator import memoize_property, deprecated +from sympy.utilities.iterables import sift + + +def deprecate_data(): + sympy_deprecation_warning( + """ + The data attribute of TensorIndexType is deprecated. Use The + replace_with_arrays() method instead. + """, + deprecated_since_version="1.4", + active_deprecations_target="deprecated-tensorindextype-attrs", + stacklevel=4, + ) + +def deprecate_fun_eval(): + sympy_deprecation_warning( + """ + The Tensor.fun_eval() method is deprecated. Use + Tensor.substitute_indices() instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensor-fun-eval", + stacklevel=4, + ) + + +def deprecate_call(): + sympy_deprecation_warning( + """ + Calling a tensor like Tensor(*indices) is deprecated. Use + Tensor.substitute_indices() instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensor-fun-eval", + stacklevel=4, + ) + + +class _IndexStructure(CantSympify): + """ + This class handles the indices (free and dummy ones). It contains the + algorithms to manage the dummy indices replacements and contractions of + free indices under multiplications of tensor expressions, as well as stuff + related to canonicalization sorting, getting the permutation of the + expression and so on. It also includes tools to get the ``TensorIndex`` + objects corresponding to the given index structure. + """ + + def __init__(self, free, dum, index_types, indices, canon_bp=False): + self.free = free + self.dum = dum + self.index_types = index_types + self.indices = indices + self._ext_rank = len(self.free) + 2*len(self.dum) + self.dum.sort(key=lambda x: x[0]) + + @staticmethod + def from_indices(*indices): + """ + Create a new ``_IndexStructure`` object from a list of ``indices``. + + Explanation + =========== + + ``indices`` ``TensorIndex`` objects, the indices. Contractions are + detected upon construction. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, _IndexStructure + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2, m3 = tensor_indices('m0,m1,m2,m3', Lorentz) + >>> _IndexStructure.from_indices(m0, m1, -m1, m3) + _IndexStructure([(m0, 0), (m3, 3)], [(1, 2)], [Lorentz, Lorentz, Lorentz, Lorentz]) + """ + + free, dum = _IndexStructure._free_dum_from_indices(*indices) + index_types = [i.tensor_index_type for i in indices] + indices = _IndexStructure._replace_dummy_names(indices, free, dum) + return _IndexStructure(free, dum, index_types, indices) + + @staticmethod + def from_components_free_dum(components, free, dum): + index_types = [] + for component in components: + index_types.extend(component.index_types) + indices = _IndexStructure.generate_indices_from_free_dum_index_types(free, dum, index_types) + return _IndexStructure(free, dum, index_types, indices) + + @staticmethod + def _free_dum_from_indices(*indices): + """ + Convert ``indices`` into ``free``, ``dum`` for single component tensor. + + Explanation + =========== + + ``free`` list of tuples ``(index, pos, 0)``, + where ``pos`` is the position of index in + the list of indices formed by the component tensors + + ``dum`` list of tuples ``(pos_contr, pos_cov, 0, 0)`` + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, \ + _IndexStructure + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2, m3 = tensor_indices('m0,m1,m2,m3', Lorentz) + >>> _IndexStructure._free_dum_from_indices(m0, m1, -m1, m3) + ([(m0, 0), (m3, 3)], [(1, 2)]) + """ + n = len(indices) + if n == 1: + return [(indices[0], 0)], [] + + # find the positions of the free indices and of the dummy indices + free = [True]*len(indices) + index_dict = {} + dum = [] + for i, index in enumerate(indices): + name = index.name + typ = index.tensor_index_type + contr = index.is_up + if (name, typ) in index_dict: + # found a pair of dummy indices + is_contr, pos = index_dict[(name, typ)] + # check consistency and update free + if is_contr: + if contr: + raise ValueError('two equal contravariant indices in slots %d and %d' %(pos, i)) + else: + free[pos] = False + free[i] = False + else: + if contr: + free[pos] = False + free[i] = False + else: + raise ValueError('two equal covariant indices in slots %d and %d' %(pos, i)) + if contr: + dum.append((i, pos)) + else: + dum.append((pos, i)) + else: + index_dict[(name, typ)] = index.is_up, i + + free = [(index, i) for i, index in enumerate(indices) if free[i]] + free.sort() + return free, dum + + def get_indices(self): + """ + Get a list of indices, creating new tensor indices to complete dummy indices. + """ + return self.indices[:] + + @staticmethod + def generate_indices_from_free_dum_index_types(free, dum, index_types): + indices = [None]*(len(free)+2*len(dum)) + for idx, pos in free: + indices[pos] = idx + + generate_dummy_name = _IndexStructure._get_generator_for_dummy_indices(free) + for pos1, pos2 in dum: + typ1 = index_types[pos1] + indname = generate_dummy_name(typ1) + indices[pos1] = TensorIndex(indname, typ1, True) + indices[pos2] = TensorIndex(indname, typ1, False) + + return _IndexStructure._replace_dummy_names(indices, free, dum) + + @staticmethod + def _get_generator_for_dummy_indices(free): + cdt = defaultdict(int) + # if the free indices have names with dummy_name, start with an + # index higher than those for the dummy indices + # to avoid name collisions + for indx, ipos in free: + if indx.name.split('_')[0] == indx.tensor_index_type.dummy_name: + cdt[indx.tensor_index_type] = max(cdt[indx.tensor_index_type], int(indx.name.split('_')[1]) + 1) + + def dummy_name_gen(tensor_index_type): + nd = str(cdt[tensor_index_type]) + cdt[tensor_index_type] += 1 + return tensor_index_type.dummy_name + '_' + nd + + return dummy_name_gen + + @staticmethod + def _replace_dummy_names(indices, free, dum): + dum.sort(key=lambda x: x[0]) + new_indices = list(indices) + assert len(indices) == len(free) + 2*len(dum) + generate_dummy_name = _IndexStructure._get_generator_for_dummy_indices(free) + for ipos1, ipos2 in dum: + typ1 = new_indices[ipos1].tensor_index_type + indname = generate_dummy_name(typ1) + new_indices[ipos1] = TensorIndex(indname, typ1, True) + new_indices[ipos2] = TensorIndex(indname, typ1, False) + return new_indices + + def get_free_indices(self) -> list[TensorIndex]: + """ + Get a list of free indices. + """ + # get sorted indices according to their position: + free = sorted(self.free, key=lambda x: x[1]) + return [i[0] for i in free] + + def __str__(self): + return "_IndexStructure({}, {}, {})".format(self.free, self.dum, self.index_types) + + def __repr__(self): + return self.__str__() + + def _get_sorted_free_indices_for_canon(self): + sorted_free = self.free[:] + sorted_free.sort(key=lambda x: x[0]) + return sorted_free + + def _get_sorted_dum_indices_for_canon(self): + return sorted(self.dum, key=lambda x: x[0]) + + def _get_lexicographically_sorted_index_types(self): + permutation = self.indices_canon_args()[0] + index_types = [None]*self._ext_rank + for i, it in enumerate(self.index_types): + index_types[permutation(i)] = it + return index_types + + def _get_lexicographically_sorted_indices(self): + permutation = self.indices_canon_args()[0] + indices = [None]*self._ext_rank + for i, it in enumerate(self.indices): + indices[permutation(i)] = it + return indices + + def perm2tensor(self, g, is_canon_bp=False): + """ + Returns a ``_IndexStructure`` instance corresponding to the permutation ``g``. + + Explanation + =========== + + ``g`` permutation corresponding to the tensor in the representation + used in canonicalization + + ``is_canon_bp`` if True, then ``g`` is the permutation + corresponding to the canonical form of the tensor + """ + sorted_free = [i[0] for i in self._get_sorted_free_indices_for_canon()] + lex_index_types = self._get_lexicographically_sorted_index_types() + lex_indices = self._get_lexicographically_sorted_indices() + nfree = len(sorted_free) + rank = self._ext_rank + dum = [[None]*2 for i in range((rank - nfree)//2)] + free = [] + + index_types = [None]*rank + indices = [None]*rank + for i in range(rank): + gi = g[i] + index_types[i] = lex_index_types[gi] + indices[i] = lex_indices[gi] + if gi < nfree: + ind = sorted_free[gi] + assert index_types[i] == sorted_free[gi].tensor_index_type + free.append((ind, i)) + else: + j = gi - nfree + idum, cov = divmod(j, 2) + if cov: + dum[idum][1] = i + else: + dum[idum][0] = i + dum = [tuple(x) for x in dum] + + return _IndexStructure(free, dum, index_types, indices) + + def indices_canon_args(self): + """ + Returns ``(g, dummies, msym, v)``, the entries of ``canonicalize`` + + See ``canonicalize`` in ``tensor_can.py`` in combinatorics module. + """ + # to be called after sorted_components + from sympy.combinatorics.permutations import _af_new + n = self._ext_rank + g = [None]*n + [n, n+1] + + # Converts the symmetry of the metric into msym from .canonicalize() + # method in the combinatorics module + def metric_symmetry_to_msym(metric): + if metric is None: + return None + sym = metric.symmetry + if sym == TensorSymmetry.fully_symmetric(2): + return 0 + if sym == TensorSymmetry.fully_symmetric(-2): + return 1 + return None + + # ordered indices: first the free indices, ordered by types + # then the dummy indices, ordered by types and contravariant before + # covariant + # g[position in tensor] = position in ordered indices + for i, (indx, ipos) in enumerate(self._get_sorted_free_indices_for_canon()): + g[ipos] = i + pos = len(self.free) + j = len(self.free) + dummies = [] + prev = None + a = [] + msym = [] + for ipos1, ipos2 in self._get_sorted_dum_indices_for_canon(): + g[ipos1] = j + g[ipos2] = j + 1 + j += 2 + typ = self.index_types[ipos1] + if typ != prev: + if a: + dummies.append(a) + a = [pos, pos + 1] + prev = typ + msym.append(metric_symmetry_to_msym(typ.metric)) + else: + a.extend([pos, pos + 1]) + pos += 2 + if a: + dummies.append(a) + + return _af_new(g), dummies, msym + + +def components_canon_args(components): + numtyp = [] + prev = None + for t in components: + if t == prev: + numtyp[-1][1] += 1 + else: + prev = t + numtyp.append([prev, 1]) + v = [] + for h, n in numtyp: + if h.comm in (0, 1): + comm = h.comm + else: + comm = TensorManager.get_comm(h.comm, h.comm) + v.append((h.symmetry.base, h.symmetry.generators, n, comm)) + return v + + +class _TensorDataLazyEvaluator(CantSympify): + """ + EXPERIMENTAL: do not rely on this class, it may change without deprecation + warnings in future versions of SymPy. + + Explanation + =========== + + This object contains the logic to associate components data to a tensor + expression. Components data are set via the ``.data`` property of tensor + expressions, is stored inside this class as a mapping between the tensor + expression and the ``ndarray``. + + Computations are executed lazily: whereas the tensor expressions can have + contractions, tensor products, and additions, components data are not + computed until they are accessed by reading the ``.data`` property + associated to the tensor expression. + """ + _substitutions_dict: dict[Any, Any] = {} + _substitutions_dict_tensmul: dict[Any, Any] = {} + + def __getitem__(self, key): + dat = self._get(key) + if dat is None: + return None + + from .array import NDimArray + if not isinstance(dat, NDimArray): + return dat + + if dat.rank() == 0: + return dat[()] + elif dat.rank() == 1 and len(dat) == 1: + return dat[0] + return dat + + def _get(self, key): + """ + Retrieve ``data`` associated with ``key``. + + Explanation + =========== + + This algorithm looks into ``self._substitutions_dict`` for all + ``TensorHead`` in the ``TensExpr`` (or just ``TensorHead`` if key is a + TensorHead instance). It reconstructs the components data that the + tensor expression should have by performing on components data the + operations that correspond to the abstract tensor operations applied. + + Metric tensor is handled in a different manner: it is pre-computed in + ``self._substitutions_dict_tensmul``. + """ + if key in self._substitutions_dict: + return self._substitutions_dict[key] + + if isinstance(key, TensorHead): + return None + + if isinstance(key, Tensor): + # special case to handle metrics. Metric tensors cannot be + # constructed through contraction by the metric, their + # components show if they are a matrix or its inverse. + signature = tuple([i.is_up for i in key.get_indices()]) + srch = (key.component,) + signature + if srch in self._substitutions_dict_tensmul: + return self._substitutions_dict_tensmul[srch] + array_list = [self.data_from_tensor(key)] + return self.data_contract_dum(array_list, key.dum, key.ext_rank) + + if isinstance(key, TensMul): + tensmul_args = key.args + if len(tensmul_args) == 1 and len(tensmul_args[0].components) == 1: + # special case to handle metrics. Metric tensors cannot be + # constructed through contraction by the metric, their + # components show if they are a matrix or its inverse. + signature = tuple([i.is_up for i in tensmul_args[0].get_indices()]) + srch = (tensmul_args[0].components[0],) + signature + if srch in self._substitutions_dict_tensmul: + return self._substitutions_dict_tensmul[srch] + #data_list = [self.data_from_tensor(i) for i in tensmul_args if isinstance(i, TensExpr)] + data_list = [self.data_from_tensor(i) if isinstance(i, Tensor) else i.data for i in tensmul_args if isinstance(i, TensExpr)] + coeff = prod([i for i in tensmul_args if not isinstance(i, TensExpr)]) + if all(i is None for i in data_list): + return None + if any(i is None for i in data_list): + raise ValueError("Mixing tensors with associated components "\ + "data with tensors without components data") + data_result = self.data_contract_dum(data_list, key.dum, key.ext_rank) + return coeff*data_result + + if isinstance(key, TensAdd): + data_list = [] + free_args_list = [] + for arg in key.args: + if isinstance(arg, TensExpr): + data_list.append(arg.data) + free_args_list.append([x[0] for x in arg.free]) + else: + data_list.append(arg) + free_args_list.append([]) + if all(i is None for i in data_list): + return None + if any(i is None for i in data_list): + raise ValueError("Mixing tensors with associated components "\ + "data with tensors without components data") + + sum_list = [] + from .array import permutedims + for data, free_args in zip(data_list, free_args_list): + if len(free_args) < 2: + sum_list.append(data) + else: + free_args_pos = {y: x for x, y in enumerate(free_args)} + axes = [free_args_pos[arg] for arg in key.free_args] + sum_list.append(permutedims(data, axes)) + return reduce(lambda x, y: x+y, sum_list) + + return None + + @staticmethod + def data_contract_dum(ndarray_list, dum, ext_rank): + from .array import tensorproduct, tensorcontraction, MutableDenseNDimArray + arrays = list(map(MutableDenseNDimArray, ndarray_list)) + prodarr = tensorproduct(*arrays) + return tensorcontraction(prodarr, *dum) + + def data_tensorhead_from_tensmul(self, data, tensmul, tensorhead): + """ + This method is used when assigning components data to a ``TensMul`` + object, it converts components data to a fully contravariant ndarray, + which is then stored according to the ``TensorHead`` key. + """ + if data is None: + return None + + return self._correct_signature_from_indices( + data, + tensmul.get_indices(), + tensmul.free, + tensmul.dum, + True) + + def data_from_tensor(self, tensor): + """ + This method corrects the components data to the right signature + (covariant/contravariant) using the metric associated with each + ``TensorIndexType``. + """ + tensorhead = tensor.component + + if tensorhead.data is None: + return None + + return self._correct_signature_from_indices( + tensorhead.data, + tensor.get_indices(), + tensor.free, + tensor.dum) + + def _assign_data_to_tensor_expr(self, key, data): + if isinstance(key, TensAdd): + raise ValueError('cannot assign data to TensAdd') + # here it is assumed that `key` is a `TensMul` instance. + if len(key.components) != 1: + raise ValueError('cannot assign data to TensMul with multiple components') + tensorhead = key.components[0] + newdata = self.data_tensorhead_from_tensmul(data, key, tensorhead) + return tensorhead, newdata + + def _check_permutations_on_data(self, tens, data): + from .array import permutedims + from .array.arrayop import Flatten + + if isinstance(tens, TensorHead): + rank = tens.rank + generators = tens.symmetry.generators + elif isinstance(tens, Tensor): + rank = tens.rank + generators = tens.components[0].symmetry.generators + elif isinstance(tens, TensorIndexType): + rank = tens.metric.rank + generators = tens.metric.symmetry.generators + + # Every generator is a permutation, check that by permuting the array + # by that permutation, the array will be the same, except for a + # possible sign change if the permutation admits it. + for gener in generators: + sign_change = +1 if (gener(rank) == rank) else -1 + data_swapped = data + last_data = data + permute_axes = list(map(gener, range(rank))) + # the order of a permutation is the number of times to get the + # identity by applying that permutation. + for i in range(gener.order()-1): + data_swapped = permutedims(data_swapped, permute_axes) + # if any value in the difference array is non-zero, raise an error: + if any(Flatten(last_data - sign_change*data_swapped)): + raise ValueError("Component data symmetry structure error") + last_data = data_swapped + + def __setitem__(self, key, value): + """ + Set the components data of a tensor object/expression. + + Explanation + =========== + + Components data are transformed to the all-contravariant form and stored + with the corresponding ``TensorHead`` object. If a ``TensorHead`` object + cannot be uniquely identified, it will raise an error. + """ + data = _TensorDataLazyEvaluator.parse_data(value) + self._check_permutations_on_data(key, data) + + # TensorHead and TensorIndexType can be assigned data directly, while + # TensMul must first convert data to a fully contravariant form, and + # assign it to its corresponding TensorHead single component. + if not isinstance(key, (TensorHead, TensorIndexType)): + key, data = self._assign_data_to_tensor_expr(key, data) + + if isinstance(key, TensorHead): + for dim, indextype in zip(data.shape, key.index_types): + if indextype.data is None: + raise ValueError("index type {} has no components data"\ + " associated (needed to raise/lower index)".format(indextype)) + if not indextype.dim.is_number: + continue + if dim != indextype.dim: + raise ValueError("wrong dimension of ndarray") + self._substitutions_dict[key] = data + + def __delitem__(self, key): + del self._substitutions_dict[key] + + def __contains__(self, key): + return key in self._substitutions_dict + + def add_metric_data(self, metric, data): + """ + Assign data to the ``metric`` tensor. The metric tensor behaves in an + anomalous way when raising and lowering indices. + + Explanation + =========== + + A fully covariant metric is the inverse transpose of the fully + contravariant metric (it is meant matrix inverse). If the metric is + symmetric, the transpose is not necessary and mixed + covariant/contravariant metrics are Kronecker deltas. + """ + # hard assignment, data should not be added to `TensorHead` for metric: + # the problem with `TensorHead` is that the metric is anomalous, i.e. + # raising and lowering the index means considering the metric or its + # inverse, this is not the case for other tensors. + self._substitutions_dict_tensmul[metric, True, True] = data + inverse_transpose = self.inverse_transpose_matrix(data) + # in symmetric spaces, the transpose is the same as the original matrix, + # the full covariant metric tensor is the inverse transpose, so this + # code will be able to handle non-symmetric metrics. + self._substitutions_dict_tensmul[metric, False, False] = inverse_transpose + # now mixed cases, these are identical to the unit matrix if the metric + # is symmetric. + m = data.tomatrix() + invt = inverse_transpose.tomatrix() + self._substitutions_dict_tensmul[metric, True, False] = m * invt + self._substitutions_dict_tensmul[metric, False, True] = invt * m + + @staticmethod + def _flip_index_by_metric(data, metric, pos): + from .array import tensorproduct, tensorcontraction + + mdim = metric.rank() + ddim = data.rank() + + if pos == 0: + data = tensorcontraction( + tensorproduct( + metric, + data + ), + (1, mdim+pos) + ) + else: + data = tensorcontraction( + tensorproduct( + data, + metric + ), + (pos, ddim) + ) + return data + + @staticmethod + def inverse_matrix(ndarray): + m = ndarray.tomatrix().inv() + return _TensorDataLazyEvaluator.parse_data(m) + + @staticmethod + def inverse_transpose_matrix(ndarray): + m = ndarray.tomatrix().inv().T + return _TensorDataLazyEvaluator.parse_data(m) + + @staticmethod + def _correct_signature_from_indices(data, indices, free, dum, inverse=False): + """ + Utility function to correct the values inside the components data + ndarray according to whether indices are covariant or contravariant. + + It uses the metric matrix to lower values of covariant indices. + """ + # change the ndarray values according covariantness/contravariantness of the indices + # use the metric + for i, indx in enumerate(indices): + if not indx.is_up and not inverse: + data = _TensorDataLazyEvaluator._flip_index_by_metric(data, indx.tensor_index_type.data, i) + elif not indx.is_up and inverse: + data = _TensorDataLazyEvaluator._flip_index_by_metric( + data, + _TensorDataLazyEvaluator.inverse_matrix(indx.tensor_index_type.data), + i + ) + return data + + @staticmethod + def _sort_data_axes(old, new): + from .array import permutedims + + new_data = old.data.copy() + + old_free = [i[0] for i in old.free] + new_free = [i[0] for i in new.free] + + for i in range(len(new_free)): + for j in range(i, len(old_free)): + if old_free[j] == new_free[i]: + old_free[i], old_free[j] = old_free[j], old_free[i] + new_data = permutedims(new_data, (i, j)) + break + return new_data + + @staticmethod + def add_rearrange_tensmul_parts(new_tensmul, old_tensmul): + def sorted_compo(): + return _TensorDataLazyEvaluator._sort_data_axes(old_tensmul, new_tensmul) + + _TensorDataLazyEvaluator._substitutions_dict[new_tensmul] = sorted_compo() + + @staticmethod + def parse_data(data): + """ + Transform ``data`` to array. The parameter ``data`` may + contain data in various formats, e.g. nested lists, SymPy ``Matrix``, + and so on. + + Examples + ======== + + >>> from sympy.tensor.tensor import _TensorDataLazyEvaluator + >>> _TensorDataLazyEvaluator.parse_data([1, 3, -6, 12]) + [1, 3, -6, 12] + + >>> _TensorDataLazyEvaluator.parse_data([[1, 2], [4, 7]]) + [[1, 2], [4, 7]] + """ + from .array import MutableDenseNDimArray + + if not isinstance(data, MutableDenseNDimArray): + if len(data) == 2 and hasattr(data[0], '__call__'): + data = MutableDenseNDimArray(data[0], data[1]) + else: + data = MutableDenseNDimArray(data) + return data + +_tensor_data_substitution_dict = _TensorDataLazyEvaluator() + + +class _TensorManager: + """ + Class to manage tensor properties. + + Notes + ===== + + Tensors belong to tensor commutation groups; each group has a label + ``comm``; there are predefined labels: + + ``0`` tensors commuting with any other tensor + + ``1`` tensors anticommuting among themselves + + ``2`` tensors not commuting, apart with those with ``comm=0`` + + Other groups can be defined using ``set_comm``; tensors in those + groups commute with those with ``comm=0``; by default they + do not commute with any other group. + """ + def __init__(self): + self._comm_init() + + def _comm_init(self): + self._comm = [{} for i in range(3)] + for i in range(3): + self._comm[0][i] = 0 + self._comm[i][0] = 0 + self._comm[1][1] = 1 + self._comm[2][1] = None + self._comm[1][2] = None + self._comm_symbols2i = {0:0, 1:1, 2:2} + self._comm_i2symbol = {0:0, 1:1, 2:2} + + @property + def comm(self): + return self._comm + + def comm_symbols2i(self, i): + """ + Get the commutation group number corresponding to ``i``. + + ``i`` can be a symbol or a number or a string. + + If ``i`` is not already defined its commutation group number + is set. + """ + if i not in self._comm_symbols2i: + n = len(self._comm) + self._comm.append({}) + self._comm[n][0] = 0 + self._comm[0][n] = 0 + self._comm_symbols2i[i] = n + self._comm_i2symbol[n] = i + return n + return self._comm_symbols2i[i] + + def comm_i2symbol(self, i): + """ + Returns the symbol corresponding to the commutation group number. + """ + return self._comm_i2symbol[i] + + def set_comm(self, i, j, c): + """ + Set the commutation parameter ``c`` for commutation groups ``i, j``. + + Parameters + ========== + + i, j : symbols representing commutation groups + + c : group commutation number + + Notes + ===== + + ``i, j`` can be symbols, strings or numbers, + apart from ``0, 1`` and ``2`` which are reserved respectively + for commuting, anticommuting tensors and tensors not commuting + with any other group apart with the commuting tensors. + For the remaining cases, use this method to set the commutation rules; + by default ``c=None``. + + The group commutation number ``c`` is assigned in correspondence + to the group commutation symbols; it can be + + 0 commuting + + 1 anticommuting + + None no commutation property + + Examples + ======== + + ``G`` and ``GH`` do not commute with themselves and commute with + each other; A is commuting. + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, TensorManager, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz') + >>> i0,i1,i2,i3,i4 = tensor_indices('i0:5', Lorentz) + >>> A = TensorHead('A', [Lorentz]) + >>> G = TensorHead('G', [Lorentz], TensorSymmetry.no_symmetry(1), 'Gcomm') + >>> GH = TensorHead('GH', [Lorentz], TensorSymmetry.no_symmetry(1), 'GHcomm') + >>> TensorManager.set_comm('Gcomm', 'GHcomm', 0) + >>> (GH(i1)*G(i0)).canon_bp() + G(i0)*GH(i1) + >>> (G(i1)*G(i0)).canon_bp() + G(i1)*G(i0) + >>> (G(i1)*A(i0)).canon_bp() + A(i0)*G(i1) + """ + if c not in (0, 1, None): + raise ValueError('`c` can assume only the values 0, 1 or None') + + i = sympify(i) + j = sympify(j) + + if i not in self._comm_symbols2i: + n = len(self._comm) + self._comm.append({}) + self._comm[n][0] = 0 + self._comm[0][n] = 0 + self._comm_symbols2i[i] = n + self._comm_i2symbol[n] = i + if j not in self._comm_symbols2i: + n = len(self._comm) + self._comm.append({}) + self._comm[0][n] = 0 + self._comm[n][0] = 0 + self._comm_symbols2i[j] = n + self._comm_i2symbol[n] = j + ni = self._comm_symbols2i[i] + nj = self._comm_symbols2i[j] + self._comm[ni][nj] = c + self._comm[nj][ni] = c + + """ + Cached sympy functions (e.g. expand) may have cached the results of + expressions involving tensors, but those results may not be valid after + changing the commutation properties. To stay on the safe side, we clear + the cache of all functions. + """ + clear_cache() + + def set_comms(self, *args): + """ + Set the commutation group numbers ``c`` for symbols ``i, j``. + + Parameters + ========== + + args : sequence of ``(i, j, c)`` + """ + for i, j, c in args: + self.set_comm(i, j, c) + + def get_comm(self, i, j): + """ + Return the commutation parameter for commutation group numbers ``i, j`` + + see ``_TensorManager.set_comm`` + """ + return self._comm[i].get(j, 0 if i == 0 or j == 0 else None) + + def clear(self): + """ + Clear the TensorManager. + """ + self._comm_init() + + +TensorManager = _TensorManager() + + +class TensorIndexType(Basic): + """ + A TensorIndexType is characterized by its name and its metric. + + Parameters + ========== + + name : name of the tensor type + dummy_name : name of the head of dummy indices + dim : dimension, it can be a symbol or an integer or ``None`` + eps_dim : dimension of the epsilon tensor + metric_symmetry : integer that denotes metric symmetry or ``None`` for no metric + metric_name : string with the name of the metric tensor + + Attributes + ========== + + ``metric`` : the metric tensor + ``delta`` : ``Kronecker delta`` + ``epsilon`` : the ``Levi-Civita epsilon`` tensor + ``data`` : (deprecated) a property to add ``ndarray`` values, to work in a specified basis. + + Notes + ===== + + The possible values of the ``metric_symmetry`` parameter are: + + ``1`` : metric tensor is fully symmetric + ``0`` : metric tensor possesses no index symmetry + ``-1`` : metric tensor is fully antisymmetric + ``None``: there is no metric tensor (metric equals to ``None``) + + The metric is assumed to be symmetric by default. It can also be set + to a custom tensor by the ``.set_metric()`` method. + + If there is a metric the metric is used to raise and lower indices. + + In the case of non-symmetric metric, the following raising and + lowering conventions will be adopted: + + ``psi(a) = g(a, b)*psi(-b); chi(-a) = chi(b)*g(-b, -a)`` + + From these it is easy to find: + + ``g(-a, b) = delta(-a, b)`` + + where ``delta(-a, b) = delta(b, -a)`` is the ``Kronecker delta`` + (see ``TensorIndex`` for the conventions on indices). + For antisymmetric metrics there is also the following equality: + + ``g(a, -b) = -delta(a, -b)`` + + If there is no metric it is not possible to raise or lower indices; + e.g. the index of the defining representation of ``SU(N)`` + is 'covariant' and the conjugate representation is + 'contravariant'; for ``N > 2`` they are linearly independent. + + ``eps_dim`` is by default equal to ``dim``, if the latter is an integer; + else it can be assigned (for use in naive dimensional regularization); + if ``eps_dim`` is not an integer ``epsilon`` is ``None``. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> Lorentz.metric + metric(Lorentz,Lorentz) + """ + + def __new__(cls, name, dummy_name=None, dim=None, eps_dim=None, + metric_symmetry=1, metric_name='metric', **kwargs): + if 'dummy_fmt' in kwargs: + dummy_fmt = kwargs['dummy_fmt'] + sympy_deprecation_warning( + f""" + The dummy_fmt keyword to TensorIndexType is deprecated. Use + dummy_name={dummy_fmt} instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorindextype-dummy-fmt", + ) + dummy_name = dummy_fmt + + if isinstance(name, str): + name = Symbol(name) + + if dummy_name is None: + dummy_name = str(name)[0] + if isinstance(dummy_name, str): + dummy_name = Symbol(dummy_name) + + if dim is None: + dim = Symbol("dim_" + dummy_name.name) + else: + dim = sympify(dim) + + if eps_dim is None: + eps_dim = dim + else: + eps_dim = sympify(eps_dim) + + metric_symmetry = sympify(metric_symmetry) + + if isinstance(metric_name, str): + metric_name = Symbol(metric_name) + + if 'metric' in kwargs: + SymPyDeprecationWarning( + """ + The 'metric' keyword argument to TensorIndexType is + deprecated. Use the 'metric_symmetry' keyword argument or the + TensorIndexType.set_metric() method instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorindextype-metric", + ) + metric = kwargs.get('metric') + if metric is not None: + if metric in (True, False, 0, 1): + metric_name = 'metric' + #metric_antisym = metric + else: + metric_name = metric.name + #metric_antisym = metric.antisym + + if metric: + metric_symmetry = -1 + else: + metric_symmetry = 1 + + obj = Basic.__new__(cls, name, dummy_name, dim, eps_dim, + metric_symmetry, metric_name) + + obj._autogenerated = [] + return obj + + @property + def name(self): + return self.args[0].name + + @property + def dummy_name(self): + return self.args[1].name + + @property + def dim(self): + return self.args[2] + + @property + def eps_dim(self): + return self.args[3] + + @memoize_property + def metric(self): + metric_symmetry = self.args[4] + metric_name = self.args[5] + if metric_symmetry is None: + return None + + if metric_symmetry == 0: + symmetry = TensorSymmetry.no_symmetry(2) + elif metric_symmetry == 1: + symmetry = TensorSymmetry.fully_symmetric(2) + elif metric_symmetry == -1: + symmetry = TensorSymmetry.fully_symmetric(-2) + + return TensorHead(metric_name, [self]*2, symmetry) + + @memoize_property + def delta(self): + return TensorHead('KD', [self]*2, TensorSymmetry.fully_symmetric(2)) + + @memoize_property + def epsilon(self): + if not isinstance(self.eps_dim, (SYMPY_INTS, Integer)): + return None + symmetry = TensorSymmetry.fully_symmetric(-self.eps_dim) + return TensorHead('Eps', [self]*self.eps_dim, symmetry) + + def set_metric(self, tensor): + self._metric = tensor + + def __lt__(self, other): + return self.name < other.name + + def __str__(self): + return self.name + + __repr__ = __str__ + + # Everything below this line is deprecated + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return _tensor_data_substitution_dict[self] + + @data.setter + def data(self, data): + deprecate_data() + # This assignment is a bit controversial, should metric components be assigned + # to the metric only or also to the TensorIndexType object? The advantage here + # is the ability to assign a 1D array and transform it to a 2D diagonal array. + from .array import MutableDenseNDimArray + + data = _TensorDataLazyEvaluator.parse_data(data) + if data.rank() > 2: + raise ValueError("data have to be of rank 1 (diagonal metric) or 2.") + if data.rank() == 1: + if self.dim.is_number: + nda_dim = data.shape[0] + if nda_dim != self.dim: + raise ValueError("Dimension mismatch") + + dim = data.shape[0] + newndarray = MutableDenseNDimArray.zeros(dim, dim) + for i, val in enumerate(data): + newndarray[i, i] = val + data = newndarray + dim1, dim2 = data.shape + if dim1 != dim2: + raise ValueError("Non-square matrix tensor.") + if self.dim.is_number: + if self.dim != dim1: + raise ValueError("Dimension mismatch") + _tensor_data_substitution_dict[self] = data + _tensor_data_substitution_dict.add_metric_data(self.metric, data) + with ignore_warnings(SymPyDeprecationWarning): + delta = self.get_kronecker_delta() + i1 = TensorIndex('i1', self) + i2 = TensorIndex('i2', self) + with ignore_warnings(SymPyDeprecationWarning): + delta(i1, -i2).data = _TensorDataLazyEvaluator.parse_data(eye(dim1)) + + @data.deleter + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + if self.metric in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self.metric] + + @deprecated( + """ + The TensorIndexType.get_kronecker_delta() method is deprecated. Use + the TensorIndexType.delta attribute instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorindextype-methods", + ) + def get_kronecker_delta(self): + sym2 = TensorSymmetry(get_symmetric_group_sgs(2)) + delta = TensorHead('KD', [self]*2, sym2) + return delta + + @deprecated( + """ + The TensorIndexType.get_epsilon() method is deprecated. Use + the TensorIndexType.epsilon attribute instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorindextype-methods", + ) + def get_epsilon(self): + if not isinstance(self._eps_dim, (SYMPY_INTS, Integer)): + return None + sym = TensorSymmetry(get_symmetric_group_sgs(self._eps_dim, 1)) + epsilon = TensorHead('Eps', [self]*self._eps_dim, sym) + return epsilon + + def _components_data_full_destroy(self): + """ + EXPERIMENTAL: do not rely on this API method. + + This destroys components data associated to the ``TensorIndexType``, if + any, specifically: + + * metric tensor data + * Kronecker tensor data + """ + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + + def delete_tensmul_data(key): + if key in _tensor_data_substitution_dict._substitutions_dict_tensmul: + del _tensor_data_substitution_dict._substitutions_dict_tensmul[key] + + # delete metric data: + delete_tensmul_data((self.metric, True, True)) + delete_tensmul_data((self.metric, True, False)) + delete_tensmul_data((self.metric, False, True)) + delete_tensmul_data((self.metric, False, False)) + + # delete delta tensor data: + delta = self.get_kronecker_delta() + if delta in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[delta] + + +class TensorIndex(Basic): + """ + Represents a tensor index + + Parameters + ========== + + name : name of the index, or ``True`` if you want it to be automatically assigned + tensor_index_type : ``TensorIndexType`` of the index + is_up : flag for contravariant index (is_up=True by default) + + Attributes + ========== + + ``name`` + ``tensor_index_type`` + ``is_up`` + + Notes + ===== + + Tensor indices are contracted with the Einstein summation convention. + + An index can be in contravariant or in covariant form; in the latter + case it is represented prepending a ``-`` to the index name. Adding + ``-`` to a covariant (is_up=False) index makes it contravariant. + + Dummy indices have a name with head given by + ``tensor_inde_type.dummy_name`` with underscore and a number. + + Similar to ``symbols`` multiple contravariant indices can be created + at once using ``tensor_indices(s, typ)``, where ``s`` is a string + of names. + + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, TensorIndex, TensorHead, tensor_indices + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> mu = TensorIndex('mu', Lorentz, is_up=False) + >>> nu, rho = tensor_indices('nu, rho', Lorentz) + >>> A = TensorHead('A', [Lorentz, Lorentz]) + >>> A(mu, nu) + A(-mu, nu) + >>> A(-mu, -rho) + A(mu, -rho) + >>> A(mu, -mu) + A(-L_0, L_0) + """ + def __new__(cls, name, tensor_index_type, is_up=True): + if isinstance(name, str): + name_symbol = Symbol(name) + elif isinstance(name, Symbol): + name_symbol = name + elif name is True: + name = "_i{}".format(len(tensor_index_type._autogenerated)) + name_symbol = Symbol(name) + tensor_index_type._autogenerated.append(name_symbol) + else: + raise ValueError("invalid name") + + is_up = sympify(is_up) + return Basic.__new__(cls, name_symbol, tensor_index_type, is_up) + + @property + def name(self): + return self.args[0].name + + @property + def tensor_index_type(self): + return self.args[1] + + @property + def is_up(self): + return self.args[2] + + def _print(self): + s = self.name + if not self.is_up: + s = '-%s' % s + return s + + def __lt__(self, other): + return ((self.tensor_index_type, self.name) < + (other.tensor_index_type, other.name)) + + def __neg__(self): + t1 = TensorIndex(self.name, self.tensor_index_type, + (not self.is_up)) + return t1 + + +def tensor_indices(s, typ): + """ + Returns list of tensor indices given their names and their types. + + Parameters + ========== + + s : string of comma separated names of indices + + typ : ``TensorIndexType`` of the indices + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> a, b, c, d = tensor_indices('a,b,c,d', Lorentz) + """ + if isinstance(s, str): + a = [x.name for x in symbols(s, seq=True)] + else: + raise ValueError('expecting a string') + + tilist = [TensorIndex(i, typ) for i in a] + if len(tilist) == 1: + return tilist[0] + return tilist + + +class TensorSymmetry(Basic): + """ + Monoterm symmetry of a tensor (i.e. any symmetric or anti-symmetric + index permutation). For the relevant terminology see ``tensor_can.py`` + section of the combinatorics module. + + Parameters + ========== + + bsgs : tuple ``(base, sgs)`` BSGS of the symmetry of the tensor + + Attributes + ========== + + ``base`` : base of the BSGS + ``generators`` : generators of the BSGS + ``rank`` : rank of the tensor + + Notes + ===== + + A tensor can have an arbitrary monoterm symmetry provided by its BSGS. + Multiterm symmetries, like the cyclic symmetry of the Riemann tensor + (i.e., Bianchi identity), are not covered. See combinatorics module for + information on how to generate BSGS for a general index permutation group. + Simple symmetries can be generated using built-in methods. + + See Also + ======== + + sympy.combinatorics.tensor_can.get_symmetric_group_sgs + + Examples + ======== + + Define a symmetric tensor of rank 2 + + >>> from sympy.tensor.tensor import TensorIndexType, TensorSymmetry, get_symmetric_group_sgs, TensorHead + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> sym = TensorSymmetry(get_symmetric_group_sgs(2)) + >>> T = TensorHead('T', [Lorentz]*2, sym) + + Note, that the same can also be done using built-in TensorSymmetry methods + + >>> sym2 = TensorSymmetry.fully_symmetric(2) + >>> sym == sym2 + True + """ + def __new__(cls, *args, **kw_args): + if len(args) == 1: + base, generators = args[0] + elif len(args) == 2: + base, generators = args + else: + raise TypeError("bsgs required, either two separate parameters or one tuple") + + if not isinstance(base, Tuple): + base = Tuple(*base) + if not isinstance(generators, Tuple): + generators = Tuple(*generators) + + return Basic.__new__(cls, base, generators, **kw_args) + + @property + def base(self): + return self.args[0] + + @property + def generators(self): + return self.args[1] + + @property + def rank(self): + return self.generators[0].size - 2 + + @classmethod + def fully_symmetric(cls, rank): + """ + Returns a fully symmetric (antisymmetric if ``rank``<0) + TensorSymmetry object for ``abs(rank)`` indices. + """ + if rank > 0: + bsgs = get_symmetric_group_sgs(rank, False) + elif rank < 0: + bsgs = get_symmetric_group_sgs(-rank, True) + elif rank == 0: + bsgs = ([], [Permutation(1)]) + return TensorSymmetry(bsgs) + + @classmethod + def direct_product(cls, *args): + """ + Returns a TensorSymmetry object that is being a direct product of + fully (anti-)symmetric index permutation groups. + + Notes + ===== + + Some examples for different values of ``(*args)``: + ``(1)`` vector, equivalent to ``TensorSymmetry.fully_symmetric(1)`` + ``(2)`` tensor with 2 symmetric indices, equivalent to ``.fully_symmetric(2)`` + ``(-2)`` tensor with 2 antisymmetric indices, equivalent to ``.fully_symmetric(-2)`` + ``(2, -2)`` tensor with the first 2 indices commuting and the last 2 anticommuting + ``(1, 1, 1)`` tensor with 3 indices without any symmetry + """ + base, sgs = [], [Permutation(1)] + for arg in args: + if arg > 0: + bsgs2 = get_symmetric_group_sgs(arg, False) + elif arg < 0: + bsgs2 = get_symmetric_group_sgs(-arg, True) + else: + continue + base, sgs = bsgs_direct_product(base, sgs, *bsgs2) + + return TensorSymmetry(base, sgs) + + @classmethod + def riemann(cls): + """ + Returns a monotorem symmetry of the Riemann tensor + """ + return TensorSymmetry(riemann_bsgs) + + @classmethod + def no_symmetry(cls, rank): + """ + TensorSymmetry object for ``rank`` indices with no symmetry + """ + return TensorSymmetry([], [Permutation(rank+1)]) + + +@deprecated( + """ + The tensorsymmetry() function is deprecated. Use the TensorSymmetry + constructor instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorsymmetry", +) +def tensorsymmetry(*args): + """ + Returns a ``TensorSymmetry`` object. This method is deprecated, use + ``TensorSymmetry.direct_product()`` or ``.riemann()`` instead. + + Explanation + =========== + + One can represent a tensor with any monoterm slot symmetry group + using a BSGS. + + ``args`` can be a BSGS + ``args[0]`` base + ``args[1]`` sgs + + Usually tensors are in (direct products of) representations + of the symmetric group; + ``args`` can be a list of lists representing the shapes of Young tableaux + + Notes + ===== + + For instance: + ``[[1]]`` vector + ``[[1]*n]`` symmetric tensor of rank ``n`` + ``[[n]]`` antisymmetric tensor of rank ``n`` + ``[[2, 2]]`` monoterm slot symmetry of the Riemann tensor + ``[[1],[1]]`` vector*vector + ``[[2],[1],[1]`` (antisymmetric tensor)*vector*vector + + Notice that with the shape ``[2, 2]`` we associate only the monoterm + symmetries of the Riemann tensor; this is an abuse of notation, + since the shape ``[2, 2]`` corresponds usually to the irreducible + representation characterized by the monoterm symmetries and by the + cyclic symmetry. + """ + from sympy.combinatorics import Permutation + + def tableau2bsgs(a): + if len(a) == 1: + # antisymmetric vector + n = a[0] + bsgs = get_symmetric_group_sgs(n, 1) + else: + if all(x == 1 for x in a): + # symmetric vector + n = len(a) + bsgs = get_symmetric_group_sgs(n) + elif a == [2, 2]: + bsgs = riemann_bsgs + else: + raise NotImplementedError + return bsgs + + if not args: + return TensorSymmetry(Tuple(), Tuple(Permutation(1))) + + if len(args) == 2 and isinstance(args[1][0], Permutation): + return TensorSymmetry(args) + base, sgs = tableau2bsgs(args[0]) + for a in args[1:]: + basex, sgsx = tableau2bsgs(a) + base, sgs = bsgs_direct_product(base, sgs, basex, sgsx) + return TensorSymmetry(Tuple(base, sgs)) + +@deprecated( + "TensorType is deprecated. Use tensor_heads() instead.", + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensortype", +) +class TensorType(Basic): + """ + Class of tensor types. Deprecated, use tensor_heads() instead. + + Parameters + ========== + + index_types : list of ``TensorIndexType`` of the tensor indices + symmetry : ``TensorSymmetry`` of the tensor + + Attributes + ========== + + ``index_types`` + ``symmetry`` + ``types`` : list of ``TensorIndexType`` without repetitions + """ + is_commutative = False + + def __new__(cls, index_types, symmetry, **kw_args): + assert symmetry.rank == len(index_types) + obj = Basic.__new__(cls, Tuple(*index_types), symmetry, **kw_args) + return obj + + @property + def index_types(self): + return self.args[0] + + @property + def symmetry(self): + return self.args[1] + + @property + def types(self): + return sorted(set(self.index_types), key=lambda x: x.name) + + def __str__(self): + return 'TensorType(%s)' % ([str(x) for x in self.index_types]) + + def __call__(self, s, comm=0): + """ + Return a TensorHead object or a list of TensorHead objects. + + Parameters + ========== + + s : name or string of names. + + comm : Commutation group. + + see ``_TensorManager.set_comm`` + """ + if isinstance(s, str): + names = [x.name for x in symbols(s, seq=True)] + else: + raise ValueError('expecting a string') + if len(names) == 1: + return TensorHead(names[0], self.index_types, self.symmetry, comm) + else: + return [TensorHead(name, self.index_types, self.symmetry, comm) for name in names] + + +@deprecated( + """ + The tensorhead() function is deprecated. Use tensor_heads() instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorhead", +) +def tensorhead(name, typ, sym=None, comm=0): + """ + Function generating tensorhead(s). This method is deprecated, + use TensorHead constructor or tensor_heads() instead. + + Parameters + ========== + + name : name or sequence of names (as in ``symbols``) + + typ : index types + + sym : same as ``*args`` in ``tensorsymmetry`` + + comm : commutation group number + see ``_TensorManager.set_comm`` + """ + if sym is None: + sym = [[1] for i in range(len(typ))] + with ignore_warnings(SymPyDeprecationWarning): + sym = tensorsymmetry(*sym) + return TensorHead(name, typ, sym, comm) + + +class TensorHead(Basic): + """ + Tensor head of the tensor. + + Parameters + ========== + + name : name of the tensor + index_types : list of TensorIndexType + symmetry : TensorSymmetry of the tensor + comm : commutation group number + + Attributes + ========== + + ``name`` + ``index_types`` + ``rank`` : total number of indices + ``symmetry`` + ``comm`` : commutation group + + Notes + ===== + + Similar to ``symbols`` multiple TensorHeads can be created using + ``tensorhead(s, typ, sym=None, comm=0)`` function, where ``s`` + is the string of names and ``sym`` is the monoterm tensor symmetry + (see ``tensorsymmetry``). + + A ``TensorHead`` belongs to a commutation group, defined by a + symbol on number ``comm`` (see ``_TensorManager.set_comm``); + tensors in a commutation group have the same commutation properties; + by default ``comm`` is ``0``, the group of the commuting tensors. + + Examples + ======== + + Define a fully antisymmetric tensor of rank 2: + + >>> from sympy.tensor.tensor import TensorIndexType, TensorHead, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> asym2 = TensorSymmetry.fully_symmetric(-2) + >>> A = TensorHead('A', [Lorentz, Lorentz], asym2) + + Examples with ndarray values, the components data assigned to the + ``TensorHead`` object are assumed to be in a fully-contravariant + representation. In case it is necessary to assign components data which + represents the values of a non-fully covariant tensor, see the other + examples. + + >>> from sympy.tensor.tensor import tensor_indices + >>> from sympy import diag + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> i0, i1 = tensor_indices('i0:2', Lorentz) + + Specify a replacement dictionary to keep track of the arrays to use for + replacements in the tensorial expression. The ``TensorIndexType`` is + associated to the metric used for contractions (in fully covariant form): + + >>> repl = {Lorentz: diag(1, -1, -1, -1)} + + Let's see some examples of working with components with the electromagnetic + tensor: + + >>> from sympy import symbols + >>> Ex, Ey, Ez, Bx, By, Bz = symbols('E_x E_y E_z B_x B_y B_z') + >>> c = symbols('c', positive=True) + + Let's define `F`, an antisymmetric tensor: + + >>> F = TensorHead('F', [Lorentz, Lorentz], asym2) + + Let's update the dictionary to contain the matrix to use in the + replacements: + + >>> repl.update({F(-i0, -i1): [ + ... [0, Ex/c, Ey/c, Ez/c], + ... [-Ex/c, 0, -Bz, By], + ... [-Ey/c, Bz, 0, -Bx], + ... [-Ez/c, -By, Bx, 0]]}) + + Now it is possible to retrieve the contravariant form of the Electromagnetic + tensor: + + >>> F(i0, i1).replace_with_arrays(repl, [i0, i1]) + [[0, -E_x/c, -E_y/c, -E_z/c], [E_x/c, 0, -B_z, B_y], [E_y/c, B_z, 0, -B_x], [E_z/c, -B_y, B_x, 0]] + + and the mixed contravariant-covariant form: + + >>> F(i0, -i1).replace_with_arrays(repl, [i0, -i1]) + [[0, E_x/c, E_y/c, E_z/c], [E_x/c, 0, B_z, -B_y], [E_y/c, -B_z, 0, B_x], [E_z/c, B_y, -B_x, 0]] + + Energy-momentum of a particle may be represented as: + + >>> from sympy import symbols + >>> P = TensorHead('P', [Lorentz], TensorSymmetry.no_symmetry(1)) + >>> E, px, py, pz = symbols('E p_x p_y p_z', positive=True) + >>> repl.update({P(i0): [E, px, py, pz]}) + + The contravariant and covariant components are, respectively: + + >>> P(i0).replace_with_arrays(repl, [i0]) + [E, p_x, p_y, p_z] + >>> P(-i0).replace_with_arrays(repl, [-i0]) + [E, -p_x, -p_y, -p_z] + + The contraction of a 1-index tensor by itself: + + >>> expr = P(i0)*P(-i0) + >>> expr.replace_with_arrays(repl, []) + E**2 - p_x**2 - p_y**2 - p_z**2 + """ + is_commutative = False + + def __new__(cls, name, index_types, symmetry=None, comm=0): + if isinstance(name, str): + name_symbol = Symbol(name) + elif isinstance(name, Symbol): + name_symbol = name + else: + raise ValueError("invalid name") + + if symmetry is None: + symmetry = TensorSymmetry.no_symmetry(len(index_types)) + else: + assert symmetry.rank == len(index_types) + + obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), symmetry, sympify(comm)) + return obj + + @property + def name(self): + return self.args[0].name + + @property + def index_types(self): + return list(self.args[1]) + + @property + def symmetry(self): + return self.args[2] + + @property + def comm(self): + return TensorManager.comm_symbols2i(self.args[3]) + + @property + def rank(self): + return len(self.index_types) + + def __lt__(self, other): + return (self.name, self.index_types) < (other.name, other.index_types) + + def commutes_with(self, other): + """ + Returns ``0`` if ``self`` and ``other`` commute, ``1`` if they anticommute. + + Returns ``None`` if ``self`` and ``other`` neither commute nor anticommute. + """ + r = TensorManager.get_comm(self.comm, other.comm) + return r + + def _print(self): + return '%s(%s)' %(self.name, ','.join([str(x) for x in self.index_types])) + + def __call__(self, *indices, **kw_args): + """ + Returns a tensor with indices. + + Explanation + =========== + + There is a special behavior in case of indices denoted by ``True``, + they are considered auto-matrix indices, their slots are automatically + filled, and confer to the tensor the behavior of a matrix or vector + upon multiplication with another tensor containing auto-matrix indices + of the same ``TensorIndexType``. This means indices get summed over the + same way as in matrix multiplication. For matrix behavior, define two + auto-matrix indices, for vector behavior define just one. + + Indices can also be strings, in which case the attribute + ``index_types`` is used to convert them to proper ``TensorIndex``. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorSymmetry, TensorHead + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> a, b = tensor_indices('a,b', Lorentz) + >>> A = TensorHead('A', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + >>> t = A(a, -b) + >>> t + A(a, -b) + + """ + + updated_indices = [] + for idx, typ in zip(indices, self.index_types): + if isinstance(idx, str): + idx = idx.strip().replace(" ", "") + if idx.startswith('-'): + updated_indices.append(TensorIndex(idx[1:], typ, + is_up=False)) + else: + updated_indices.append(TensorIndex(idx, typ)) + else: + updated_indices.append(idx) + + updated_indices += indices[len(updated_indices):] + + tensor = Tensor(self, updated_indices, **kw_args) + return tensor.doit() + + # Everything below this line is deprecated + + def __pow__(self, other): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self.data is None: + raise ValueError("No power on abstract tensors.") + from .array import tensorproduct, tensorcontraction + metrics = [_.data for _ in self.index_types] + + marray = self.data + marraydim = marray.rank() + for metric in metrics: + marray = tensorproduct(marray, metric, marray) + marray = tensorcontraction(marray, (0, marraydim), (marraydim+1, marraydim+2)) + + return marray ** (other * S.Half) + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return _tensor_data_substitution_dict[self] + + @data.setter + def data(self, data): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + _tensor_data_substitution_dict[self] = data + + @data.deleter + def data(self): + deprecate_data() + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + + def __iter__(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data.__iter__() + + def _components_data_full_destroy(self): + """ + EXPERIMENTAL: do not rely on this API method. + + Destroy components data associated to the ``TensorHead`` object, this + checks for attached components data, and destroys components data too. + """ + # do not garbage collect Kronecker tensor (it should be done by + # ``TensorIndexType`` garbage collection) + deprecate_data() + if self.name == "KD": + return + + # the data attached to a tensor must be deleted only by the TensorHead + # destructor. If the TensorHead is deleted, it means that there are no + # more instances of that tensor anywhere. + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + + +def tensor_heads(s, index_types, symmetry=None, comm=0): + """ + Returns a sequence of TensorHeads from a string `s` + """ + if isinstance(s, str): + names = [x.name for x in symbols(s, seq=True)] + else: + raise ValueError('expecting a string') + + thlist = [TensorHead(name, index_types, symmetry, comm) for name in names] + if len(thlist) == 1: + return thlist[0] + return thlist + + +class TensExpr(Expr, ABC): + """ + Abstract base class for tensor expressions + + Notes + ===== + + A tensor expression is an expression formed by tensors; + currently the sums of tensors are distributed. + + A ``TensExpr`` can be a ``TensAdd`` or a ``TensMul``. + + ``TensMul`` objects are formed by products of component tensors, + and include a coefficient, which is a SymPy expression. + + + In the internal representation contracted indices are represented + by ``(ipos1, ipos2, icomp1, icomp2)``, where ``icomp1`` is the position + of the component tensor with contravariant index, ``ipos1`` is the + slot which the index occupies in that component tensor. + + Contracted indices are therefore nameless in the internal representation. + """ + + _op_priority = 12.0 + is_commutative = False + + def __neg__(self): + return self*S.NegativeOne + + def __abs__(self): + raise NotImplementedError + + def __add__(self, other): + return TensAdd(self, other).doit(deep=False) + + def __radd__(self, other): + return TensAdd(other, self).doit(deep=False) + + def __sub__(self, other): + return TensAdd(self, -other).doit(deep=False) + + def __rsub__(self, other): + return TensAdd(other, -self).doit(deep=False) + + def __mul__(self, other): + """ + Multiply two tensors using Einstein summation convention. + + Explanation + =========== + + If the two tensors have an index in common, one contravariant + and the other covariant, in their product the indices are summed + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> g = Lorentz.metric + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t1 = p(m0) + >>> t2 = q(-m0) + >>> t1*t2 + p(L_0)*q(-L_0) + """ + return TensMul(self, other).doit(deep=False) + + def __rmul__(self, other): + return TensMul(other, self).doit(deep=False) + + def __truediv__(self, other): + other = _sympify(other) + if isinstance(other, TensExpr): + raise ValueError('cannot divide by a tensor') + return TensMul(self, S.One/other).doit(deep=False) + + def __rtruediv__(self, other): + raise ValueError('cannot divide by a tensor') + + def __pow__(self, other): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self.data is None: + raise ValueError("No power without ndarray data.") + from .array import tensorproduct, tensorcontraction + free = self.free + marray = self.data + mdim = marray.rank() + for metric in free: + marray = tensorcontraction( + tensorproduct( + marray, + metric[0].tensor_index_type.data, + marray), + (0, mdim), (mdim+1, mdim+2) + ) + return marray ** (other * S.Half) + + def __rpow__(self, other): + raise NotImplementedError + + @property + @abstractmethod + def nocoeff(self): + raise NotImplementedError("abstract method") + + @property + @abstractmethod + def coeff(self): + raise NotImplementedError("abstract method") + + @abstractmethod + def get_indices(self): + raise NotImplementedError("abstract method") + + @abstractmethod + def get_free_indices(self) -> list[TensorIndex]: + raise NotImplementedError("abstract method") + + @abstractmethod + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + raise NotImplementedError("abstract method") + + def fun_eval(self, *index_tuples): + deprecate_fun_eval() + return self.substitute_indices(*index_tuples) + + def get_matrix(self): + """ + DEPRECATED: do not use. + + Returns ndarray components data as a matrix, if components data are + available and ndarray dimension does not exceed 2. + """ + from sympy.matrices.dense import Matrix + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if 0 < self.rank <= 2: + rows = self.data.shape[0] + columns = self.data.shape[1] if self.rank == 2 else 1 + if self.rank == 2: + mat_list = [] * rows + for i in range(rows): + mat_list.append([]) + for j in range(columns): + mat_list[i].append(self[i, j]) + else: + mat_list = [None] * rows + for i in range(rows): + mat_list[i] = self[i] + return Matrix(mat_list) + else: + raise NotImplementedError( + "missing multidimensional reduction to matrix.") + + @staticmethod + def _get_indices_permutation(indices1, indices2): + return [indices1.index(i) for i in indices2] + + def _get_free_indices_set(self): + indset = set() + for arg in self.args: + if isinstance(arg, TensExpr): + indset.update(arg._get_free_indices_set()) + return indset + + def _get_dummy_indices_set(self): + indset = set() + for arg in self.args: + if isinstance(arg, TensExpr): + indset.update(arg._get_dummy_indices_set()) + return indset + + def _get_indices_set(self): + indset = set() + for arg in self.args: + if isinstance(arg, TensExpr): + indset.update(arg._get_indices_set()) + return indset + + @property + def _iterate_dummy_indices(self): + dummy_set = self._get_dummy_indices_set() + + def recursor(expr, pos): + if isinstance(expr, TensorIndex): + if expr in dummy_set: + yield (expr, pos) + elif isinstance(expr, (Tuple, TensExpr)): + for p, arg in enumerate(expr.args): + yield from recursor(arg, pos+(p,)) + + return recursor(self, ()) + + @property + def _iterate_free_indices(self): + free_set = self._get_free_indices_set() + + def recursor(expr, pos): + if isinstance(expr, TensorIndex): + if expr in free_set: + yield (expr, pos) + elif isinstance(expr, (Tuple, TensExpr)): + for p, arg in enumerate(expr.args): + yield from recursor(arg, pos+(p,)) + + return recursor(self, ()) + + @property + def _iterate_indices(self): + def recursor(expr, pos): + if isinstance(expr, TensorIndex): + yield (expr, pos) + elif isinstance(expr, (Tuple, TensExpr)): + for p, arg in enumerate(expr.args): + yield from recursor(arg, pos+(p,)) + + return recursor(self, ()) + + @staticmethod + def _contract_and_permute_with_metric(metric, array, pos, dim): + # TODO: add possibility of metric after (spinors) + from .array import tensorcontraction, tensorproduct, permutedims + + array = tensorcontraction(tensorproduct(metric, array), (1, 2+pos)) + permu = list(range(dim)) + permu[0], permu[pos] = permu[pos], permu[0] + return permutedims(array, permu) + + @staticmethod + def _match_indices_with_other_tensor(array, free_ind1, free_ind2, replacement_dict): + from .array import permutedims + + index_types1 = [i.tensor_index_type for i in free_ind1] + + # Check if variance of indices needs to be fixed: + pos2up = [] + pos2down = [] + free2remaining = free_ind2[:] + for pos1, index1 in enumerate(free_ind1): + if index1 in free2remaining: + pos2 = free2remaining.index(index1) + free2remaining[pos2] = None + continue + if -index1 in free2remaining: + pos2 = free2remaining.index(-index1) + free2remaining[pos2] = None + free_ind2[pos2] = index1 + if index1.is_up: + pos2up.append(pos2) + else: + pos2down.append(pos2) + else: + index2 = free2remaining[pos1] + if index2 is None: + raise ValueError("incompatible indices: %s and %s" % (free_ind1, free_ind2)) + free2remaining[pos1] = None + free_ind2[pos1] = index1 + if index1.is_up ^ index2.is_up: + if index1.is_up: + pos2up.append(pos1) + else: + pos2down.append(pos1) + + if len(set(free_ind1) & set(free_ind2)) < len(free_ind1): + raise ValueError("incompatible indices: %s and %s" % (free_ind1, free_ind2)) + + # Raise indices: + for pos in pos2up: + index_type_pos = index_types1[pos] + if index_type_pos not in replacement_dict: + raise ValueError("No metric provided to lower index") + metric = replacement_dict[index_type_pos] + metric_inverse = _TensorDataLazyEvaluator.inverse_matrix(metric) + array = TensExpr._contract_and_permute_with_metric(metric_inverse, array, pos, len(free_ind1)) + # Lower indices: + for pos in pos2down: + index_type_pos = index_types1[pos] + if index_type_pos not in replacement_dict: + raise ValueError("No metric provided to lower index") + metric = replacement_dict[index_type_pos] + array = TensExpr._contract_and_permute_with_metric(metric, array, pos, len(free_ind1)) + + if free_ind1: + permutation = TensExpr._get_indices_permutation(free_ind2, free_ind1) + array = permutedims(array, permutation) + + if hasattr(array, "rank") and array.rank() == 0: + array = array[()] + + return free_ind2, array + + def replace_with_arrays(self, replacement_dict, indices=None): + """ + Replace the tensorial expressions with arrays. The final array will + correspond to the N-dimensional array with indices arranged according + to ``indices``. + + Parameters + ========== + + replacement_dict + dictionary containing the replacement rules for tensors. + indices + the index order with respect to which the array is read. The + original index order will be used if no value is passed. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices + >>> from sympy.tensor.tensor import TensorHead + >>> from sympy import symbols, diag + + >>> L = TensorIndexType("L") + >>> i, j = tensor_indices("i j", L) + >>> A = TensorHead("A", [L]) + >>> A(i).replace_with_arrays({A(i): [1, 2]}, [i]) + [1, 2] + + Since 'indices' is optional, we can also call replace_with_arrays by + this way if no specific index order is needed: + + >>> A(i).replace_with_arrays({A(i): [1, 2]}) + [1, 2] + + >>> expr = A(i)*A(j) + >>> expr.replace_with_arrays({A(i): [1, 2]}) + [[1, 2], [2, 4]] + + For contractions, specify the metric of the ``TensorIndexType``, which + in this case is ``L``, in its covariant form: + + >>> expr = A(i)*A(-i) + >>> expr.replace_with_arrays({A(i): [1, 2], L: diag(1, -1)}) + -3 + + Symmetrization of an array: + + >>> H = TensorHead("H", [L, L]) + >>> a, b, c, d = symbols("a b c d") + >>> expr = H(i, j)/2 + H(j, i)/2 + >>> expr.replace_with_arrays({H(i, j): [[a, b], [c, d]]}) + [[a, b/2 + c/2], [b/2 + c/2, d]] + + Anti-symmetrization of an array: + + >>> expr = H(i, j)/2 - H(j, i)/2 + >>> repl = {H(i, j): [[a, b], [c, d]]} + >>> expr.replace_with_arrays(repl) + [[0, b/2 - c/2], [-b/2 + c/2, 0]] + + The same expression can be read as the transpose by inverting ``i`` and + ``j``: + + >>> expr.replace_with_arrays(repl, [j, i]) + [[0, -b/2 + c/2], [b/2 - c/2, 0]] + """ + from .array import Array + + indices = indices or [] + remap = {k.args[0] if k.is_up else -k.args[0]: k for k in self.get_free_indices()} + for i, index in enumerate(indices): + if isinstance(index, (Symbol, Mul)): + if index in remap: + indices[i] = remap[index] + else: + indices[i] = -remap[-index] + + replacement_dict = {tensor: Array(array) for tensor, array in replacement_dict.items()} + + # Check dimensions of replaced arrays: + for tensor, array in replacement_dict.items(): + if isinstance(tensor, TensorIndexType): + expected_shape = [tensor.dim for i in range(2)] + else: + expected_shape = [index_type.dim for index_type in tensor.index_types] + if len(expected_shape) != array.rank() or (not all(dim1 == dim2 if + dim1.is_number else True for dim1, dim2 in zip(expected_shape, + array.shape))): + raise ValueError("shapes for tensor %s expected to be %s, "\ + "replacement array shape is %s" % (tensor, expected_shape, + array.shape)) + + ret_indices, array = self._extract_data(replacement_dict) + + last_indices, array = self._match_indices_with_other_tensor(array, indices, ret_indices, replacement_dict) + return array + + def _check_add_Sum(self, expr, index_symbols): + from sympy.concrete.summations import Sum + indices = self.get_indices() + dum = self.dum + sum_indices = [ (index_symbols[i], 0, + indices[i].tensor_index_type.dim-1) for i, j in dum] + if sum_indices: + expr = Sum(expr, *sum_indices) + return expr + + def _expand_partial_derivative(self): + # simply delegate the _expand_partial_derivative() to + # its arguments to expand a possibly found PartialDerivative + return self.func(*[ + a._expand_partial_derivative() + if isinstance(a, TensExpr) else a + for a in self.args]) + + def _matches_simple(self, expr, repl_dict=None, old=False): + """ + Matches assuming there are no wild objects in self. + """ + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + if not isinstance(expr, TensExpr): + if len(self.get_free_indices()) > 0: + #self has indices, but expr does not. + return None + elif set(self.get_free_indices()) != set(expr.get_free_indices()): + #If there are no wilds and the free indices are not the same, they cannot match. + return None + + if canon_bp(self - expr) == S.Zero: + return repl_dict + else: + return None + + +class TensAdd(TensExpr, AssocOp): + """ + Sum of tensors. + + Parameters + ========== + + free_args : list of the free indices + + Attributes + ========== + + ``args`` : tuple of addends + ``rank`` : rank of the tensor + ``free_args`` : list of the free indices in sorted order + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_heads, tensor_indices + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> a, b = tensor_indices('a,b', Lorentz) + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t = p(a) + q(a); t + p(a) + q(a) + + Examples with components data added to the tensor expression: + + >>> from sympy import symbols, diag + >>> x, y, z, t = symbols("x y z t") + >>> repl = {} + >>> repl[Lorentz] = diag(1, -1, -1, -1) + >>> repl[p(a)] = [1, 2, 3, 4] + >>> repl[q(a)] = [x, y, z, t] + + The following are: 2**2 - 3**2 - 2**2 - 7**2 ==> -58 + + >>> expr = p(a) + q(a) + >>> expr.replace_with_arrays(repl, [a]) + [x + 1, y + 2, z + 3, t + 4] + """ + + def __new__(cls, *args, **kw_args): + args = [_sympify(x) for x in args if x] + args = TensAdd._tensAdd_flatten(args) + args.sort(key=default_sort_key) + if not args: + return S.Zero + if len(args) == 1: + return args[0] + + return Basic.__new__(cls, *args, **kw_args) + + @property + def coeff(self): + return S.One + + @property + def nocoeff(self): + return self + + def get_free_indices(self) -> list[TensorIndex]: + return self.free_indices + + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + newargs = [arg._replace_indices(repl) if isinstance(arg, TensExpr) else arg for arg in self.args] + return self.func(*newargs) + + @memoize_property + def rank(self): + if isinstance(self.args[0], TensExpr): + return self.args[0].rank + else: + return 0 + + @memoize_property + def free_args(self): + if isinstance(self.args[0], TensExpr): + return self.args[0].free_args + else: + return [] + + @memoize_property + def free_indices(self): + if isinstance(self.args[0], TensExpr): + return self.args[0].get_free_indices() + else: + return set() + + def doit(self, **hints) -> Expr: + deep = hints.get('deep', True) + if deep: + args = [arg.doit(**hints) for arg in self.args] + else: + args = self.args # type: ignore + + # if any of the args are zero (after doit), drop them. Otherwise, _tensAdd_check will complain about non-matching indices, even though the TensAdd is correctly formed. + args = [arg for arg in args if arg != S.Zero] + + if len(args) == 0: + return S.Zero + elif len(args) == 1: + return args[0] + + # now check that all addends have the same indices: + TensAdd._tensAdd_check(args) + + # Collect terms appearing more than once, differing by their coefficients: + args = TensAdd._tensAdd_collect_terms(args) + + # collect canonicalized terms + def sort_key(t): + if not isinstance(t, TensExpr): + return [], [], [] + if hasattr(t, "_index_structure") and hasattr(t, "components"): + x = get_index_structure(t) + return t.components, x.free, x.dum + return [], [], [] + args.sort(key=sort_key) + + if not args: + return S.Zero + # it there is only a component tensor return it + if len(args) == 1: + return args[0] + + obj = self.func(*args) + return obj + + @staticmethod + def _tensAdd_flatten(args): + # flatten TensAdd, coerce terms which are not tensors to tensors + a = [] + for x in args: + if isinstance(x, (Add, TensAdd)): + a.extend(list(x.args)) + else: + a.append(x) + args = [x for x in a if x.coeff] + return args + + @staticmethod + def _tensAdd_check(args): + # check that all addends have the same free indices + + def get_indices_set(x: Expr) -> set[TensorIndex]: + if isinstance(x, TensExpr): + return set(x.get_free_indices()) + return set() + + indices0 = get_indices_set(args[0]) + list_indices = [get_indices_set(arg) for arg in args[1:]] + if not all(x == indices0 for x in list_indices): + raise ValueError('all tensors must have the same indices') + + @staticmethod + def _tensAdd_collect_terms(args): + # collect TensMul terms differing at most by their coefficient + terms_dict = defaultdict(list) + scalars = S.Zero + if isinstance(args[0], TensExpr): + free_indices = set(args[0].get_free_indices()) + else: + free_indices = set() + + for arg in args: + if not isinstance(arg, TensExpr): + if free_indices != set(): + raise ValueError("wrong valence") + scalars += arg + continue + if free_indices != set(arg.get_free_indices()): + raise ValueError("wrong valence") + # TODO: what is the part which is not a coeff? + # needs an implementation similar to .as_coeff_Mul() + terms_dict[arg.nocoeff].append(arg.coeff) + + new_args = [TensMul(Add(*coeff), t).doit(deep=False) for t, coeff in terms_dict.items() if Add(*coeff) != 0] + if isinstance(scalars, Add): + new_args = list(scalars.args) + new_args + elif scalars != 0: + new_args = [scalars] + new_args + return new_args + + def get_indices(self): + indices = [] + for arg in self.args: + indices.extend([i for i in get_indices(arg) if i not in indices]) + return indices + + + def __call__(self, *indices): + deprecate_call() + free_args = self.free_args + indices = list(indices) + if [x.tensor_index_type for x in indices] != [x.tensor_index_type for x in free_args]: + raise ValueError('incompatible types') + if indices == free_args: + return self + index_tuples = list(zip(free_args, indices)) + a = [x.func(*x.substitute_indices(*index_tuples).args) for x in self.args] + res = TensAdd(*a).doit(deep=False) + return res + + def canon_bp(self): + """ + Canonicalize using the Butler-Portugal algorithm for canonicalization + under monoterm symmetries. + """ + expr = self.expand() + if isinstance(expr, self.func): + args = [canon_bp(x) for x in expr.args] + res = TensAdd(*args).doit(deep=False) + return res + else: + return canon_bp(expr) + + def equals(self, other): + other = _sympify(other) + if isinstance(other, TensMul) and other.coeff == 0: + return all(x.coeff == 0 for x in self.args) + if isinstance(other, TensExpr): + if self.rank != other.rank: + return False + if isinstance(other, TensAdd): + if set(self.args) != set(other.args): + return False + else: + return True + t = self - other + if not isinstance(t, TensExpr): + return t == 0 + else: + if isinstance(t, TensMul): + return t.coeff == 0 + else: + return all(x.coeff == 0 for x in t.args) + + def __getitem__(self, item): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data[item] + + def contract_delta(self, delta): + args = [x.contract_delta(delta) if isinstance(x, TensExpr) else x for x in self.args] + t = TensAdd(*args).doit(deep=False) + return canon_bp(t) + + def contract_metric(self, g): + """ + Raise or lower indices with the metric ``g``. + + Parameters + ========== + + g : metric + + contract_all : if True, eliminate all ``g`` which are contracted + + Notes + ===== + + see the ``TensorIndexType`` docstring for the contraction conventions + """ + + args = [contract_metric(x, g) for x in self.args] + t = TensAdd(*args).doit(deep=False) + return canon_bp(t) + + def substitute_indices(self, *index_tuples): + new_args = [] + for arg in self.args: + if isinstance(arg, TensExpr): + arg = arg.substitute_indices(*index_tuples) + new_args.append(arg) + return TensAdd(*new_args).doit(deep=False) + + def _print(self): + a = [] + args = self.args + for x in args: + a.append(str(x)) + s = ' + '.join(a) + s = s.replace('+ -', '- ') + return s + + def _extract_data(self, replacement_dict): + from sympy.tensor.array import Array, permutedims + args_indices, arrays = zip(*[ + arg._extract_data(replacement_dict) if + isinstance(arg, TensExpr) else ([], arg) for arg in self.args + ]) + arrays = [Array(i) for i in arrays] + ref_indices = args_indices[0] + for i in range(1, len(args_indices)): + indices = args_indices[i] + array = arrays[i] + permutation = TensMul._get_indices_permutation(indices, ref_indices) + arrays[i] = permutedims(array, permutation) + return ref_indices, sum(arrays, Array.zeros(*array.shape)) + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return _tensor_data_substitution_dict[self.expand()] + + @data.setter + def data(self, data): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + _tensor_data_substitution_dict[self] = data + + @data.deleter + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + + def __iter__(self): + deprecate_data() + if not self.data: + raise ValueError("No iteration on abstract tensors") + return self.data.flatten().__iter__() + + def _eval_rewrite_as_Indexed(self, *args, **kwargs): + return Add.fromiter(args) + + def _eval_partial_derivative(self, s): + # Evaluation like Add + list_addends = [] + for a in self.args: + if isinstance(a, TensExpr): + list_addends.append(a._eval_partial_derivative(s)) + # do not call diff if s is no symbol + elif s._diff_wrt: + list_addends.append(a._eval_derivative(s)) + + return self.func(*list_addends).doit(deep=False) + + def matches(self, expr, repl_dict=None, old=False): + expr = sympify(expr) + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + if not isinstance(expr, TensAdd): + return None + + if len(_get_wilds(self)) == 0: + return self._matches_simple(expr, repl_dict, old) + + def siftkey(arg): + wildatoms = _get_wilds(arg) + wildatom_types = sift(wildatoms, type) + if len(wildatoms) == 0: + return "nonwild" + elif WildTensor in wildatom_types.keys(): + for w in wildatom_types["WildTensor"]: + if len(w.get_indices()) == 0: + return "indexless_wildtensor" + return "wildtensor" + else: + return "otherwild" + + query_sifted = sift(self.args, siftkey) + expr_sifted = sift(expr.args, siftkey) + + #First try to match the terms without WildTensors + matched_e_tensors = [] #Used to make sure that the same tensor in expr is not matched with more than one tensor in self. + for q_tensor in query_sifted["nonwild"]: + matched_this_q = False + for e_tensor in expr_sifted["nonwild"]: + if e_tensor in matched_e_tensors: + continue + + m = q_tensor.matches(e_tensor, repl_dict=repl_dict, old=old) + if m is None: + continue + else: + matched_this_q = True + repl_dict.update(m) + matched_e_tensors.append(e_tensor) + break + + if not matched_this_q: + return None + + remaining_e_tensors = [t for t in expr_sifted["nonwild"] if t not in matched_e_tensors] + for w in query_sifted["otherwild"]: + for e in remaining_e_tensors: + m = w.matches(e) + if m is not None: + matched_e_tensors.append(e) + if w in repl_dict.keys(): + repl_dict[w] += m.pop(w) + repl_dict.update(m) + + remaining_e_tensors = [t for t in expr_sifted["nonwild"] if t not in matched_e_tensors] + for w in query_sifted["wildtensor"]: + for e in remaining_e_tensors: + m = w.matches(e) + if m is not None: + matched_e_tensors.append(e) + if w.component in repl_dict.keys(): + repl_dict[w.component] += m.pop(w.component) + repl_dict.update(m) + + remaining_e_tensors = [t for t in expr_sifted["nonwild"] if t not in matched_e_tensors] + for w in query_sifted["indexless_wildtensor"]: + for e in remaining_e_tensors: + m = w.matches(e) + if m is not None: + matched_e_tensors.append(e) + if w.component in repl_dict.keys(): + repl_dict[w.component] += m.pop(w.component) + repl_dict.update(m) + + remaining_e_tensors = [t for t in expr_sifted["nonwild"] if t not in matched_e_tensors] + if len(remaining_e_tensors) > 0: + return None + else: + return repl_dict + + +class Tensor(TensExpr): + """ + Base tensor class, i.e. this represents a tensor, the single unit to be + put into an expression. + + Explanation + =========== + + This object is usually created from a ``TensorHead``, by attaching indices + to it. Indices preceded by a minus sign are considered contravariant, + otherwise covariant. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead + >>> Lorentz = TensorIndexType("Lorentz", dummy_name="L") + >>> mu, nu = tensor_indices('mu nu', Lorentz) + >>> A = TensorHead("A", [Lorentz, Lorentz]) + >>> A(mu, -nu) + A(mu, -nu) + >>> A(mu, -mu) + A(L_0, -L_0) + + It is also possible to use symbols instead of inidices (appropriate indices + are then generated automatically). + + >>> from sympy import Symbol + >>> x = Symbol('x') + >>> A(x, mu) + A(x, mu) + >>> A(x, -x) + A(L_0, -L_0) + + """ + + is_commutative = False + + _index_structure: _IndexStructure + args: tuple[TensorHead, Tuple] + + def __new__(cls, tensor_head, indices, *, is_canon_bp=False, **kw_args): + indices = cls._parse_indices(tensor_head, indices) + obj = Basic.__new__(cls, tensor_head, Tuple(*indices), **kw_args) + obj._index_structure = _IndexStructure.from_indices(*indices) + obj._free = obj._index_structure.free[:] + obj._dum = obj._index_structure.dum[:] + obj._ext_rank = obj._index_structure._ext_rank + obj._coeff = S.One + obj._nocoeff = obj + obj._component = tensor_head + obj._components = [tensor_head] + if tensor_head.rank != len(indices): + raise ValueError("wrong number of indices") + obj.is_canon_bp = is_canon_bp + obj._index_map = Tensor._build_index_map(indices, obj._index_structure) + return obj + + @property + def free(self): + return self._free + + @property + def dum(self): + return self._dum + + @property + def ext_rank(self): + return self._ext_rank + + @property + def coeff(self): + return self._coeff + + @property + def nocoeff(self): + return self._nocoeff + + @property + def component(self): + return self._component + + @property + def components(self): + return self._components + + @property + def head(self): + return self.args[0] + + @property + def indices(self): + return self.args[1] + + @property + def free_indices(self): + return set(self._index_structure.get_free_indices()) + + @property + def index_types(self): + return self.head.index_types + + @property + def rank(self): + return len(self.free_indices) + + @staticmethod + def _build_index_map(indices, index_structure): + index_map = {} + for idx in indices: + index_map[idx] = (indices.index(idx),) + return index_map + + def doit(self, **hints): + args, indices, free, dum = TensMul._tensMul_contract_indices([self]) + return args[0] + + @staticmethod + def _parse_indices(tensor_head, indices): + if not isinstance(indices, (tuple, list, Tuple)): + raise TypeError("indices should be an array, got %s" % type(indices)) + indices = list(indices) + for i, index in enumerate(indices): + if isinstance(index, Symbol): + indices[i] = TensorIndex(index, tensor_head.index_types[i], True) + elif isinstance(index, Mul): + c, e = index.as_coeff_Mul() + if c == -1 and isinstance(e, Symbol): + indices[i] = TensorIndex(e, tensor_head.index_types[i], False) + else: + raise ValueError("index not understood: %s" % index) + elif not isinstance(index, TensorIndex): + raise TypeError("wrong type for index: %s is %s" % (index, type(index))) + return indices + + def _set_new_index_structure(self, im, is_canon_bp=False): + indices = im.get_indices() + return self._set_indices(*indices, is_canon_bp=is_canon_bp) + + def _set_indices(self, *indices, is_canon_bp=False, **kw_args): + if len(indices) != self.ext_rank: + raise ValueError("indices length mismatch") + return self.func(self.args[0], indices, is_canon_bp=is_canon_bp).doit() + + def _get_free_indices_set(self): + return {i[0] for i in self._index_structure.free} + + def _get_dummy_indices_set(self): + dummy_pos = set(itertools.chain(*self._index_structure.dum)) + return {idx for i, idx in enumerate(self.args[1]) if i in dummy_pos} + + def _get_indices_set(self): + return set(self.args[1].args) + + @property + def free_in_args(self): + return [(ind, pos, 0) for ind, pos in self.free] + + @property + def dum_in_args(self): + return [(p1, p2, 0, 0) for p1, p2 in self.dum] + + @property + def free_args(self): + return sorted([x[0] for x in self.free]) + + def commutes_with(self, other): + """ + :param other: + :return: + 0 commute + 1 anticommute + None neither commute nor anticommute + """ + if not isinstance(other, TensExpr): + return 0 + elif isinstance(other, Tensor): + return self.component.commutes_with(other.component) + return NotImplementedError + + def perm2tensor(self, g, is_canon_bp=False): + """ + Returns the tensor corresponding to the permutation ``g``. + + For further details, see the method in ``TIDS`` with the same name. + """ + return perm2tensor(self, g, is_canon_bp) + + def canon_bp(self): + if self.is_canon_bp: + return self + expr = self.expand() + g, dummies, msym = expr._index_structure.indices_canon_args() + v = components_canon_args([expr.component]) + can = canonicalize(g, dummies, msym, *v) + if can == 0: + return S.Zero + tensor = self.perm2tensor(can, True) + return tensor + + def split(self): + return [self] + + def sorted_components(self): + return self + + def get_indices(self) -> list[TensorIndex]: + """ + Get a list of indices, corresponding to those of the tensor. + """ + return list(self.args[1]) + + def get_free_indices(self) -> list[TensorIndex]: + """ + Get a list of free indices, corresponding to those of the tensor. + """ + return self._index_structure.get_free_indices() + + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + # TODO: this could be optimized by only swapping the indices + # instead of visiting the whole expression tree: + return self.xreplace(repl) + + def as_base_exp(self): + return self, S.One + + def substitute_indices(self, *index_tuples): + """ + Return a tensor with free indices substituted according to ``index_tuples``. + + ``index_types`` list of tuples ``(old_index, new_index)``. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> i, j, k, l = tensor_indices('i,j,k,l', Lorentz) + >>> A, B = tensor_heads('A,B', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + >>> t = A(i, k)*B(-k, -j); t + A(i, L_0)*B(-L_0, -j) + >>> t.substitute_indices((i, k),(-j, l)) + A(k, L_0)*B(-L_0, l) + """ + indices = [] + for index in self.indices: + for ind_old, ind_new in index_tuples: + if (index.name == ind_old.name and index.tensor_index_type == + ind_old.tensor_index_type): + if index.is_up == ind_old.is_up: + indices.append(ind_new) + else: + indices.append(-ind_new) + break + else: + indices.append(index) + return self.head(*indices) + + def _get_symmetrized_forms(self): + """ + Return a list giving all possible permutations of self that are allowed by its symmetries. + """ + comp = self.component + gens = comp.symmetry.generators + rank = comp.rank + + old_perms = None + new_perms = {self} + while new_perms != old_perms: + old_perms = new_perms.copy() + for tens in old_perms: + for gen in gens: + inds = tens.get_indices() + per = [gen.apply(i) for i in range(0,rank)] + sign = (-1)**(gen.apply(rank) - rank) + ind_map = dict(zip(inds, [inds[i] for i in per])) + new_perms.add( sign * tens._replace_indices(ind_map) ) + + return new_perms + + def matches(self, expr, repl_dict=None, old=False): + expr = sympify(expr) + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + #simple checks + if self == expr: + return repl_dict + if not isinstance(expr, Tensor): + return None + if self.head != expr.head: + return None + + #Now consider all index symmetries of expr, and see if any of them allow a match. + for new_expr in expr._get_symmetrized_forms(): + m = self._matches(new_expr, repl_dict, old=old) + if m is not None: + repl_dict.update(m) + return repl_dict + + return None + + def _matches(self, expr, repl_dict=None, old=False): + """ + This does not account for index symmetries of expr + """ + expr = sympify(expr) + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + #simple checks + if self == expr: + return repl_dict + if not isinstance(expr, Tensor): + return None + if self.head != expr.head: + return None + + s_indices = self.get_indices() + e_indices = expr.get_indices() + + if len(s_indices) != len(e_indices): + return None + + for i in range(len(s_indices)): + s_ind = s_indices[i] + m = s_ind.matches(e_indices[i]) + if m is None: + return None + elif -s_ind in repl_dict.keys() and -repl_dict[-s_ind] != m[s_ind]: + return None + else: + repl_dict.update(m) + + return repl_dict + + def __call__(self, *indices): + deprecate_call() + free_args = self.free_args + indices = list(indices) + if [x.tensor_index_type for x in indices] != [x.tensor_index_type for x in free_args]: + raise ValueError('incompatible types') + if indices == free_args: + return self + t = self.substitute_indices(*list(zip(free_args, indices))) + + # object is rebuilt in order to make sure that all contracted indices + # get recognized as dummies, but only if there are contracted indices. + if len({i if i.is_up else -i for i in indices}) != len(indices): + return t.func(*t.args) + return t + + # TODO: put this into TensExpr? + def __iter__(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data.__iter__() + + # TODO: put this into TensExpr? + def __getitem__(self, item): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data[item] + + def _extract_data(self, replacement_dict): + from .array import Array + for k, v in replacement_dict.items(): + if isinstance(k, Tensor) and k.args[0] == self.args[0]: + other = k + array = v + break + else: + raise ValueError("%s not found in %s" % (self, replacement_dict)) + + # TODO: inefficient, this should be done at root level only: + replacement_dict = {k: Array(v) for k, v in replacement_dict.items()} + array = Array(array) + + dum1 = self.dum + dum2 = other.dum + + if len(dum2) > 0: + for pair in dum2: + # allow `dum2` if the contained values are also in `dum1`. + if pair not in dum1: + raise NotImplementedError("%s with contractions is not implemented" % other) + # Remove elements in `dum2` from `dum1`: + dum1 = [pair for pair in dum1 if pair not in dum2] + if len(dum1) > 0: + indices1 = self.get_indices() + indices2 = other.get_indices() + repl = {} + for p1, p2 in dum1: + repl[indices2[p2]] = -indices2[p1] + for pos in (p1, p2): + if indices1[pos].is_up ^ indices2[pos].is_up: + metric = replacement_dict[indices1[pos].tensor_index_type] + if indices1[pos].is_up: + metric = _TensorDataLazyEvaluator.inverse_matrix(metric) + array = self._contract_and_permute_with_metric(metric, array, pos, len(indices2)) + other = other.xreplace(repl).doit() + array = _TensorDataLazyEvaluator.data_contract_dum([array], dum1, len(indices2)) + + free_ind1 = self.get_free_indices() + free_ind2 = other.get_free_indices() + + return self._match_indices_with_other_tensor(array, free_ind1, free_ind2, replacement_dict) + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return _tensor_data_substitution_dict[self] + + @data.setter + def data(self, data): + deprecate_data() + # TODO: check data compatibility with properties of tensor. + with ignore_warnings(SymPyDeprecationWarning): + _tensor_data_substitution_dict[self] = data + + @data.deleter + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + if self.metric in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self.metric] + + def _print(self): + indices = [str(ind) for ind in self.indices] + component = self.component + if component.rank > 0: + return ('%s(%s)' % (component.name, ', '.join(indices))) + else: + return ('%s' % component.name) + + def equals(self, other): + if other == 0: + return self.coeff == 0 + other = _sympify(other) + if not isinstance(other, TensExpr): + assert not self.components + return S.One == other + + def _get_compar_comp(self): + t = self.canon_bp() + r = (t.coeff, tuple(t.components), \ + tuple(sorted(t.free)), tuple(sorted(t.dum))) + return r + + return _get_compar_comp(self) == _get_compar_comp(other) + + def contract_metric(self, g): + # if metric is not the same, ignore this step: + if self.component != g: + return self + # in case there are free components, do not perform anything: + if len(self.free) != 0: + return self + + #antisym = g.index_types[0].metric_antisym + if g.symmetry == TensorSymmetry.fully_symmetric(-2): + antisym = 1 + elif g.symmetry == TensorSymmetry.fully_symmetric(2): + antisym = 0 + elif g.symmetry == TensorSymmetry.no_symmetry(2): + antisym = None + else: + raise NotImplementedError + sign = S.One + typ = g.index_types[0] + + if not antisym: + # g(i, -i) + sign = sign*typ.dim + else: + # g(i, -i) + sign = sign*typ.dim + + dp0, dp1 = self.dum[0] + if dp0 < dp1: + # g(i, -i) = -D with antisymmetric metric + sign = -sign + + return sign + + def contract_delta(self, metric): + return self.contract_metric(metric) + + def _eval_rewrite_as_Indexed(self, tens, indices, **kwargs): + from sympy.tensor.indexed import Indexed + # TODO: replace .args[0] with .name: + index_symbols = [i.args[0] for i in self.get_indices()] + expr = Indexed(tens.args[0], *index_symbols) + return self._check_add_Sum(expr, index_symbols) + + def _eval_partial_derivative(self, s: Tensor) -> Expr: + + if not isinstance(s, Tensor): + return S.Zero + else: + + # @a_i/@a_k = delta_i^k + # @a_i/@a^k = g_ij delta^j_k + # @a^i/@a^k = delta^i_k + # @a^i/@a_k = g^ij delta_j^k + # TODO: if there is no metric present, the derivative should be zero? + + if self.head != s.head: + return S.Zero + + # if heads are the same, provide delta and/or metric products + # for every free index pair in the appropriate tensor + # assumed that the free indices are in proper order + # A contravariante index in the derivative becomes covariant + # after performing the derivative and vice versa + + kronecker_delta_list = [1] + + # not guarantee a correct index order + + for (count, (iself, iother)) in enumerate(zip(self.get_free_indices(), s.get_free_indices())): + if iself.tensor_index_type != iother.tensor_index_type: + raise ValueError("index types not compatible") + else: + tensor_index_type = iself.tensor_index_type + tensor_metric = tensor_index_type.metric + dummy = TensorIndex("d_" + str(count), tensor_index_type, + is_up=iself.is_up) + if iself.is_up == iother.is_up: + kroneckerdelta = tensor_index_type.delta(iself, -iother) + else: + kroneckerdelta = ( + TensMul(tensor_metric(iself, dummy), + tensor_index_type.delta(-dummy, -iother)) + ) + kronecker_delta_list.append(kroneckerdelta) + return TensMul.fromiter(kronecker_delta_list).doit(deep=False) + # doit necessary to rename dummy indices accordingly + + +class TensMul(TensExpr, AssocOp): + """ + Product of tensors. + + Parameters + ========== + + coeff : SymPy coefficient of the tensor + args + + Attributes + ========== + + ``components`` : list of ``TensorHead`` of the component tensors + ``types`` : list of nonrepeated ``TensorIndexType`` + ``free`` : list of ``(ind, ipos, icomp)``, see Notes + ``dum`` : list of ``(ipos1, ipos2, icomp1, icomp2)``, see Notes + ``ext_rank`` : rank of the tensor counting the dummy indices + ``rank`` : rank of the tensor + ``coeff`` : SymPy coefficient of the tensor + ``free_args`` : list of the free indices in sorted order + ``is_canon_bp`` : ``True`` if the tensor in in canonical form + + Notes + ===== + + ``args[0]`` list of ``TensorHead`` of the component tensors. + + ``args[1]`` list of ``(ind, ipos, icomp)`` + where ``ind`` is a free index, ``ipos`` is the slot position + of ``ind`` in the ``icomp``-th component tensor. + + ``args[2]`` list of tuples representing dummy indices. + ``(ipos1, ipos2, icomp1, icomp2)`` indicates that the contravariant + dummy index is the ``ipos1``-th slot position in the ``icomp1``-th + component tensor; the corresponding covariant index is + in the ``ipos2`` slot position in the ``icomp2``-th component tensor. + + """ + identity = S.One + + _index_structure: _IndexStructure + + def __new__(cls, *args, **kw_args): + is_canon_bp = kw_args.get('is_canon_bp', False) + args = list(map(_sympify, args)) + + """ + If the internal dummy indices in one arg conflict with the free indices + of the remaining args, we need to rename those internal dummy indices. + """ + free = [get_free_indices(arg) for arg in args] + free = set(itertools.chain(*free)) #flatten free + newargs = [] + for arg in args: + dum_this = set(get_dummy_indices(arg)) + dum_other = [get_dummy_indices(a) for a in newargs] + dum_other = set(itertools.chain(*dum_other)) #flatten dum_other + free_this = set(get_free_indices(arg)) + if len(dum_this.intersection(free)) > 0: + exclude = free_this.union(free, dum_other) + newarg = TensMul._dedupe_indices(arg, exclude) + else: + newarg = arg + newargs.append(newarg) + + args = newargs + + # Flatten: + args = [i for arg in args for i in (arg.args if isinstance(arg, (TensMul, Mul)) else [arg])] + + args, indices, free, dum = TensMul._tensMul_contract_indices(args, replace_indices=False) + + # Data for indices: + index_types = [i.tensor_index_type for i in indices] + index_structure = _IndexStructure(free, dum, index_types, indices, canon_bp=is_canon_bp) + + obj = TensExpr.__new__(cls, *args) + obj._indices = indices + obj._index_types = index_types.copy() + obj._index_structure = index_structure + obj._free = index_structure.free[:] + obj._dum = index_structure.dum[:] + obj._free_indices = {x[0] for x in obj.free} + obj._rank = len(obj.free) + obj._ext_rank = len(obj._index_structure.free) + 2*len(obj._index_structure.dum) + obj._coeff = S.One + obj._is_canon_bp = is_canon_bp + return obj + + index_types = property(lambda self: self._index_types) + free = property(lambda self: self._free) + dum = property(lambda self: self._dum) + free_indices = property(lambda self: self._free_indices) + rank = property(lambda self: self._rank) + ext_rank = property(lambda self: self._ext_rank) + + @staticmethod + def _indices_to_free_dum(args_indices): + free2pos1 = {} + free2pos2 = {} + dummy_data = [] + indices = [] + + # Notation for positions (to better understand the code): + # `pos1`: position in the `args`. + # `pos2`: position in the indices. + + # Example: + # A(i, j)*B(k, m, n)*C(p) + # `pos1` of `n` is 1 because it's in `B` (second `args` of TensMul). + # `pos2` of `n` is 4 because it's the fifth overall index. + + # Counter for the index position wrt the whole expression: + pos2 = 0 + + for pos1, arg_indices in enumerate(args_indices): + + for index in arg_indices: + if not isinstance(index, TensorIndex): + raise TypeError("expected TensorIndex") + if -index in free2pos1: + # Dummy index detected: + other_pos1 = free2pos1.pop(-index) + other_pos2 = free2pos2.pop(-index) + if index.is_up: + dummy_data.append((index, pos1, other_pos1, pos2, other_pos2)) + else: + dummy_data.append((-index, other_pos1, pos1, other_pos2, pos2)) + indices.append(index) + elif index in free2pos1: + raise ValueError("Repeated index: %s" % index) + else: + free2pos1[index] = pos1 + free2pos2[index] = pos2 + indices.append(index) + pos2 += 1 + + free = list(free2pos2.items()) + free_names = [i.name for i in free2pos2.keys()] + + dummy_data.sort(key=lambda x: x[3]) + return indices, free, free_names, dummy_data + + @staticmethod + def _dummy_data_to_dum(dummy_data): + return [(p2a, p2b) for (i, p1a, p1b, p2a, p2b) in dummy_data] + + @staticmethod + def _tensMul_contract_indices(args, replace_indices=True): + replacements = [{} for _ in args] + + #_index_order = all(_has_index_order(arg) for arg in args) + + args_indices = [get_indices(arg) for arg in args] + indices, free, free_names, dummy_data = TensMul._indices_to_free_dum(args_indices) + + cdt = defaultdict(int) + + def dummy_name_gen(tensor_index_type): + nd = str(cdt[tensor_index_type]) + cdt[tensor_index_type] += 1 + return tensor_index_type.dummy_name + '_' + nd + + if replace_indices: + for old_index, pos1cov, pos1contra, pos2cov, pos2contra in dummy_data: + index_type = old_index.tensor_index_type + while True: + dummy_name = dummy_name_gen(index_type) + if dummy_name not in free_names: + break + dummy = old_index.func(dummy_name, index_type, *old_index.args[2:]) + replacements[pos1cov][old_index] = dummy + replacements[pos1contra][-old_index] = -dummy + indices[pos2cov] = dummy + indices[pos2contra] = -dummy + args = [ + arg._replace_indices(repl) if isinstance(arg, TensExpr) else arg + for arg, repl in zip(args, replacements)] + + """ + The order of indices might've changed due to the replacements (e.g. if one of the args is a TensAdd, replacing an index can change the sort order of the terms, thus changing the order of indices returned by its get_indices() method). + To stay on the safe side, we calculate these quantities again. + """ + args_indices = [get_indices(arg) for arg in args] + indices, free, free_names, dummy_data = TensMul._indices_to_free_dum(args_indices) + + dum = TensMul._dummy_data_to_dum(dummy_data) + return args, indices, free, dum + + @staticmethod + def _get_components_from_args(args): + """ + Get a list of ``Tensor`` objects having the same ``TIDS`` if multiplied + by one another. + """ + components = [] + for arg in args: + if not isinstance(arg, TensExpr): + continue + if isinstance(arg, TensAdd): + continue + components.extend(arg.components) + return components + + @staticmethod + def _rebuild_tensors_list(args, index_structure): + indices = index_structure.get_indices() + #tensors = [None for i in components] # pre-allocate list + ind_pos = 0 + for i, arg in enumerate(args): + if not isinstance(arg, TensExpr): + continue + prev_pos = ind_pos + ind_pos += arg.ext_rank + args[i] = Tensor(arg.component, indices[prev_pos:ind_pos]) + + def doit(self, **hints): + is_canon_bp = self._is_canon_bp + deep = hints.get('deep', True) + if deep: + args = [arg.doit(**hints) for arg in self.args] + + """ + There may now be conflicts between dummy indices of different args + (each arg's doit method does not have any information about which + dummy indices are already used in the other args), so we + deduplicate them. + """ + rule = dict(zip(self.args, args)) + rule = self._dedupe_indices_in_rule(rule) + args = [rule[a] for a in self.args] + + else: + args = self.args + + args = [arg for arg in args if arg != self.identity] + + # Extract non-tensor coefficients: + coeff = reduce(lambda a, b: a*b, [arg for arg in args if not isinstance(arg, TensExpr)], S.One) + args = [arg for arg in args if isinstance(arg, TensExpr)] + + if len(args) == 0: + return coeff + + if coeff != self.identity: + args = [coeff] + args + if coeff == 0: + return S.Zero + + if len(args) == 1: + return args[0] + + args, indices, free, dum = TensMul._tensMul_contract_indices(args) + + # Data for indices: + index_types = [i.tensor_index_type for i in indices] + index_structure = _IndexStructure(free, dum, index_types, indices, canon_bp=is_canon_bp) + + obj = self.func(*args) + obj._index_types = index_types + obj._index_structure = index_structure + obj._ext_rank = len(obj._index_structure.free) + 2*len(obj._index_structure.dum) + obj._coeff = coeff + obj._is_canon_bp = is_canon_bp + return obj + + # TODO: this method should be private + # TODO: should this method be renamed _from_components_free_dum ? + @staticmethod + def from_data(coeff, components, free, dum, **kw_args): + return TensMul(coeff, *TensMul._get_tensors_from_components_free_dum(components, free, dum), **kw_args).doit(deep=False) + + @staticmethod + def _get_tensors_from_components_free_dum(components, free, dum): + """ + Get a list of ``Tensor`` objects by distributing ``free`` and ``dum`` indices on the ``components``. + """ + index_structure = _IndexStructure.from_components_free_dum(components, free, dum) + indices = index_structure.get_indices() + tensors = [None for i in components] # pre-allocate list + + # distribute indices on components to build a list of tensors: + ind_pos = 0 + for i, component in enumerate(components): + prev_pos = ind_pos + ind_pos += component.rank + tensors[i] = Tensor(component, indices[prev_pos:ind_pos]) + return tensors + + def _get_free_indices_set(self): + return {i[0] for i in self.free} + + def _get_dummy_indices_set(self): + dummy_pos = set(itertools.chain(*self.dum)) + return {idx for i, idx in enumerate(self._index_structure.get_indices()) if i in dummy_pos} + + def _get_position_offset_for_indices(self): + arg_offset = [None for i in range(self.ext_rank)] + counter = 0 + for arg in self.args: + if not isinstance(arg, TensExpr): + continue + for j in range(arg.ext_rank): + arg_offset[j + counter] = counter + counter += arg.ext_rank + return arg_offset + + @property + def free_args(self): + return sorted([x[0] for x in self.free]) + + @property + def components(self): + return self._get_components_from_args(self.args) + + @property + def free_in_args(self): + arg_offset = self._get_position_offset_for_indices() + argpos = self._get_indices_to_args_pos() + return [(ind, pos-arg_offset[pos], argpos[pos]) for (ind, pos) in self.free] + + @property + def coeff(self): + # return Mul.fromiter([c for c in self.args if not isinstance(c, TensExpr)]) + return self._coeff + + @property + def nocoeff(self): + return self.func(*self.args, 1/self.coeff).doit(deep=False) + + @property + def dum_in_args(self): + arg_offset = self._get_position_offset_for_indices() + argpos = self._get_indices_to_args_pos() + return [(p1-arg_offset[p1], p2-arg_offset[p2], argpos[p1], argpos[p2]) for p1, p2 in self.dum] + + def equals(self, other): + if other == 0: + return self.coeff == 0 + other = _sympify(other) + if not isinstance(other, TensExpr): + assert not self.components + return self.coeff == other + + return self.canon_bp() == other.canon_bp() + + def get_indices(self): + """ + Returns the list of indices of the tensor. + + Explanation + =========== + + The indices are listed in the order in which they appear in the + component tensors. + The dummy indices are given a name which does not collide with + the names of the free indices. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> g = Lorentz.metric + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t = p(m1)*g(m0,m2) + >>> t.get_indices() + [m1, m0, m2] + >>> t2 = p(m1)*g(-m1, m2) + >>> t2.get_indices() + [L_0, -L_0, m2] + """ + return self._indices + + def get_free_indices(self) -> list[TensorIndex]: + """ + Returns the list of free indices of the tensor. + + Explanation + =========== + + The indices are listed in the order in which they appear in the + component tensors. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> g = Lorentz.metric + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t = p(m1)*g(m0,m2) + >>> t.get_free_indices() + [m1, m0, m2] + >>> t2 = p(m1)*g(-m1, m2) + >>> t2.get_free_indices() + [m2] + """ + return self._index_structure.get_free_indices() + + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + return self.func(*[arg._replace_indices(repl) if isinstance(arg, TensExpr) else arg for arg in self.args]) + + def split(self): + """ + Returns a list of tensors, whose product is ``self``. + + Explanation + =========== + + Dummy indices contracted among different tensor components + become free indices with the same name as the one used to + represent the dummy indices. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> a, b, c, d = tensor_indices('a,b,c,d', Lorentz) + >>> A, B = tensor_heads('A,B', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + >>> t = A(a,b)*B(-b,c) + >>> t + A(a, L_0)*B(-L_0, c) + >>> t.split() + [A(a, L_0), B(-L_0, c)] + """ + if self.args == (): + return [self] + splitp = [] + res = 1 + for arg in self.args: + if isinstance(arg, Tensor): + splitp.append(res*arg) + res = 1 + else: + res *= arg + return splitp + + def _eval_expand_mul(self, **hints): + args1 = [arg.args if isinstance(arg, (Add, TensAdd)) else (arg,) for arg in self.args] + return TensAdd(*[ + TensMul(*i).doit(deep=False) for i in itertools.product(*args1)] + ) + + def __neg__(self): + return TensMul(S.NegativeOne, self, is_canon_bp=self._is_canon_bp).doit(deep=False) + + def __getitem__(self, item): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data[item] + + def _get_args_for_traditional_printer(self): + args = list(self.args) + if self.coeff.could_extract_minus_sign(): + # expressions like "-A(a)" + sign = "-" + if args[0] == S.NegativeOne: + args = args[1:] + else: + args[0] = -args[0] + else: + sign = "" + return sign, args + + def _sort_args_for_sorted_components(self): + """ + Returns the ``args`` sorted according to the components commutation + properties. + + Explanation + =========== + + The sorting is done taking into account the commutation group + of the component tensors. + """ + cv = [arg for arg in self.args if isinstance(arg, TensExpr)] + sign = 1 + n = len(cv) - 1 + for i in range(n): + for j in range(n, i, -1): + c = cv[j-1].commutes_with(cv[j]) + # if `c` is `None`, it does neither commute nor anticommute, skip: + if c not in (0, 1): + continue + typ1 = sorted(set(cv[j-1].component.index_types), key=lambda x: x.name) + typ2 = sorted(set(cv[j].component.index_types), key=lambda x: x.name) + if (typ1, cv[j-1].component.name) > (typ2, cv[j].component.name): + cv[j-1], cv[j] = cv[j], cv[j-1] + # if `c` is 1, the anticommute, so change sign: + if c: + sign = -sign + + coeff = sign * self.coeff + if coeff != 1: + return [coeff] + cv + return cv + + def sorted_components(self): + """ + Returns a tensor product with sorted components. + """ + return TensMul(*self._sort_args_for_sorted_components()).doit(deep=False) + + def perm2tensor(self, g, is_canon_bp=False): + """ + Returns the tensor corresponding to the permutation ``g`` + + For further details, see the method in ``TIDS`` with the same name. + """ + return perm2tensor(self, g, is_canon_bp=is_canon_bp) + + def canon_bp(self): + """ + Canonicalize using the Butler-Portugal algorithm for canonicalization + under monoterm symmetries. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(-2)) + >>> t = A(m0,-m1)*A(m1,-m0) + >>> t.canon_bp() + -A(L_0, L_1)*A(-L_0, -L_1) + >>> t = A(m0,-m1)*A(m1,-m2)*A(m2,-m0) + >>> t.canon_bp() + 0 + """ + if self._is_canon_bp: + return self + expr = self.expand() + if isinstance(expr, TensAdd): + return expr.canon_bp() + if not expr.components: + return expr + expr = expr.doit(deep=False) #make sure self.coeff is populated correctly + t = expr.sorted_components() + g, dummies, msym = t._index_structure.indices_canon_args() + v = components_canon_args(t.components) + can = canonicalize(g, dummies, msym, *v) + if can == 0: + return S.Zero + tmul = t.perm2tensor(can, True) + return tmul + + def contract_delta(self, delta): + t = self.contract_metric(delta) + return t + + def _get_indices_to_args_pos(self): + """ + Get a dict mapping the index position to TensMul's argument number. + """ + pos_map = {} + pos_counter = 0 + for arg_i, arg in enumerate(self.args): + if not isinstance(arg, TensExpr): + continue + assert isinstance(arg, Tensor) + for i in range(arg.ext_rank): + pos_map[pos_counter] = arg_i + pos_counter += 1 + return pos_map + + def contract_metric(self, g): + """ + Raise or lower indices with the metric ``g``. + + Parameters + ========== + + g : metric + + Notes + ===== + + See the ``TensorIndexType`` docstring for the contraction conventions. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> g = Lorentz.metric + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t = p(m0)*q(m1)*g(-m0, -m1) + >>> t.canon_bp() + metric(L_0, L_1)*p(-L_0)*q(-L_1) + >>> t.contract_metric(g).canon_bp() + p(L_0)*q(-L_0) + """ + expr = self.expand().doit(deep=False) + if self != expr: + expr = canon_bp(expr) + return contract_metric(expr, g) + pos_map = self._get_indices_to_args_pos() + args = list(self.args) + + #antisym = g.index_types[0].metric_antisym + if g.symmetry == TensorSymmetry.fully_symmetric(-2): + antisym = 1 + elif g.symmetry == TensorSymmetry.fully_symmetric(2): + antisym = 0 + elif g.symmetry == TensorSymmetry.no_symmetry(2): + antisym = None + else: + raise NotImplementedError + + # list of positions of the metric ``g`` inside ``args`` + gpos = [i for i, x in enumerate(self.args) if isinstance(x, Tensor) and x.component == g] + if not gpos: + return self + + # Sign is either 1 or -1, to correct the sign after metric contraction + # (for spinor indices). + sign = 1 + dum = self.dum[:] + free = self.free[:] + elim = set() + for gposx in gpos: + if gposx in elim: + continue + free1 = [x for x in free if pos_map[x[1]] == gposx] + dum1 = [x for x in dum if pos_map[x[0]] == gposx or pos_map[x[1]] == gposx] + if not dum1: + continue + elim.add(gposx) + # subs with the multiplication neutral element, that is, remove it: + args[gposx] = 1 + if len(dum1) == 2: + if not antisym: + dum10, dum11 = dum1 + if pos_map[dum10[1]] == gposx: + # the index with pos p0 contravariant + p0 = dum10[0] + else: + # the index with pos p0 is covariant + p0 = dum10[1] + if pos_map[dum11[1]] == gposx: + # the index with pos p1 is contravariant + p1 = dum11[0] + else: + # the index with pos p1 is covariant + p1 = dum11[1] + + dum.append((p0, p1)) + else: + dum10, dum11 = dum1 + # change the sign to bring the indices of the metric to contravariant + # form; change the sign if dum10 has the metric index in position 0 + if pos_map[dum10[1]] == gposx: + # the index with pos p0 is contravariant + p0 = dum10[0] + if dum10[1] == 1: + sign = -sign + else: + # the index with pos p0 is covariant + p0 = dum10[1] + if dum10[0] == 0: + sign = -sign + if pos_map[dum11[1]] == gposx: + # the index with pos p1 is contravariant + p1 = dum11[0] + sign = -sign + else: + # the index with pos p1 is covariant + p1 = dum11[1] + + dum.append((p0, p1)) + + elif len(dum1) == 1: + if not antisym: + dp0, dp1 = dum1[0] + if pos_map[dp0] == pos_map[dp1]: + # g(i, -i) + typ = g.index_types[0] + sign = sign*typ.dim + + else: + # g(i0, i1)*p(-i1) + if pos_map[dp0] == gposx: + p1 = dp1 + else: + p1 = dp0 + + ind, p = free1[0] + free.append((ind, p1)) + else: + dp0, dp1 = dum1[0] + if pos_map[dp0] == pos_map[dp1]: + # g(i, -i) + typ = g.index_types[0] + sign = sign*typ.dim + + if dp0 < dp1: + # g(i, -i) = -D with antisymmetric metric + sign = -sign + else: + # g(i0, i1)*p(-i1) + if pos_map[dp0] == gposx: + p1 = dp1 + if dp0 == 0: + sign = -sign + else: + p1 = dp0 + ind, p = free1[0] + free.append((ind, p1)) + dum = [x for x in dum if x not in dum1] + free = [x for x in free if x not in free1] + + # shift positions: + shift = 0 + shifts = [0]*len(args) + for i in range(len(args)): + if i in elim: + shift += 2 + continue + shifts[i] = shift + free = [(ind, p - shifts[pos_map[p]]) for (ind, p) in free if pos_map[p] not in elim] + dum = [(p0 - shifts[pos_map[p0]], p1 - shifts[pos_map[p1]]) for p0, p1 in dum if pos_map[p0] not in elim and pos_map[p1] not in elim] + + res = ( sign*TensMul(*args) ).doit(deep=False) + if not isinstance(res, TensExpr): + return res + im = _IndexStructure.from_components_free_dum(res.components, free, dum) + return res._set_new_index_structure(im) + + def _set_new_index_structure(self, im, is_canon_bp=False): + indices = im.get_indices() + return self._set_indices(*indices, is_canon_bp=is_canon_bp) + + def _set_indices(self, *indices, is_canon_bp=False, **kw_args): + if len(indices) != self.ext_rank: + raise ValueError("indices length mismatch") + args = list(self.args) + pos = 0 + for i, arg in enumerate(args): + if not isinstance(arg, TensExpr): + continue + assert isinstance(arg, Tensor) + ext_rank = arg.ext_rank + args[i] = arg._set_indices(*indices[pos:pos+ext_rank]) + pos += ext_rank + return TensMul(*args, is_canon_bp=is_canon_bp).doit(deep=False) + + @staticmethod + def _index_replacement_for_contract_metric(args, free, dum): + for arg in args: + if not isinstance(arg, TensExpr): + continue + assert isinstance(arg, Tensor) + + def substitute_indices(self, *index_tuples): + new_args = [] + for arg in self.args: + if isinstance(arg, TensExpr): + arg = arg.substitute_indices(*index_tuples) + new_args.append(arg) + return TensMul(*new_args).doit(deep=False) + + def __call__(self, *indices): + deprecate_call() + free_args = self.free_args + indices = list(indices) + if [x.tensor_index_type for x in indices] != [x.tensor_index_type for x in free_args]: + raise ValueError('incompatible types') + if indices == free_args: + return self + t = self.substitute_indices(*list(zip(free_args, indices))) + + # object is rebuilt in order to make sure that all contracted indices + # get recognized as dummies, but only if there are contracted indices. + if len({i if i.is_up else -i for i in indices}) != len(indices): + return t.func(*t.args) + return t + + def _extract_data(self, replacement_dict): + args_indices, arrays = zip(*[arg._extract_data(replacement_dict) for arg in self.args if isinstance(arg, TensExpr)]) + coeff = reduce(operator.mul, [a for a in self.args if not isinstance(a, TensExpr)], S.One) + indices, free, free_names, dummy_data = TensMul._indices_to_free_dum(args_indices) + dum = TensMul._dummy_data_to_dum(dummy_data) + ext_rank = self.ext_rank + free.sort(key=lambda x: x[1]) + free_indices = [i[0] for i in free] + return free_indices, coeff*_TensorDataLazyEvaluator.data_contract_dum(arrays, dum, ext_rank) + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + dat = _tensor_data_substitution_dict[self.expand()] + return dat + + @data.setter + def data(self, data): + deprecate_data() + raise ValueError("Not possible to set component data to a tensor expression") + + @data.deleter + def data(self): + deprecate_data() + raise ValueError("Not possible to delete component data to a tensor expression") + + def __iter__(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self.data is None: + raise ValueError("No iteration on abstract tensors") + return self.data.__iter__() + + @staticmethod + def _dedupe_indices(new, exclude): + """ + exclude: set + new: TensExpr + + If ``new`` has any dummy indices that are in ``exclude``, return a version + of new with those indices replaced. If no replacements are needed, + return None + + """ + exclude = set(exclude) + dums_new = set(get_dummy_indices(new)) + free_new = set(get_free_indices(new)) + + conflicts = dums_new.intersection(exclude) + if len(conflicts) == 0: + return None + + """ + ``exclude_for_gen`` is to be passed to ``_IndexStructure._get_generator_for_dummy_indices()``. + Since the latter does not use the index position for anything, we just + set it as ``None`` here. + """ + exclude.update(dums_new) + exclude.update(free_new) + exclude_for_gen = [(i, None) for i in exclude] + gen = _IndexStructure._get_generator_for_dummy_indices(exclude_for_gen) + repl = {} + for d in conflicts: + if -d in repl.keys(): + continue + newname = gen(d.tensor_index_type) + new_d = d.func(newname, *d.args[1:]) + repl[d] = new_d + repl[-d] = -new_d + + if len(repl) == 0: + return None + + new_renamed = new._replace_indices(repl) + return new_renamed + + def _dedupe_indices_in_rule(self, rule): + """ + rule: dict + + This applies TensMul._dedupe_indices on all values of rule. + + """ + index_rules = {k:v for k,v in rule.items() if isinstance(k, TensorIndex)} + other_rules = {k:v for k,v in rule.items() if k not in index_rules.keys()} + exclude = set(self.get_indices()) + + newrule = {} + newrule.update(index_rules) + exclude.update(index_rules.keys()) + exclude.update(index_rules.values()) + for old, new in other_rules.items(): + new_renamed = TensMul._dedupe_indices(new, exclude) + if old == new or new_renamed is None: + newrule[old] = new + else: + newrule[old] = new_renamed + exclude.update(get_indices(new_renamed)) + return newrule + + def _eval_subs(self, old, new): + """ + If new is an index which is already present in self as a dummy, the dummies in self should be renamed. + """ + + if not isinstance(new, TensorIndex): + return None + + exclude = {new} + self_renamed = self._dedupe_indices(self, exclude) + if self_renamed is None: + return None + else: + return self_renamed._subs(old, new).doit(deep=False) + + def _eval_rewrite_as_Indexed(self, *args, **kwargs): + from sympy.concrete.summations import Sum + index_symbols = [i.args[0] for i in self.get_indices()] + args = [arg.args[0] if isinstance(arg, Sum) else arg for arg in args] + expr = Mul.fromiter(args) + return self._check_add_Sum(expr, index_symbols) + + def _eval_partial_derivative(self, s): + # Evaluation like Mul + terms = [] + for i, arg in enumerate(self.args): + # checking whether some tensor instance is differentiated + # or some other thing is necessary, but ugly + if isinstance(arg, TensExpr): + d = arg._eval_partial_derivative(s) + else: + # do not call diff is s is no symbol + if s._diff_wrt: + d = arg._eval_derivative(s) + else: + d = S.Zero + if d: + terms.append(TensMul.fromiter(self.args[:i] + (d,) + self.args[i + 1:]).doit(deep=False)) + return TensAdd.fromiter(terms).doit(deep=False) + + + def _matches_commutative(self, expr, repl_dict=None, old=False): + """ + Match assuming all tensors commute. But note that we are not assuming anything about their symmetry under index permutations. + """ + #Take care of the various possible types for expr. + if not isinstance(expr, TensMul): + if isinstance(expr, (TensExpr, Expr)): + expr = TensMul(expr) + else: + return None + + #The code that follows assumes expr is a TensMul + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + #Make sure that none of the dummy indices in self, expr conflict with the values already present in repl_dict. This may happen due to automatic index relabelling when rem_query and rem_expr are formed later on in this function (it calls itself recursively). + indices = [k for k in repl_dict.values() if isinstance(k ,TensorIndex)] + + def dedupe(expr): + renamed = TensMul._dedupe_indices(expr, indices) + if renamed is not None: + return renamed + else: + return expr + + self = dedupe(self) + expr = dedupe(expr) + + #Find the non-tensor part of expr. This need not be the same as expr.coeff when expr.doit() has not been called. + expr_coeff = reduce(lambda a, b: a*b, [arg for arg in expr.args if not isinstance(arg, TensExpr)], S.One) + + # handle simple patterns + if self == expr: + return repl_dict + + if len(_get_wilds(self)) == 0: + return self._matches_simple(expr, repl_dict, old) + + def siftkey(arg): + if isinstance(arg, WildTensor): + return "WildTensor" + elif isinstance(arg, (Tensor, TensExpr)): + return "Tensor" + else: + return "coeff" + + query_sifted = sift(self.args, siftkey) + expr_sifted = sift(expr.args, siftkey) + + #Sanity checks + if "coeff" in query_sifted.keys(): + if TensMul(*query_sifted["coeff"]).doit(deep=False) != self.coeff: + raise NotImplementedError(f"Found something that we do not know to handle: {query_sifted['coeff']}") + if "coeff" in expr_sifted.keys(): + if TensMul(*expr_sifted["coeff"]).doit(deep=False) != expr_coeff: + raise NotImplementedError(f"Found something that we do not know to handle: {expr_sifted['coeff']}") + + query_tens_heads = {tuple(getattr(x, "components", [])) for x in query_sifted["Tensor"]} #We use getattr because, e.g. TensAdd does not have the 'components' attribute. + expr_tens_heads = {tuple(getattr(x, "components", [])) for x in expr_sifted["Tensor"]} + if not query_tens_heads.issubset(expr_tens_heads): + #Some tensorheads in self are not present in the expr + return None + + #Try to match all non-wild tensors of self with tensors that compose expr + if len(query_sifted["Tensor"]) > 0: + q_tensor = query_sifted["Tensor"][0] + """ + We need to iterate over all possible symmetrized forms of q_tensor since the matches given by some of them may map dummy indices to free indices; the information about which indices are dummy/free will only be available later, when we are doing rem_q.matches(rem_e) + """ + for q_tens in q_tensor._get_symmetrized_forms(): + for e in expr_sifted["Tensor"]: + if isinstance(q_tens, TensMul): + #q_tensor got a minus sign due to this permutation. + sign = -1 + else: + sign = 1 + + """ + _matches is used here since we are already iterating over index permutations of q_tensor. Also note that the sign is removed from q_tensor, and will later be put into rem_q. + """ + m = (sign*q_tens)._matches(e) + if m is None: + continue + + rem_query = self.func(sign, *[a for a in self.args if a != q_tensor]).doit(deep=False) + rem_expr = expr.func(*[a for a in expr.args if a != e]).doit(deep=False) + tmp_repl = {} + tmp_repl.update(repl_dict) + tmp_repl.update(m) + rem_m = rem_query.matches(rem_expr, repl_dict=tmp_repl) + if rem_m is not None: + #Check that contracted indices are not mapped to different indices. + internally_consistent = True + for k in rem_m.keys(): + if isinstance(k,TensorIndex): + if -k in rem_m.keys() and rem_m[-k] != -rem_m[k]: + internally_consistent = False + break + if internally_consistent: + repl_dict.update(rem_m) + return repl_dict + + return None + + #Try to match WildTensor instances which have indices + matched_e_tensors = [] + remaining_e_tensors = expr_sifted["Tensor"] + indexless_wilds, wilds = sift(query_sifted["WildTensor"], lambda x: len(x.get_free_indices()) == 0, binary=True) + + for w in wilds: + free_this_wild = set(w.get_free_indices()) + tensors_to_try = [] + for t in remaining_e_tensors: + free = t.get_free_indices() + shares_indices_with_wild = True + for i in free: + if all(j.matches(i) is None for j in free_this_wild): + #The index i matches none of the indices in free_this_wild + shares_indices_with_wild = False + if shares_indices_with_wild: + tensors_to_try.append(t) + + m = w.matches(TensMul(*tensors_to_try).doit(deep=False) ) + if m is None: + return None + else: + for tens in tensors_to_try: + matched_e_tensors.append(tens) + repl_dict.update(m) + + #Try to match indexless WildTensor instances + remaining_e_tensors = [t for t in expr_sifted["Tensor"] if t not in matched_e_tensors] + if len(indexless_wilds) > 0: + #If there are any remaining tensors, match them with the indexless WildTensor + m = indexless_wilds[0].matches( TensMul(1,*remaining_e_tensors).doit(deep=False) ) + if m is None: + return None + else: + repl_dict.update(m) + elif len(remaining_e_tensors) > 0: + return None + + #Try to match the non-tensorial coefficient + m = self.coeff.matches(expr_coeff, old=old) + if m is None: + return None + else: + repl_dict.update(m) + + return repl_dict + + def matches(self, expr, repl_dict=None, old=False): + expr = sympify(expr) + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + commute = all(arg.component.comm == 0 for arg in expr.args if isinstance(arg, Tensor)) + if commute: + return self._matches_commutative(expr, repl_dict, old) + else: + raise NotImplementedError("Tensor matching not implemented for non-commuting tensors") + +class TensorElement(TensExpr): + """ + Tensor with evaluated components. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, TensorHead, TensorSymmetry + >>> from sympy import symbols + >>> L = TensorIndexType("L") + >>> i, j, k = symbols("i j k") + >>> A = TensorHead("A", [L, L], TensorSymmetry.fully_symmetric(2)) + >>> A(i, j).get_free_indices() + [i, j] + + If we want to set component ``i`` to a specific value, use the + ``TensorElement`` class: + + >>> from sympy.tensor.tensor import TensorElement + >>> te = TensorElement(A(i, j), {i: 2}) + + As index ``i`` has been accessed (``{i: 2}`` is the evaluation of its 3rd + element), the free indices will only contain ``j``: + + >>> te.get_free_indices() + [j] + """ + + def __new__(cls, expr, index_map): + if not isinstance(expr, Tensor): + # remap + if not isinstance(expr, TensExpr): + raise TypeError("%s is not a tensor expression" % expr) + return expr.func(*[TensorElement(arg, index_map) for arg in expr.args]) + expr_free_indices = expr.get_free_indices() + name_translation = {i.args[0]: i for i in expr_free_indices} + index_map = {name_translation.get(index, index): value for index, value in index_map.items()} + index_map = {index: value for index, value in index_map.items() if index in expr_free_indices} + if len(index_map) == 0: + return expr + free_indices = [i for i in expr_free_indices if i not in index_map.keys()] + index_map = Dict(index_map) + obj = TensExpr.__new__(cls, expr, index_map) + obj._free_indices = free_indices + return obj + + @property + def free(self): + return [(index, i) for i, index in enumerate(self.get_free_indices())] + + @property + def dum(self): + # TODO: inherit dummies from expr + return [] + + @property + def expr(self): + return self._args[0] + + @property + def index_map(self): + return self._args[1] + + @property + def coeff(self): + return S.One + + @property + def nocoeff(self): + return self + + def get_free_indices(self): + return self._free_indices + + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + # TODO: can be improved: + return self.xreplace(repl) + + def get_indices(self): + return self.get_free_indices() + + def _extract_data(self, replacement_dict): + ret_indices, array = self.expr._extract_data(replacement_dict) + index_map = self.index_map + slice_tuple = tuple(index_map.get(i, slice(None)) for i in ret_indices) + ret_indices = [i for i in ret_indices if i not in index_map] + array = array.__getitem__(slice_tuple) + return ret_indices, array + + +class WildTensorHead(TensorHead): + """ + A wild object that is used to create ``WildTensor`` instances + + Explanation + =========== + + Examples + ======== + >>> from sympy.tensor.tensor import TensorHead, TensorIndex, WildTensorHead, TensorIndexType + >>> R3 = TensorIndexType('R3', dim=3) + >>> p = TensorIndex('p', R3) + >>> q = TensorIndex('q', R3) + + A WildTensorHead can be created without specifying a ``TensorIndexType`` + + >>> W = WildTensorHead("W") + + Calling it with a ``TensorIndex`` creates a ``WildTensor`` instance. + + >>> type(W(p)) + + + The ``TensorIndexType`` is automatically detected from the index that is passed + + >>> W(p).component + W(R3) + + Calling it with no indices returns an object that can match tensors with any number of indices. + + >>> K = TensorHead('K', [R3]) + >>> Q = TensorHead('Q', [R3, R3]) + >>> W().matches(K(p)) + {W: K(p)} + >>> W().matches(Q(p,q)) + {W: Q(p, q)} + + If you want to ignore the order of indices while matching, pass ``unordered_indices=True``. + + >>> U = WildTensorHead("U", unordered_indices=True) + >>> W(p,q).matches(Q(q,p)) + >>> U(p,q).matches(Q(q,p)) + {U(R3,R3): _WildTensExpr(Q(q, p))} + + Parameters + ========== + name : name of the tensor + unordered_indices : whether the order of the indices matters for matching + (default: False) + + See also + ======== + ``WildTensor`` + ``TensorHead`` + + """ + def __new__(cls, name, index_types=None, symmetry=None, comm=0, unordered_indices=False): + if isinstance(name, str): + name_symbol = Symbol(name) + elif isinstance(name, Symbol): + name_symbol = name + else: + raise ValueError("invalid name") + + if index_types is None: + index_types = [] + + if symmetry is None: + symmetry = TensorSymmetry.no_symmetry(len(index_types)) + else: + assert symmetry.rank == len(index_types) + + if symmetry != TensorSymmetry.no_symmetry(len(index_types)): + raise NotImplementedError("Wild matching based on symmetry is not implemented.") + + obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), sympify(symmetry), sympify(comm), sympify(unordered_indices)) + + return obj + + @property + def unordered_indices(self): + return self.args[4] + + def __call__(self, *indices, **kwargs): + tensor = WildTensor(self, indices, **kwargs) + return tensor.doit() + + +class WildTensor(Tensor): + """ + A wild object which matches ``Tensor`` instances + + Explanation + =========== + This is instantiated by attaching indices to a ``WildTensorHead`` instance. + + Examples + ======== + >>> from sympy.tensor.tensor import TensorHead, TensorIndex, WildTensorHead, TensorIndexType + >>> W = WildTensorHead("W") + >>> R3 = TensorIndexType('R3', dim=3) + >>> p = TensorIndex('p', R3) + >>> q = TensorIndex('q', R3) + >>> K = TensorHead('K', [R3]) + >>> Q = TensorHead('Q', [R3, R3]) + + Matching also takes the indices into account + >>> W(p).matches(K(p)) + {W(R3): _WildTensExpr(K(p))} + >>> W(p).matches(K(q)) + >>> W(p).matches(K(-p)) + + If you want to match objects with any number of indices, just use a ``WildTensor`` with no indices. + >>> W().matches(K(p)) + {W: K(p)} + >>> W().matches(Q(p,q)) + {W: Q(p, q)} + + See Also + ======== + ``WildTensorHead`` + ``Tensor`` + + """ + def __new__(cls, tensor_head, indices, **kw_args): + is_canon_bp = kw_args.pop("is_canon_bp", False) + + if tensor_head.func == TensorHead: + """ + If someone tried to call WildTensor by supplying a TensorHead (not a WildTensorHead), return a normal tensor instead. This is helpful when using subs on an expression to replace occurrences of a WildTensorHead with a TensorHead. + """ + return Tensor(tensor_head, indices, is_canon_bp=is_canon_bp, **kw_args) + elif tensor_head.func == _WildTensExpr: + return tensor_head(*indices) + + indices = cls._parse_indices(tensor_head, indices) + index_types = [ind.tensor_index_type for ind in indices] + tensor_head = tensor_head.func( + tensor_head.name, + index_types, + symmetry=None, + comm=tensor_head.comm, + unordered_indices=tensor_head.unordered_indices, + ) + + obj = Basic.__new__(cls, tensor_head, Tuple(*indices)) + obj.name = tensor_head.name + obj._index_structure = _IndexStructure.from_indices(*indices) + obj._free = obj._index_structure.free[:] + obj._dum = obj._index_structure.dum[:] + obj._ext_rank = obj._index_structure._ext_rank + obj._coeff = S.One + obj._nocoeff = obj + obj._component = tensor_head + obj._components = [tensor_head] + if tensor_head.rank != len(indices): + raise ValueError("wrong number of indices") + obj.is_canon_bp = is_canon_bp + obj._index_map = obj._build_index_map(indices, obj._index_structure) + + return obj + + + def matches(self, expr, repl_dict=None, old=False): + if not isinstance(expr, TensExpr) and expr != S(1): + return None + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + if len(self.indices) > 0: + if not hasattr(expr, "get_free_indices"): + return None + expr_indices = expr.get_free_indices() + if len(expr_indices) != len(self.indices): + return None + if self._component.unordered_indices: + m = self._match_indices_ignoring_order(expr) + if m is None: + return None + else: + repl_dict.update(m) + else: + for i in range(len(expr_indices)): + m = self.indices[i].matches(expr_indices[i]) + if m is None: + return None + else: + repl_dict.update(m) + + repl_dict[self.component] = _WildTensExpr(expr) + else: + #If no indices were passed to the WildTensor, it may match tensors with any number of indices. + repl_dict[self] = expr + + return repl_dict + + def _match_indices_ignoring_order(self, expr, repl_dict=None, old=False): + """ + Helper method for matches. Checks if the indices of self and expr + match disregarding index ordering. + """ + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + def siftkey(ind): + if isinstance(ind, WildTensorIndex): + if ind.ignore_updown: + return "wild, updown" + else: + return "wild" + else: + return "nonwild" + + indices_sifted = sift(self.indices, siftkey) + + matched_indices = [] + expr_indices_remaining = expr.get_indices() + for ind in indices_sifted["nonwild"]: + matched_this_ind = False + for e_ind in expr_indices_remaining: + if e_ind in matched_indices: + continue + m = ind.matches(e_ind) + if m is not None: + matched_this_ind = True + repl_dict.update(m) + matched_indices.append(e_ind) + break + if not matched_this_ind: + return None + + expr_indices_remaining = [i for i in expr_indices_remaining if i not in matched_indices] + for ind in indices_sifted["wild"]: + matched_this_ind = False + for e_ind in expr_indices_remaining: + m = ind.matches(e_ind) + if m is not None: + if -ind in repl_dict.keys() and -repl_dict[-ind] != m[ind]: + return None + matched_this_ind = True + repl_dict.update(m) + matched_indices.append(e_ind) + break + if not matched_this_ind: + return None + + expr_indices_remaining = [i for i in expr_indices_remaining if i not in matched_indices] + for ind in indices_sifted["wild, updown"]: + matched_this_ind = False + for e_ind in expr_indices_remaining: + m = ind.matches(e_ind) + if m is not None: + if -ind in repl_dict.keys() and -repl_dict[-ind] != m[ind]: + return None + matched_this_ind = True + repl_dict.update(m) + matched_indices.append(e_ind) + break + if not matched_this_ind: + return None + + if len(matched_indices) < len(self.indices): + return None + else: + return repl_dict + +class WildTensorIndex(TensorIndex): + """ + A wild object that matches TensorIndex instances. + + Examples + ======== + >>> from sympy.tensor.tensor import TensorIndex, TensorIndexType, WildTensorIndex + >>> R3 = TensorIndexType('R3', dim=3) + >>> p = TensorIndex("p", R3) + + By default, covariant indices only match with covariant indices (and + similarly for contravariant) + + >>> q = WildTensorIndex("q", R3) + >>> (q).matches(p) + {q: p} + >>> (q).matches(-p) + + If you want matching to ignore whether the index is co/contra-variant, set + ignore_updown=True + + >>> r = WildTensorIndex("r", R3, ignore_updown=True) + >>> (r).matches(-p) + {r: -p} + >>> (r).matches(p) + {r: p} + + Parameters + ========== + name : name of the index (string), or ``True`` if you want it to be + automatically assigned + tensor_index_type : ``TensorIndexType`` of the index + is_up : flag for contravariant index (is_up=True by default) + ignore_updown : bool, Whether this should match both co- and contra-variant + indices (default:False) + """ + def __new__(cls, name, tensor_index_type, is_up=True, ignore_updown=False): + if isinstance(name, str): + name_symbol = Symbol(name) + elif isinstance(name, Symbol): + name_symbol = name + elif name is True: + name = "_i{}".format(len(tensor_index_type._autogenerated)) + name_symbol = Symbol(name) + tensor_index_type._autogenerated.append(name_symbol) + else: + raise ValueError("invalid name") + + is_up = sympify(is_up) + ignore_updown = sympify(ignore_updown) + return Basic.__new__(cls, name_symbol, tensor_index_type, is_up, ignore_updown) + + @property + def ignore_updown(self): + return self.args[3] + + def __neg__(self): + t1 = WildTensorIndex(self.name, self.tensor_index_type, + (not self.is_up), self.ignore_updown) + return t1 + + def matches(self, expr, repl_dict=None, old=False): + if not isinstance(expr, TensorIndex): + return None + if self.tensor_index_type != expr.tensor_index_type: + return None + if not self.ignore_updown: + if self.is_up != expr.is_up: + return None + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + repl_dict[self] = expr + return repl_dict + + +class _WildTensExpr(Basic): + """ + INTERNAL USE ONLY + + This is an object that helps with replacement of WildTensors in expressions. + When this object is set as the tensor_head of a WildTensor, it replaces the + WildTensor by a TensExpr (passed when initializing this object). + + Examples + ======== + >>> from sympy.tensor.tensor import WildTensorHead, TensorIndex, TensorHead, TensorIndexType + >>> W = WildTensorHead("W") + >>> R3 = TensorIndexType('R3', dim=3) + >>> p = TensorIndex('p', R3) + >>> q = TensorIndex('q', R3) + >>> K = TensorHead('K', [R3]) + >>> print( ( K(p) ).replace( W(p), W(q)*W(-q)*W(p) ) ) + K(R_0)*K(-R_0)*K(p) + + """ + def __init__(self, expr): + if not isinstance(expr, TensExpr): + raise TypeError("_WildTensExpr expects a TensExpr as argument") + self.expr = expr + + def __call__(self, *indices): + return self.expr._replace_indices(dict(zip(self.expr.get_free_indices(), indices))) + + def __neg__(self): + return self.func(self.expr*S.NegativeOne) + + def __abs__(self): + raise NotImplementedError + + def __add__(self, other): + if other.func != self.func: + raise TypeError(f"Cannot add {self.func} to {other.func}") + return self.func(self.expr+other.expr) + + def __radd__(self, other): + if other.func != self.func: + raise TypeError(f"Cannot add {self.func} to {other.func}") + return self.func(other.expr+self.expr) + + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return other + (-self) + + def __mul__(self, other): + raise NotImplementedError + + def __rmul__(self, other): + raise NotImplementedError + + def __truediv__(self, other): + raise NotImplementedError + + def __rtruediv__(self, other): + raise NotImplementedError + + def __pow__(self, other): + raise NotImplementedError + + def __rpow__(self, other): + raise NotImplementedError + + +def canon_bp(p): + """ + Butler-Portugal canonicalization. See ``tensor_can.py`` from the + combinatorics module for the details. + """ + if isinstance(p, TensExpr): + return p.canon_bp() + return p + + +def tensor_mul(*a): + """ + product of tensors + """ + if not a: + return TensMul.from_data(S.One, [], [], []) + t = a[0] + for tx in a[1:]: + t = t*tx + return t + + +def riemann_cyclic_replace(t_r): + """ + replace Riemann tensor with an equivalent expression + + ``R(m,n,p,q) -> 2/3*R(m,n,p,q) - 1/3*R(m,q,n,p) + 1/3*R(m,p,n,q)`` + + """ + free = sorted(t_r.free, key=lambda x: x[1]) + m, n, p, q = [x[0] for x in free] + t0 = t_r*Rational(2, 3) + t1 = -t_r.substitute_indices((m,m),(n,q),(p,n),(q,p))*Rational(1, 3) + t2 = t_r.substitute_indices((m,m),(n,p),(p,n),(q,q))*Rational(1, 3) + t3 = t0 + t1 + t2 + return t3 + +def riemann_cyclic(t2): + """ + Replace each Riemann tensor with an equivalent expression + satisfying the cyclic identity. + + This trick is discussed in the reference guide to Cadabra. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, riemann_cyclic, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> i, j, k, l = tensor_indices('i,j,k,l', Lorentz) + >>> R = TensorHead('R', [Lorentz]*4, TensorSymmetry.riemann()) + >>> t = R(i,j,k,l)*(R(-i,-j,-k,-l) - 2*R(-i,-k,-j,-l)) + >>> riemann_cyclic(t) + 0 + """ + t2 = t2.expand() + if isinstance(t2, (TensMul, Tensor)): + args = [t2] + else: + args = t2.args + a1 = [x.split() for x in args] + a2 = [[riemann_cyclic_replace(tx) for tx in y] for y in a1] + a3 = [tensor_mul(*v) for v in a2] + t3 = TensAdd(*a3).doit(deep=False) + if not t3: + return t3 + else: + return canon_bp(t3) + + +def get_lines(ex, index_type): + """ + Returns ``(lines, traces, rest)`` for an index type, + where ``lines`` is the list of list of positions of a matrix line, + ``traces`` is the list of list of traced matrix lines, + ``rest`` is the rest of the elements of the tensor. + """ + def _join_lines(a): + i = 0 + while i < len(a): + x = a[i] + xend = x[-1] + xstart = x[0] + hit = True + while hit: + hit = False + for j in range(i + 1, len(a)): + if j >= len(a): + break + if a[j][0] == xend: + hit = True + x.extend(a[j][1:]) + xend = x[-1] + a.pop(j) + continue + if a[j][0] == xstart: + hit = True + a[i] = reversed(a[j][1:]) + x + x = a[i] + xstart = a[i][0] + a.pop(j) + continue + if a[j][-1] == xend: + hit = True + x.extend(reversed(a[j][:-1])) + xend = x[-1] + a.pop(j) + continue + if a[j][-1] == xstart: + hit = True + a[i] = a[j][:-1] + x + x = a[i] + xstart = x[0] + a.pop(j) + continue + i += 1 + return a + + arguments = ex.args + dt = {} + for c in ex.args: + if not isinstance(c, TensExpr): + continue + if c in dt: + continue + index_types = c.index_types + a = [] + for i in range(len(index_types)): + if index_types[i] is index_type: + a.append(i) + if len(a) > 2: + raise ValueError('at most two indices of type %s allowed' % index_type) + if len(a) == 2: + dt[c] = a + #dum = ex.dum + lines = [] + traces = [] + traces1 = [] + #indices_to_args_pos = ex._get_indices_to_args_pos() + # TODO: add a dum_to_components_map ? + for p0, p1, c0, c1 in ex.dum_in_args: + if arguments[c0] not in dt: + continue + if c0 == c1: + traces.append([c0]) + continue + ta0 = dt[arguments[c0]] + ta1 = dt[arguments[c1]] + if p0 not in ta0: + continue + if ta0.index(p0) == ta1.index(p1): + # case gamma(i,s0,-s1) in c0, gamma(j,-s0,s2) in c1; + # to deal with this case one could add to the position + # a flag for transposition; + # one could write [(c0, False), (c1, True)] + raise NotImplementedError + # if p0 == ta0[1] then G in pos c0 is mult on the right by G in c1 + # if p0 == ta0[0] then G in pos c1 is mult on the right by G in c0 + ta0 = dt[arguments[c0]] + b0, b1 = (c0, c1) if p0 == ta0[1] else (c1, c0) + lines1 = lines.copy() + for line in lines: + if line[-1] == b0: + if line[0] == b1: + n = line.index(min(line)) + traces1.append(line) + traces.append(line[n:] + line[:n]) + else: + line.append(b1) + break + elif line[0] == b1: + line.insert(0, b0) + break + else: + lines1.append([b0, b1]) + + lines = [x for x in lines1 if x not in traces1] + lines = _join_lines(lines) + rest = [] + for line in lines: + for y in line: + rest.append(y) + for line in traces: + for y in line: + rest.append(y) + rest = [x for x in range(len(arguments)) if x not in rest] + + return lines, traces, rest + + +def get_free_indices(t): + if not isinstance(t, TensExpr): + return () + return t.get_free_indices() + + +def get_indices(t): + if not isinstance(t, TensExpr): + return () + return t.get_indices() + +def get_dummy_indices(t): + if not isinstance(t, TensExpr): + return () + inds = t.get_indices() + free = t.get_free_indices() + return [i for i in inds if i not in free] + +def get_index_structure(t): + if isinstance(t, TensExpr): + return t._index_structure + return _IndexStructure([], [], [], []) + + +def get_coeff(t): + if isinstance(t, Tensor): + return S.One + if isinstance(t, TensMul): + return t.coeff + if isinstance(t, TensExpr): + raise ValueError("no coefficient associated to this tensor expression") + return t + +def contract_metric(t, g): + if isinstance(t, TensExpr): + return t.contract_metric(g) + return t + +def perm2tensor(t, g, is_canon_bp=False): + """ + Returns the tensor corresponding to the permutation ``g`` + + For further details, see the method in ``TIDS`` with the same name. + """ + if not isinstance(t, TensExpr): + return t + elif isinstance(t, (Tensor, TensMul)): + nim = get_index_structure(t).perm2tensor(g, is_canon_bp=is_canon_bp) + res = t._set_new_index_structure(nim, is_canon_bp=is_canon_bp) + if g[-1] != len(g) - 1: + return -res + + return res + raise NotImplementedError() + + +def substitute_indices(t, *index_tuples): + if not isinstance(t, TensExpr): + return t + return t.substitute_indices(*index_tuples) + + +def _get_wilds(expr): + return list(expr.atoms(Wild, WildFunction, WildTensor, WildTensorIndex, WildTensorHead)) + + +def get_postprocessor(cls): + def _postprocessor(expr): + tens_class = {Mul: TensMul, Add: TensAdd}[cls] + if any(isinstance(a, TensExpr) for a in expr.args): + return tens_class(*expr.args) + else: + return expr + + return _postprocessor + +Basic._constructor_postprocessor_mapping[TensExpr] = { + "Mul": [get_postprocessor(Mul)], +} diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/toperators.py b/.venv/lib/python3.13/site-packages/sympy/tensor/toperators.py new file mode 100644 index 0000000000000000000000000000000000000000..1bdd67c4f4a7e86b9821ee55b1d2f9bde29c96a8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/toperators.py @@ -0,0 +1,256 @@ +from sympy import permutedims +from sympy.core.numbers import Number +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.tensor.tensor import Tensor, TensExpr, TensAdd, TensMul + + +class PartialDerivative(TensExpr): + """ + Partial derivative for tensor expressions. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, TensorHead + >>> from sympy.tensor.toperators import PartialDerivative + >>> from sympy import symbols + >>> L = TensorIndexType("L") + >>> A = TensorHead("A", [L]) + >>> B = TensorHead("B", [L]) + >>> i, j, k = symbols("i j k") + + >>> expr = PartialDerivative(A(i), A(j)) + >>> expr + PartialDerivative(A(i), A(j)) + + The ``PartialDerivative`` object behaves like a tensorial expression: + + >>> expr.get_indices() + [i, -j] + + Notice that the deriving variables have opposite valence than the + printed one: ``A(j)`` is printed as covariant, but the index of the + derivative is actually contravariant, i.e. ``-j``. + + Indices can be contracted: + + >>> expr = PartialDerivative(A(i), A(i)) + >>> expr + PartialDerivative(A(L_0), A(L_0)) + >>> expr.get_indices() + [L_0, -L_0] + + The method ``.get_indices()`` always returns all indices (even the + contracted ones). If only uncontracted indices are needed, call + ``.get_free_indices()``: + + >>> expr.get_free_indices() + [] + + Nested partial derivatives are flattened: + + >>> expr = PartialDerivative(PartialDerivative(A(i), A(j)), A(k)) + >>> expr + PartialDerivative(A(i), A(j), A(k)) + >>> expr.get_indices() + [i, -j, -k] + + Replace a derivative with array values: + + >>> from sympy.abc import x, y + >>> from sympy import sin, log + >>> compA = [sin(x), log(x)*y**3] + >>> compB = [x, y] + >>> expr = PartialDerivative(A(i), B(j)) + >>> expr.replace_with_arrays({A(i): compA, B(i): compB}) + [[cos(x), 0], [y**3/x, 3*y**2*log(x)]] + + The returned array is indexed by `(i, -j)`. + + Be careful that other SymPy modules put the indices of the deriving + variables before the indices of the derivand in the derivative result. + For example: + + >>> expr.get_free_indices() + [i, -j] + + >>> from sympy import Matrix, Array + >>> Matrix(compA).diff(Matrix(compB)).reshape(2, 2) + [[cos(x), y**3/x], [0, 3*y**2*log(x)]] + >>> Array(compA).diff(Array(compB)) + [[cos(x), y**3/x], [0, 3*y**2*log(x)]] + + These are the transpose of the result of ``PartialDerivative``, + as the matrix and the array modules put the index `-j` before `i` in the + derivative result. An array read with index order `(-j, i)` is indeed the + transpose of the same array read with index order `(i, -j)`. By specifying + the index order to ``.replace_with_arrays`` one can get a compatible + expression: + + >>> expr.replace_with_arrays({A(i): compA, B(i): compB}, [-j, i]) + [[cos(x), y**3/x], [0, 3*y**2*log(x)]] + """ + + def __new__(cls, expr, *variables): + + # Flatten: + if isinstance(expr, PartialDerivative): + variables = expr.variables + variables + expr = expr.expr + + args, indices, free, dum = cls._contract_indices_for_derivative( + S(expr), variables) + + obj = TensExpr.__new__(cls, *args) + + obj._indices = indices + obj._free = free + obj._dum = dum + return obj + + @property + def coeff(self): + return S.One + + @property + def nocoeff(self): + return self + + @classmethod + def _contract_indices_for_derivative(cls, expr, variables): + variables_opposite_valence = [] + + for i in variables: + if isinstance(i, Tensor): + i_free_indices = i.get_free_indices() + variables_opposite_valence.append( + i.xreplace({k: -k for k in i_free_indices})) + elif isinstance(i, Symbol): + variables_opposite_valence.append(i) + + args, indices, free, dum = TensMul._tensMul_contract_indices( + [expr] + variables_opposite_valence, replace_indices=True) + + for i in range(1, len(args)): + args_i = args[i] + if isinstance(args_i, Tensor): + i_indices = args[i].get_free_indices() + args[i] = args[i].xreplace({k: -k for k in i_indices}) + + return args, indices, free, dum + + def doit(self, **hints): + args, indices, free, dum = self._contract_indices_for_derivative(self.expr, self.variables) + + obj = self.func(*args) + obj._indices = indices + obj._free = free + obj._dum = dum + + return obj + + def _expand_partial_derivative(self): + args, indices, free, dum = self._contract_indices_for_derivative(self.expr, self.variables) + + obj = self.func(*args) + obj._indices = indices + obj._free = free + obj._dum = dum + + result = obj + + if not args[0].free_symbols: + return S.Zero + elif isinstance(obj.expr, TensAdd): + # take care of sums of multi PDs + result = obj.expr.func(*[ + self.func(a, *obj.variables)._expand_partial_derivative() + for a in result.expr.args]) + elif isinstance(obj.expr, TensMul): + # take care of products of multi PDs + if len(obj.variables) == 1: + # derivative with respect to single variable + terms = [] + mulargs = list(obj.expr.args) + for ind in range(len(mulargs)): + if not isinstance(sympify(mulargs[ind]), Number): + # a number coefficient is not considered for + # expansion of PartialDerivative + d = self.func(mulargs[ind], *obj.variables)._expand_partial_derivative() + terms.append(TensMul(*(mulargs[:ind] + + [d] + + mulargs[(ind + 1):]))) + result = TensAdd.fromiter(terms) + else: + # derivative with respect to multiple variables + # decompose: + # partial(expr, (u, v)) + # = partial(partial(expr, u).doit(), v).doit() + result = obj.expr # init with expr + for v in obj.variables: + result = self.func(result, v)._expand_partial_derivative() + # then throw PD on it + + return result + + def _perform_derivative(self): + result = self.expr + for v in self.variables: + if isinstance(result, TensExpr): + result = result._eval_partial_derivative(v) + else: + if v._diff_wrt: + result = result._eval_derivative(v) + else: + result = S.Zero + return result + + def get_indices(self): + return self._indices + + def get_free_indices(self): + free = sorted(self._free, key=lambda x: x[1]) + return [i[0] for i in free] + + def _replace_indices(self, repl): + expr = self.expr.xreplace(repl) + mirrored = {-k: -v for k, v in repl.items()} + variables = [i.xreplace(mirrored) for i in self.variables] + return self.func(expr, *variables) + + @property + def expr(self): + return self.args[0] + + @property + def variables(self): + return self.args[1:] + + def _extract_data(self, replacement_dict): + from .array import derive_by_array, tensorcontraction + indices, array = self.expr._extract_data(replacement_dict) + for variable in self.variables: + var_indices, var_array = variable._extract_data(replacement_dict) + var_indices = [-i for i in var_indices] + coeff_array, var_array = zip(*[i.as_coeff_Mul() for i in var_array]) + dim_before = len(array.shape) + array = derive_by_array(array, var_array) + dim_after = len(array.shape) + dim_increase = dim_after - dim_before + array = permutedims(array, [i + dim_increase for i in range(dim_before)] + list(range(dim_increase))) + array = array.as_mutable() + varindex = var_indices[0] + # Remove coefficients of base vector: + coeff_index = [0] + [slice(None) for i in range(len(indices))] + for i, coeff in enumerate(coeff_array): + coeff_index[0] = i + array[tuple(coeff_index)] /= coeff + if -varindex in indices: + pos = indices.index(-varindex) + array = tensorcontraction(array, (0, pos+1)) + indices.pop(pos) + else: + indices.append(varindex) + return indices, array diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/__init__.py b/.venv/lib/python3.13/site-packages/sympy/vector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6757bbeb35022481b1cf183373ecccd19779faa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/__init__.py @@ -0,0 +1,50 @@ +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.vector import (Vector, VectorAdd, VectorMul, + BaseVector, VectorZero, Cross, Dot, cross, dot) +from sympy.vector.dyadic import (Dyadic, DyadicAdd, DyadicMul, + BaseDyadic, DyadicZero) +from sympy.vector.scalar import BaseScalar +from sympy.vector.deloperator import Del +from sympy.vector.functions import (express, matrix_to_vector, + laplacian, is_conservative, + is_solenoidal, scalar_potential, + directional_derivative, + scalar_potential_difference) +from sympy.vector.point import Point +from sympy.vector.orienters import (AxisOrienter, BodyOrienter, + SpaceOrienter, QuaternionOrienter) +from sympy.vector.operators import Gradient, Divergence, Curl, Laplacian, gradient, curl, divergence +from sympy.vector.implicitregion import ImplicitRegion +from sympy.vector.parametricregion import (ParametricRegion, parametric_region_list) +from sympy.vector.integrals import (ParametricIntegral, vector_integrate) +from sympy.vector.kind import VectorKind + +__all__ = [ + 'Vector', 'VectorAdd', 'VectorMul', 'BaseVector', 'VectorZero', 'Cross', + 'Dot', 'cross', 'dot', + + 'VectorKind', + + 'Dyadic', 'DyadicAdd', 'DyadicMul', 'BaseDyadic', 'DyadicZero', + + 'BaseScalar', + + 'Del', + + 'CoordSys3D', + + 'express', 'matrix_to_vector', 'laplacian', 'is_conservative', + 'is_solenoidal', 'scalar_potential', 'directional_derivative', + 'scalar_potential_difference', + + 'Point', + + 'AxisOrienter', 'BodyOrienter', 'SpaceOrienter', 'QuaternionOrienter', + + 'Gradient', 'Divergence', 'Curl', 'Laplacian', 'gradient', 'curl', + 'divergence', + + 'ParametricRegion', 'parametric_region_list', 'ImplicitRegion', + + 'ParametricIntegral', 'vector_integrate', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/basisdependent.py b/.venv/lib/python3.13/site-packages/sympy/vector/basisdependent.py new file mode 100644 index 0000000000000000000000000000000000000000..53e4efc0bf839fb5a5de2d1af1487683fabd8cf1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/basisdependent.py @@ -0,0 +1,374 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from sympy.simplify import simplify as simp, trigsimp as tsimp # type: ignore +from sympy.core.decorators import call_highest_priority, _sympifyit +from sympy.core.assumptions import StdFactKB +from sympy.core.function import diff as df +from sympy.integrals.integrals import Integral +from sympy.polys.polytools import factor as fctr +from sympy.core import S, Add, Mul +from sympy.core.expr import Expr + +if TYPE_CHECKING: + from sympy.vector.vector import BaseVector + + +class BasisDependent(Expr): + """ + Super class containing functionality common to vectors and + dyadics. + Named so because the representation of these quantities in + sympy.vector is dependent on the basis they are expressed in. + """ + + zero: BasisDependentZero + + @call_highest_priority('__radd__') + def __add__(self, other): + return self._add_func(self, other) + + @call_highest_priority('__add__') + def __radd__(self, other): + return self._add_func(other, self) + + @call_highest_priority('__rsub__') + def __sub__(self, other): + return self._add_func(self, -other) + + @call_highest_priority('__sub__') + def __rsub__(self, other): + return self._add_func(other, -self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rmul__') + def __mul__(self, other): + return self._mul_func(self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__mul__') + def __rmul__(self, other): + return self._mul_func(other, self) + + def __neg__(self): + return self._mul_func(S.NegativeOne, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rtruediv__') + def __truediv__(self, other): + return self._div_helper(other) + + @call_highest_priority('__truediv__') + def __rtruediv__(self, other): + return TypeError("Invalid divisor for division") + + def evalf(self, n=15, subs=None, maxn=100, chop=False, strict=False, quad=None, verbose=False): + """ + Implements the SymPy evalf routine for this quantity. + + evalf's documentation + ===================== + + """ + options = {'subs':subs, 'maxn':maxn, 'chop':chop, 'strict':strict, + 'quad':quad, 'verbose':verbose} + vec = self.zero + for k, v in self.components.items(): + vec += v.evalf(n, **options) * k + return vec + + evalf.__doc__ += Expr.evalf.__doc__ # type: ignore + + n = evalf # type: ignore + + def simplify(self, **kwargs): + """ + Implements the SymPy simplify routine for this quantity. + + simplify's documentation + ======================== + + """ + simp_components = [simp(v, **kwargs) * k for + k, v in self.components.items()] + return self._add_func(*simp_components) + + simplify.__doc__ += simp.__doc__ # type: ignore + + def trigsimp(self, **opts): + """ + Implements the SymPy trigsimp routine, for this quantity. + + trigsimp's documentation + ======================== + + """ + trig_components = [tsimp(v, **opts) * k for + k, v in self.components.items()] + return self._add_func(*trig_components) + + trigsimp.__doc__ += tsimp.__doc__ # type: ignore + + def _eval_simplify(self, **kwargs): + return self.simplify(**kwargs) + + def _eval_trigsimp(self, **opts): + return self.trigsimp(**opts) + + def _eval_derivative(self, wrt): + return self.diff(wrt) + + def _eval_Integral(self, *symbols, **assumptions): + integral_components = [Integral(v, *symbols, **assumptions) * k + for k, v in self.components.items()] + return self._add_func(*integral_components) + + def as_numer_denom(self): + """ + Returns the expression as a tuple wrt the following + transformation - + + expression -> a/b -> a, b + + """ + return self, S.One + + def factor(self, *args, **kwargs): + """ + Implements the SymPy factor routine, on the scalar parts + of a basis-dependent expression. + + factor's documentation + ======================== + + """ + fctr_components = [fctr(v, *args, **kwargs) * k for + k, v in self.components.items()] + return self._add_func(*fctr_components) + + factor.__doc__ += fctr.__doc__ # type: ignore + + def as_coeff_Mul(self, rational=False): + """Efficiently extract the coefficient of a product.""" + return (S.One, self) + + def as_coeff_add(self, *deps): + """Efficiently extract the coefficient of a summation.""" + return 0, tuple(x * self.components[x] for x in self.components) + + def diff(self, *args, **kwargs): + """ + Implements the SymPy diff routine, for vectors. + + diff's documentation + ======================== + + """ + for x in args: + if isinstance(x, BasisDependent): + raise TypeError("Invalid arg for differentiation") + diff_components = [df(v, *args, **kwargs) * k for + k, v in self.components.items()] + return self._add_func(*diff_components) + + diff.__doc__ += df.__doc__ # type: ignore + + def doit(self, **hints): + """Calls .doit() on each term in the Dyadic""" + doit_components = [self.components[x].doit(**hints) * x + for x in self.components] + return self._add_func(*doit_components) + + +class BasisDependentAdd(BasisDependent, Add): + """ + Denotes sum of basis dependent quantities such that they cannot + be expressed as base or Mul instances. + """ + + def __new__(cls, *args, **options): + components = {} + + # Check each arg and simultaneously learn the components + for arg in args: + if not isinstance(arg, cls._expr_type): + if isinstance(arg, Mul): + arg = cls._mul_func(*(arg.args)) + elif isinstance(arg, Add): + arg = cls._add_func(*(arg.args)) + else: + raise TypeError(str(arg) + + " cannot be interpreted correctly") + # If argument is zero, ignore + if arg == cls.zero: + continue + # Else, update components accordingly + for x in arg.components: + components[x] = components.get(x, 0) + arg.components[x] + + temp = list(components.keys()) + for x in temp: + if components[x] == 0: + del components[x] + + # Handle case of zero vector + if len(components) == 0: + return cls.zero + + # Build object + newargs = [x * components[x] for x in components] + obj = super().__new__(cls, *newargs, **options) + if isinstance(obj, Mul): + return cls._mul_func(*obj.args) + assumptions = {'commutative': True} + obj._assumptions = StdFactKB(assumptions) + obj._components = components + obj._sys = (list(components.keys()))[0]._sys + + return obj + + +class BasisDependentMul(BasisDependent, Mul): + """ + Denotes product of base- basis dependent quantity with a scalar. + """ + + def __new__(cls, *args, **options): + obj = cls._new(*args, **options) + return obj + + def _new_rawargs(self, *args): + # XXX: This is needed because Add.flatten() uses it but the default + # implementation does not work for Vectors because they assign + # attributes outside of .args. + return type(self)(*args) + + @classmethod + def _new(cls, *args, **options): + from sympy.vector import Cross, Dot, Curl, Gradient + count = 0 + measure_number = S.One + zeroflag = False + extra_args = [] + + # Determine the component and check arguments + # Also keep a count to ensure two vectors aren't + # being multiplied + for arg in args: + if isinstance(arg, cls._zero_func): + count += 1 + zeroflag = True + elif arg == S.Zero: + zeroflag = True + elif isinstance(arg, (cls._base_func, cls._mul_func)): + count += 1 + expr = arg._base_instance + measure_number *= arg._measure_number + elif isinstance(arg, cls._add_func): + count += 1 + expr = arg + elif isinstance(arg, (Cross, Dot, Curl, Gradient)): + extra_args.append(arg) + else: + measure_number *= arg + # Make sure incompatible types weren't multiplied + if count > 1: + raise ValueError("Invalid multiplication") + elif count == 0: + return Mul(*args, **options) + # Handle zero vector case + if zeroflag: + return cls.zero + + # If one of the args was a VectorAdd, return an + # appropriate VectorAdd instance + if isinstance(expr, cls._add_func): + newargs = [cls._mul_func(measure_number, x) for + x in expr.args] + return cls._add_func(*newargs) + + obj = super().__new__(cls, measure_number, + expr._base_instance, + *extra_args, + **options) + if isinstance(obj, Add): + return cls._add_func(*obj.args) + obj._base_instance = expr._base_instance + obj._measure_number = measure_number + assumptions = {'commutative': True} + obj._assumptions = StdFactKB(assumptions) + obj._components = {expr._base_instance: measure_number} + obj._sys = expr._base_instance._sys + + return obj + + def _sympystr(self, printer): + measure_str = printer._print(self._measure_number) + if ('(' in measure_str or '-' in measure_str or + '+' in measure_str): + measure_str = '(' + measure_str + ')' + return measure_str + '*' + printer._print(self._base_instance) + + +class BasisDependentZero(BasisDependent): + """ + Class to denote a zero basis dependent instance. + """ + components: dict['BaseVector', Expr] = {} + _latex_form: str + + def __new__(cls): + obj = super().__new__(cls) + # Pre-compute a specific hash value for the zero vector + # Use the same one always + obj._hash = (S.Zero, cls).__hash__() + return obj + + def __hash__(self): + return self._hash + + @call_highest_priority('__req__') + def __eq__(self, other): + return isinstance(other, self._zero_func) + + __req__ = __eq__ + + @call_highest_priority('__radd__') + def __add__(self, other): + if isinstance(other, self._expr_type): + return other + else: + raise TypeError("Invalid argument types for addition") + + @call_highest_priority('__add__') + def __radd__(self, other): + if isinstance(other, self._expr_type): + return other + else: + raise TypeError("Invalid argument types for addition") + + @call_highest_priority('__rsub__') + def __sub__(self, other): + if isinstance(other, self._expr_type): + return -other + else: + raise TypeError("Invalid argument types for subtraction") + + @call_highest_priority('__sub__') + def __rsub__(self, other): + if isinstance(other, self._expr_type): + return other + else: + raise TypeError("Invalid argument types for subtraction") + + def __neg__(self): + return self + + def normalize(self): + """ + Returns the normalized version of this vector. + """ + return self + + def _sympystr(self, printer): + return '0' diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/coordsysrect.py b/.venv/lib/python3.13/site-packages/sympy/vector/coordsysrect.py new file mode 100644 index 0000000000000000000000000000000000000000..55539fb19dc4221de69437111f44d6a6cc70b3e4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/coordsysrect.py @@ -0,0 +1,1031 @@ +from collections.abc import Callable + +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core import S, Dummy, Lambda +from sympy.core.symbol import Str +from sympy.core.symbol import symbols +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.matrices.matrixbase import MatrixBase +from sympy.solvers import solve +from sympy.vector.scalar import BaseScalar +from sympy.core.containers import Tuple +from sympy.core.function import diff +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, atan2, cos, sin) +from sympy.matrices.dense import eye +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.simplify.simplify import simplify +from sympy.simplify.trigsimp import trigsimp +import sympy.vector +from sympy.vector.orienters import (Orienter, AxisOrienter, BodyOrienter, + SpaceOrienter, QuaternionOrienter) + + +class CoordSys3D(Basic): + """ + Represents a coordinate system in 3-D space. + """ + + def __new__(cls, name, transformation=None, parent=None, location=None, + rotation_matrix=None, vector_names=None, variable_names=None): + """ + The orientation/location parameters are necessary if this system + is being defined at a certain orientation or location wrt another. + + Parameters + ========== + + name : str + The name of the new CoordSys3D instance. + + transformation : Lambda, Tuple, str + Transformation defined by transformation equations or chosen + from predefined ones. + + location : Vector + The position vector of the new system's origin wrt the parent + instance. + + rotation_matrix : SymPy ImmutableMatrix + The rotation matrix of the new coordinate system with respect + to the parent. In other words, the output of + new_system.rotation_matrix(parent). + + parent : CoordSys3D + The coordinate system wrt which the orientation/location + (or both) is being defined. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + """ + + name = str(name) + Vector = sympy.vector.Vector + Point = sympy.vector.Point + + if not isinstance(name, str): + raise TypeError("name should be a string") + + if transformation is not None: + if (location is not None) or (rotation_matrix is not None): + raise ValueError("specify either `transformation` or " + "`location`/`rotation_matrix`") + if isinstance(transformation, (Tuple, tuple, list)): + if isinstance(transformation[0], MatrixBase): + rotation_matrix = transformation[0] + location = transformation[1] + else: + transformation = Lambda(transformation[0], + transformation[1]) + elif isinstance(transformation, Callable): + x1, x2, x3 = symbols('x1 x2 x3', cls=Dummy) + transformation = Lambda((x1, x2, x3), + transformation(x1, x2, x3)) + elif isinstance(transformation, str): + transformation = Str(transformation) + elif isinstance(transformation, (Str, Lambda)): + pass + else: + raise TypeError("transformation: " + "wrong type {}".format(type(transformation))) + + # If orientation information has been provided, store + # the rotation matrix accordingly + if rotation_matrix is None: + rotation_matrix = ImmutableDenseMatrix(eye(3)) + else: + if not isinstance(rotation_matrix, MatrixBase): + raise TypeError("rotation_matrix should be an Immutable" + + "Matrix instance") + rotation_matrix = rotation_matrix.as_immutable() + + # If location information is not given, adjust the default + # location as Vector.zero + if parent is not None: + if not isinstance(parent, CoordSys3D): + raise TypeError("parent should be a " + + "CoordSys3D/None") + if location is None: + location = Vector.zero + else: + if not isinstance(location, Vector): + raise TypeError("location should be a Vector") + # Check that location does not contain base + # scalars + for x in location.free_symbols: + if isinstance(x, BaseScalar): + raise ValueError("location should not contain" + + " BaseScalars") + origin = parent.origin.locate_new(name + '.origin', + location) + else: + location = Vector.zero + origin = Point(name + '.origin') + + if transformation is None: + transformation = Tuple(rotation_matrix, location) + + if isinstance(transformation, Tuple): + lambda_transformation = CoordSys3D._compose_rotation_and_translation( + transformation[0], + transformation[1], + parent + ) + r, l = transformation + l = l._projections + lambda_lame = CoordSys3D._get_lame_coeff('cartesian') + lambda_inverse = lambda x, y, z: r.inv()*Matrix( + [x-l[0], y-l[1], z-l[2]]) + elif isinstance(transformation, Str): + trname = transformation.name + lambda_transformation = CoordSys3D._get_transformation_lambdas(trname) + if parent is not None: + if parent.lame_coefficients() != (S.One, S.One, S.One): + raise ValueError('Parent for pre-defined coordinate ' + 'system should be Cartesian.') + lambda_lame = CoordSys3D._get_lame_coeff(trname) + lambda_inverse = CoordSys3D._set_inv_trans_equations(trname) + elif isinstance(transformation, Lambda): + if not CoordSys3D._check_orthogonality(transformation): + raise ValueError("The transformation equation does not " + "create orthogonal coordinate system") + lambda_transformation = transformation + lambda_lame = CoordSys3D._calculate_lame_coeff(lambda_transformation) + lambda_inverse = None + else: + lambda_transformation = lambda x, y, z: transformation(x, y, z) + lambda_lame = CoordSys3D._get_lame_coeff(transformation) + lambda_inverse = None + + if variable_names is None: + if isinstance(transformation, Lambda): + variable_names = ["x1", "x2", "x3"] + elif isinstance(transformation, Str): + if transformation.name == 'spherical': + variable_names = ["r", "theta", "phi"] + elif transformation.name == 'cylindrical': + variable_names = ["r", "theta", "z"] + else: + variable_names = ["x", "y", "z"] + else: + variable_names = ["x", "y", "z"] + if vector_names is None: + vector_names = ["i", "j", "k"] + + # All systems that are defined as 'roots' are unequal, unless + # they have the same name. + # Systems defined at same orientation/position wrt the same + # 'parent' are equal, irrespective of the name. + # This is true even if the same orientation is provided via + # different methods like Axis/Body/Space/Quaternion. + # However, coincident systems may be seen as unequal if + # positioned/oriented wrt different parents, even though + # they may actually be 'coincident' wrt the root system. + if parent is not None: + obj = super().__new__( + cls, Str(name), transformation, parent) + else: + obj = super().__new__( + cls, Str(name), transformation) + obj._name = name + # Initialize the base vectors + + _check_strings('vector_names', vector_names) + vector_names = list(vector_names) + latex_vects = [(r'\mathbf{\hat{%s}_{%s}}' % (x, name)) for + x in vector_names] + pretty_vects = ['%s_%s' % (x, name) for x in vector_names] + + obj._vector_names = vector_names + + v1 = BaseVector(0, obj, pretty_vects[0], latex_vects[0]) + v2 = BaseVector(1, obj, pretty_vects[1], latex_vects[1]) + v3 = BaseVector(2, obj, pretty_vects[2], latex_vects[2]) + + obj._base_vectors = (v1, v2, v3) + + # Initialize the base scalars + + _check_strings('variable_names', vector_names) + variable_names = list(variable_names) + latex_scalars = [(r"\mathbf{{%s}_{%s}}" % (x, name)) for + x in variable_names] + pretty_scalars = ['%s_%s' % (x, name) for x in variable_names] + + obj._variable_names = variable_names + obj._vector_names = vector_names + + x1 = BaseScalar(0, obj, pretty_scalars[0], latex_scalars[0]) + x2 = BaseScalar(1, obj, pretty_scalars[1], latex_scalars[1]) + x3 = BaseScalar(2, obj, pretty_scalars[2], latex_scalars[2]) + + obj._base_scalars = (x1, x2, x3) + + obj._transformation = transformation + obj._transformation_lambda = lambda_transformation + obj._lame_coefficients = lambda_lame(x1, x2, x3) + obj._transformation_from_parent_lambda = lambda_inverse + + setattr(obj, variable_names[0], x1) + setattr(obj, variable_names[1], x2) + setattr(obj, variable_names[2], x3) + + setattr(obj, vector_names[0], v1) + setattr(obj, vector_names[1], v2) + setattr(obj, vector_names[2], v3) + + # Assign params + obj._parent = parent + if obj._parent is not None: + obj._root = obj._parent._root + else: + obj._root = obj + + obj._parent_rotation_matrix = rotation_matrix + obj._origin = origin + + # Return the instance + return obj + + def _sympystr(self, printer): + return self._name + + def __iter__(self): + return iter(self.base_vectors()) + + @staticmethod + def _check_orthogonality(equations): + """ + Helper method for _connect_to_cartesian. It checks if + set of transformation equations create orthogonal curvilinear + coordinate system + + Parameters + ========== + + equations : Lambda + Lambda of transformation equations + + """ + + x1, x2, x3 = symbols("x1, x2, x3", cls=Dummy) + equations = equations(x1, x2, x3) + v1 = Matrix([diff(equations[0], x1), + diff(equations[1], x1), diff(equations[2], x1)]) + + v2 = Matrix([diff(equations[0], x2), + diff(equations[1], x2), diff(equations[2], x2)]) + + v3 = Matrix([diff(equations[0], x3), + diff(equations[1], x3), diff(equations[2], x3)]) + + if any(simplify(i[0] + i[1] + i[2]) == 0 for i in (v1, v2, v3)): + return False + else: + if simplify(v1.dot(v2)) == 0 and simplify(v2.dot(v3)) == 0 \ + and simplify(v3.dot(v1)) == 0: + return True + else: + return False + + @staticmethod + def _set_inv_trans_equations(curv_coord_name): + """ + Store information about inverse transformation equations for + pre-defined coordinate systems. + + Parameters + ========== + + curv_coord_name : str + Name of coordinate system + + """ + if curv_coord_name == 'cartesian': + return lambda x, y, z: (x, y, z) + + if curv_coord_name == 'spherical': + return lambda x, y, z: ( + sqrt(x**2 + y**2 + z**2), + acos(z/sqrt(x**2 + y**2 + z**2)), + atan2(y, x) + ) + if curv_coord_name == 'cylindrical': + return lambda x, y, z: ( + sqrt(x**2 + y**2), + atan2(y, x), + z + ) + raise ValueError('Wrong set of parameters.' + 'Type of coordinate system is defined') + + def _calculate_inv_trans_equations(self): + """ + Helper method for set_coordinate_type. It calculates inverse + transformation equations for given transformations equations. + + """ + x1, x2, x3 = symbols("x1, x2, x3", cls=Dummy, reals=True) + x, y, z = symbols("x, y, z", cls=Dummy) + + equations = self._transformation(x1, x2, x3) + + solved = solve([equations[0] - x, + equations[1] - y, + equations[2] - z], (x1, x2, x3), dict=True)[0] + solved = solved[x1], solved[x2], solved[x3] + self._transformation_from_parent_lambda = \ + lambda x1, x2, x3: tuple(i.subs(list(zip((x, y, z), (x1, x2, x3)))) for i in solved) + + @staticmethod + def _get_lame_coeff(curv_coord_name): + """ + Store information about Lame coefficients for pre-defined + coordinate systems. + + Parameters + ========== + + curv_coord_name : str + Name of coordinate system + + """ + if isinstance(curv_coord_name, str): + if curv_coord_name == 'cartesian': + return lambda x, y, z: (S.One, S.One, S.One) + if curv_coord_name == 'spherical': + return lambda r, theta, phi: (S.One, r, r*sin(theta)) + if curv_coord_name == 'cylindrical': + return lambda r, theta, h: (S.One, r, S.One) + raise ValueError('Wrong set of parameters.' + ' Type of coordinate system is not defined') + return CoordSys3D._calculate_lame_coefficients(curv_coord_name) + + @staticmethod + def _calculate_lame_coeff(equations): + """ + It calculates Lame coefficients + for given transformations equations. + + Parameters + ========== + + equations : Lambda + Lambda of transformation equations. + + """ + return lambda x1, x2, x3: ( + sqrt(diff(equations(x1, x2, x3)[0], x1)**2 + + diff(equations(x1, x2, x3)[1], x1)**2 + + diff(equations(x1, x2, x3)[2], x1)**2), + sqrt(diff(equations(x1, x2, x3)[0], x2)**2 + + diff(equations(x1, x2, x3)[1], x2)**2 + + diff(equations(x1, x2, x3)[2], x2)**2), + sqrt(diff(equations(x1, x2, x3)[0], x3)**2 + + diff(equations(x1, x2, x3)[1], x3)**2 + + diff(equations(x1, x2, x3)[2], x3)**2) + ) + + def _inverse_rotation_matrix(self): + """ + Returns inverse rotation matrix. + """ + return simplify(self._parent_rotation_matrix**-1) + + @staticmethod + def _get_transformation_lambdas(curv_coord_name): + """ + Store information about transformation equations for pre-defined + coordinate systems. + + Parameters + ========== + + curv_coord_name : str + Name of coordinate system + + """ + if isinstance(curv_coord_name, str): + if curv_coord_name == 'cartesian': + return lambda x, y, z: (x, y, z) + if curv_coord_name == 'spherical': + return lambda r, theta, phi: ( + r*sin(theta)*cos(phi), + r*sin(theta)*sin(phi), + r*cos(theta) + ) + if curv_coord_name == 'cylindrical': + return lambda r, theta, h: ( + r*cos(theta), + r*sin(theta), + h + ) + raise ValueError('Wrong set of parameters.' + 'Type of coordinate system is defined') + + @classmethod + def _rotation_trans_equations(cls, matrix, equations): + """ + Returns the transformation equations obtained from rotation matrix. + + Parameters + ========== + + matrix : Matrix + Rotation matrix + + equations : tuple + Transformation equations + + """ + return tuple(matrix * Matrix(equations)) + + @property + def origin(self): + return self._origin + + def base_vectors(self): + return self._base_vectors + + def base_scalars(self): + return self._base_scalars + + def lame_coefficients(self): + return self._lame_coefficients + + def transformation_to_parent(self): + return self._transformation_lambda(*self.base_scalars()) + + def transformation_from_parent(self): + if self._parent is None: + raise ValueError("no parent coordinate system, use " + "`transformation_from_parent_function()`") + return self._transformation_from_parent_lambda( + *self._parent.base_scalars()) + + def transformation_from_parent_function(self): + return self._transformation_from_parent_lambda + + def rotation_matrix(self, other): + """ + Returns the direction cosine matrix(DCM), also known as the + 'rotation matrix' of this coordinate system with respect to + another system. + + If v_a is a vector defined in system 'A' (in matrix format) + and v_b is the same vector defined in system 'B', then + v_a = A.rotation_matrix(B) * v_b. + + A SymPy Matrix is returned. + + Parameters + ========== + + other : CoordSys3D + The system which the DCM is generated to. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q1 = symbols('q1') + >>> N = CoordSys3D('N') + >>> A = N.orient_new_axis('A', q1, N.i) + >>> N.rotation_matrix(A) + Matrix([ + [1, 0, 0], + [0, cos(q1), -sin(q1)], + [0, sin(q1), cos(q1)]]) + + """ + from sympy.vector.functions import _path + if not isinstance(other, CoordSys3D): + raise TypeError(str(other) + + " is not a CoordSys3D") + # Handle special cases + if other == self: + return eye(3) + elif other == self._parent: + return self._parent_rotation_matrix + elif other._parent == self: + return other._parent_rotation_matrix.T + # Else, use tree to calculate position + rootindex, path = _path(self, other) + result = eye(3) + for i in range(rootindex): + result *= path[i]._parent_rotation_matrix + for i in range(rootindex + 1, len(path)): + result *= path[i]._parent_rotation_matrix.T + return result + + @cacheit + def position_wrt(self, other): + """ + Returns the position vector of the origin of this coordinate + system with respect to another Point/CoordSys3D. + + Parameters + ========== + + other : Point/CoordSys3D + If other is a Point, the position of this system's origin + wrt it is returned. If its an instance of CoordSyRect, + the position wrt its origin is returned. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> N1 = N.locate_new('N1', 10 * N.i) + >>> N.position_wrt(N1) + (-10)*N.i + + """ + return self.origin.position_wrt(other) + + def scalar_map(self, other): + """ + Returns a dictionary which expresses the coordinate variables + (base scalars) of this frame in terms of the variables of + otherframe. + + Parameters + ========== + + otherframe : CoordSys3D + The other system to map the variables to. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import Symbol + >>> A = CoordSys3D('A') + >>> q = Symbol('q') + >>> B = A.orient_new_axis('B', q, A.k) + >>> A.scalar_map(B) + {A.x: B.x*cos(q) - B.y*sin(q), A.y: B.x*sin(q) + B.y*cos(q), A.z: B.z} + + """ + + origin_coords = tuple(self.position_wrt(other).to_matrix(other)) + relocated_scalars = [x - origin_coords[i] + for i, x in enumerate(other.base_scalars())] + + vars_matrix = (self.rotation_matrix(other) * + Matrix(relocated_scalars)) + return {x: trigsimp(vars_matrix[i]) + for i, x in enumerate(self.base_scalars())} + + def locate_new(self, name, position, vector_names=None, + variable_names=None): + """ + Returns a CoordSys3D with its origin located at the given + position wrt this coordinate system's origin. + + Parameters + ========== + + name : str + The name of the new CoordSys3D instance. + + position : Vector + The position vector of the new system's origin wrt this + one. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> A = CoordSys3D('A') + >>> B = A.locate_new('B', 10 * A.i) + >>> B.origin.position_wrt(A.origin) + 10*A.i + + """ + if variable_names is None: + variable_names = self._variable_names + if vector_names is None: + vector_names = self._vector_names + + return CoordSys3D(name, location=position, + vector_names=vector_names, + variable_names=variable_names, + parent=self) + + def orient_new(self, name, orienters, location=None, + vector_names=None, variable_names=None): + """ + Creates a new CoordSys3D oriented in the user-specified way + with respect to this system. + + Please refer to the documentation of the orienter classes + for more information about the orientation procedure. + + Parameters + ========== + + name : str + The name of the new CoordSys3D instance. + + orienters : iterable/Orienter + An Orienter or an iterable of Orienters for orienting the + new coordinate system. + If an Orienter is provided, it is applied to get the new + system. + If an iterable is provided, the orienters will be applied + in the order in which they appear in the iterable. + + location : Vector(optional) + The location of the new coordinate system's origin wrt this + system's origin. If not specified, the origins are taken to + be coincident. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + >>> N = CoordSys3D('N') + + Using an AxisOrienter + + >>> from sympy.vector import AxisOrienter + >>> axis_orienter = AxisOrienter(q1, N.i + 2 * N.j) + >>> A = N.orient_new('A', (axis_orienter, )) + + Using a BodyOrienter + + >>> from sympy.vector import BodyOrienter + >>> body_orienter = BodyOrienter(q1, q2, q3, '123') + >>> B = N.orient_new('B', (body_orienter, )) + + Using a SpaceOrienter + + >>> from sympy.vector import SpaceOrienter + >>> space_orienter = SpaceOrienter(q1, q2, q3, '312') + >>> C = N.orient_new('C', (space_orienter, )) + + Using a QuaternionOrienter + + >>> from sympy.vector import QuaternionOrienter + >>> q_orienter = QuaternionOrienter(q0, q1, q2, q3) + >>> D = N.orient_new('D', (q_orienter, )) + """ + if variable_names is None: + variable_names = self._variable_names + if vector_names is None: + vector_names = self._vector_names + + if isinstance(orienters, Orienter): + if isinstance(orienters, AxisOrienter): + final_matrix = orienters.rotation_matrix(self) + else: + final_matrix = orienters.rotation_matrix() + # TODO: trigsimp is needed here so that the matrix becomes + # canonical (scalar_map also calls trigsimp; without this, you can + # end up with the same CoordinateSystem that compares differently + # due to a differently formatted matrix). However, this is + # probably not so good for performance. + final_matrix = trigsimp(final_matrix) + else: + final_matrix = Matrix(eye(3)) + for orienter in orienters: + if isinstance(orienter, AxisOrienter): + final_matrix *= orienter.rotation_matrix(self) + else: + final_matrix *= orienter.rotation_matrix() + + return CoordSys3D(name, rotation_matrix=final_matrix, + vector_names=vector_names, + variable_names=variable_names, + location=location, + parent=self) + + def orient_new_axis(self, name, angle, axis, location=None, + vector_names=None, variable_names=None): + """ + Axis rotation is a rotation about an arbitrary axis by + some angle. The angle is supplied as a SymPy expr scalar, and + the axis is supplied as a Vector. + + Parameters + ========== + + name : string + The name of the new coordinate system + + angle : Expr + The angle by which the new system is to be rotated + + axis : Vector + The axis around which the rotation has to be performed + + location : Vector(optional) + The location of the new coordinate system's origin wrt this + system's origin. If not specified, the origins are taken to + be coincident. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q1 = symbols('q1') + >>> N = CoordSys3D('N') + >>> B = N.orient_new_axis('B', q1, N.i + 2 * N.j) + + """ + if variable_names is None: + variable_names = self._variable_names + if vector_names is None: + vector_names = self._vector_names + + orienter = AxisOrienter(angle, axis) + return self.orient_new(name, orienter, + location=location, + vector_names=vector_names, + variable_names=variable_names) + + def orient_new_body(self, name, angle1, angle2, angle3, + rotation_order, location=None, + vector_names=None, variable_names=None): + """ + Body orientation takes this coordinate system through three + successive simple rotations. + + Body fixed rotations include both Euler Angles and + Tait-Bryan Angles, see https://en.wikipedia.org/wiki/Euler_angles. + + Parameters + ========== + + name : string + The name of the new coordinate system + + angle1, angle2, angle3 : Expr + Three successive angles to rotate the coordinate system by + + rotation_order : string + String defining the order of axes for rotation + + location : Vector(optional) + The location of the new coordinate system's origin wrt this + system's origin. If not specified, the origins are taken to + be coincident. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q1, q2, q3 = symbols('q1 q2 q3') + >>> N = CoordSys3D('N') + + A 'Body' fixed rotation is described by three angles and + three body-fixed rotation axes. To orient a coordinate system D + with respect to N, each sequential rotation is always about + the orthogonal unit vectors fixed to D. For example, a '123' + rotation will specify rotations about N.i, then D.j, then + D.k. (Initially, D.i is same as N.i) + Therefore, + + >>> D = N.orient_new_body('D', q1, q2, q3, '123') + + is same as + + >>> D = N.orient_new_axis('D', q1, N.i) + >>> D = D.orient_new_axis('D', q2, D.j) + >>> D = D.orient_new_axis('D', q3, D.k) + + Acceptable rotation orders are of length 3, expressed in XYZ or + 123, and cannot have a rotation about about an axis twice in a row. + + >>> B = N.orient_new_body('B', q1, q2, q3, '123') + >>> B = N.orient_new_body('B', q1, q2, 0, 'ZXZ') + >>> B = N.orient_new_body('B', 0, 0, 0, 'XYX') + + """ + + orienter = BodyOrienter(angle1, angle2, angle3, rotation_order) + return self.orient_new(name, orienter, + location=location, + vector_names=vector_names, + variable_names=variable_names) + + def orient_new_space(self, name, angle1, angle2, angle3, + rotation_order, location=None, + vector_names=None, variable_names=None): + """ + Space rotation is similar to Body rotation, but the rotations + are applied in the opposite order. + + Parameters + ========== + + name : string + The name of the new coordinate system + + angle1, angle2, angle3 : Expr + Three successive angles to rotate the coordinate system by + + rotation_order : string + String defining the order of axes for rotation + + location : Vector(optional) + The location of the new coordinate system's origin wrt this + system's origin. If not specified, the origins are taken to + be coincident. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + See Also + ======== + + CoordSys3D.orient_new_body : method to orient via Euler + angles + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q1, q2, q3 = symbols('q1 q2 q3') + >>> N = CoordSys3D('N') + + To orient a coordinate system D with respect to N, each + sequential rotation is always about N's orthogonal unit vectors. + For example, a '123' rotation will specify rotations about + N.i, then N.j, then N.k. + Therefore, + + >>> D = N.orient_new_space('D', q1, q2, q3, '312') + + is same as + + >>> B = N.orient_new_axis('B', q1, N.i) + >>> C = B.orient_new_axis('C', q2, N.j) + >>> D = C.orient_new_axis('D', q3, N.k) + + """ + + orienter = SpaceOrienter(angle1, angle2, angle3, rotation_order) + return self.orient_new(name, orienter, + location=location, + vector_names=vector_names, + variable_names=variable_names) + + def orient_new_quaternion(self, name, q0, q1, q2, q3, location=None, + vector_names=None, variable_names=None): + """ + Quaternion orientation orients the new CoordSys3D with + Quaternions, defined as a finite rotation about lambda, a unit + vector, by some amount theta. + + This orientation is described by four parameters: + + q0 = cos(theta/2) + + q1 = lambda_x sin(theta/2) + + q2 = lambda_y sin(theta/2) + + q3 = lambda_z sin(theta/2) + + Quaternion does not take in a rotation order. + + Parameters + ========== + + name : string + The name of the new coordinate system + + q0, q1, q2, q3 : Expr + The quaternions to rotate the coordinate system by + + location : Vector(optional) + The location of the new coordinate system's origin wrt this + system's origin. If not specified, the origins are taken to + be coincident. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + >>> N = CoordSys3D('N') + >>> B = N.orient_new_quaternion('B', q0, q1, q2, q3) + + """ + + orienter = QuaternionOrienter(q0, q1, q2, q3) + return self.orient_new(name, orienter, + location=location, + vector_names=vector_names, + variable_names=variable_names) + + def create_new(self, name, transformation, variable_names=None, vector_names=None): + """ + Returns a CoordSys3D which is connected to self by transformation. + + Parameters + ========== + + name : str + The name of the new CoordSys3D instance. + + transformation : Lambda, Tuple, str + Transformation defined by transformation equations or chosen + from predefined ones. + + vector_names, variable_names : iterable(optional) + Iterables of 3 strings each, with custom names for base + vectors and base scalars of the new system respectively. + Used for simple str printing. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> a = CoordSys3D('a') + >>> b = a.create_new('b', transformation='spherical') + >>> b.transformation_to_parent() + (b.r*sin(b.theta)*cos(b.phi), b.r*sin(b.phi)*sin(b.theta), b.r*cos(b.theta)) + >>> b.transformation_from_parent() + (sqrt(a.x**2 + a.y**2 + a.z**2), acos(a.z/sqrt(a.x**2 + a.y**2 + a.z**2)), atan2(a.y, a.x)) + + """ + return CoordSys3D(name, parent=self, transformation=transformation, + variable_names=variable_names, vector_names=vector_names) + + def __init__(self, name, location=None, rotation_matrix=None, + parent=None, vector_names=None, variable_names=None, + latex_vects=None, pretty_vects=None, latex_scalars=None, + pretty_scalars=None, transformation=None): + # Dummy initializer for setting docstring + pass + + __init__.__doc__ = __new__.__doc__ + + @staticmethod + def _compose_rotation_and_translation(rot, translation, parent): + r = lambda x, y, z: CoordSys3D._rotation_trans_equations(rot, (x, y, z)) + if parent is None: + return r + + dx, dy, dz = [translation.dot(i) for i in parent.base_vectors()] + t = lambda x, y, z: ( + x + dx, + y + dy, + z + dz, + ) + return lambda x, y, z: t(*r(x, y, z)) + + +def _check_strings(arg_name, arg): + errorstr = arg_name + " must be an iterable of 3 string-types" + if len(arg) != 3: + raise ValueError(errorstr) + for s in arg: + if not isinstance(s, str): + raise TypeError(errorstr) + + +# Delayed import to avoid cyclic import problems: +from sympy.vector.vector import BaseVector diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/deloperator.py b/.venv/lib/python3.13/site-packages/sympy/vector/deloperator.py new file mode 100644 index 0000000000000000000000000000000000000000..51c3c0caf42b5e5d372bd65907d8bae2bd563562 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/deloperator.py @@ -0,0 +1,121 @@ +from sympy.core import Basic +from sympy.vector.operators import gradient, divergence, curl + + +class Del(Basic): + """ + Represents the vector differential operator, usually represented in + mathematical expressions as the 'nabla' symbol. + """ + + def __new__(cls): + obj = super().__new__(cls) + obj._name = "delop" + return obj + + def gradient(self, scalar_field, doit=False): + """ + Returns the gradient of the given scalar field, as a + Vector instance. + + Parameters + ========== + + scalar_field : SymPy expression + The scalar field to calculate the gradient of. + + doit : bool + If True, the result is returned after calling .doit() on + each component. Else, the returned expression contains + Derivative instances + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Del + >>> C = CoordSys3D('C') + >>> delop = Del() + >>> delop.gradient(9) + 0 + >>> delop(C.x*C.y*C.z).doit() + C.y*C.z*C.i + C.x*C.z*C.j + C.x*C.y*C.k + + """ + + return gradient(scalar_field, doit=doit) + + __call__ = gradient + __call__.__doc__ = gradient.__doc__ + + def dot(self, vect, doit=False): + """ + Represents the dot product between this operator and a given + vector - equal to the divergence of the vector field. + + Parameters + ========== + + vect : Vector + The vector whose divergence is to be calculated. + + doit : bool + If True, the result is returned after calling .doit() on + each component. Else, the returned expression contains + Derivative instances + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Del + >>> delop = Del() + >>> C = CoordSys3D('C') + >>> delop.dot(C.x*C.i) + Derivative(C.x, C.x) + >>> v = C.x*C.y*C.z * (C.i + C.j + C.k) + >>> (delop & v).doit() + C.x*C.y + C.x*C.z + C.y*C.z + + """ + return divergence(vect, doit=doit) + + __and__ = dot + __and__.__doc__ = dot.__doc__ + + def cross(self, vect, doit=False): + """ + Represents the cross product between this operator and a given + vector - equal to the curl of the vector field. + + Parameters + ========== + + vect : Vector + The vector whose curl is to be calculated. + + doit : bool + If True, the result is returned after calling .doit() on + each component. Else, the returned expression contains + Derivative instances + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Del + >>> C = CoordSys3D('C') + >>> delop = Del() + >>> v = C.x*C.y*C.z * (C.i + C.j + C.k) + >>> delop.cross(v, doit = True) + (-C.x*C.y + C.x*C.z)*C.i + (C.x*C.y - C.y*C.z)*C.j + + (-C.x*C.z + C.y*C.z)*C.k + >>> (delop ^ C.i).doit() + 0 + + """ + + return curl(vect, doit=doit) + + __xor__ = cross + __xor__.__doc__ = cross.__doc__ + + def _sympystr(self, printer): + return self._name diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/dyadic.py b/.venv/lib/python3.13/site-packages/sympy/vector/dyadic.py new file mode 100644 index 0000000000000000000000000000000000000000..980c6e6dad90ac095b7bd6d4228f507a7831b39f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/dyadic.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +from sympy.vector.basisdependent import (BasisDependent, BasisDependentAdd, + BasisDependentMul, BasisDependentZero) +from sympy.core import S, Pow +from sympy.core.expr import AtomicExpr +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +import sympy.vector + + +class Dyadic(BasisDependent): + """ + Super class for all Dyadic-classes. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Dyadic_tensor + .. [2] Kane, T., Levinson, D. Dynamics Theory and Applications. 1985 + McGraw-Hill + + """ + + _op_priority = 13.0 + + _expr_type: type[Dyadic] + _mul_func: type[Dyadic] + _add_func: type[Dyadic] + _zero_func: type[Dyadic] + _base_func: type[Dyadic] + zero: DyadicZero + + @property + def components(self): + """ + Returns the components of this dyadic in the form of a + Python dictionary mapping BaseDyadic instances to the + corresponding measure numbers. + + """ + # The '_components' attribute is defined according to the + # subclass of Dyadic the instance belongs to. + return self._components + + def dot(self, other): + """ + Returns the dot product(also called inner product) of this + Dyadic, with another Dyadic or Vector. + If 'other' is a Dyadic, this returns a Dyadic. Else, it returns + a Vector (unless an error is encountered). + + Parameters + ========== + + other : Dyadic/Vector + The other Dyadic or Vector to take the inner product with + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> D1 = N.i.outer(N.j) + >>> D2 = N.j.outer(N.j) + >>> D1.dot(D2) + (N.i|N.j) + >>> D1.dot(N.j) + N.i + + """ + + Vector = sympy.vector.Vector + if isinstance(other, BasisDependentZero): + return Vector.zero + elif isinstance(other, Vector): + outvec = Vector.zero + for k, v in self.components.items(): + vect_dot = k.args[1].dot(other) + outvec += vect_dot * v * k.args[0] + return outvec + elif isinstance(other, Dyadic): + outdyad = Dyadic.zero + for k1, v1 in self.components.items(): + for k2, v2 in other.components.items(): + vect_dot = k1.args[1].dot(k2.args[0]) + outer_product = k1.args[0].outer(k2.args[1]) + outdyad += vect_dot * v1 * v2 * outer_product + return outdyad + else: + raise TypeError("Inner product is not defined for " + + str(type(other)) + " and Dyadics.") + + def __and__(self, other): + return self.dot(other) + + __and__.__doc__ = dot.__doc__ + + def cross(self, other): + """ + Returns the cross product between this Dyadic, and a Vector, as a + Vector instance. + + Parameters + ========== + + other : Vector + The Vector that we are crossing this Dyadic with + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> d = N.i.outer(N.i) + >>> d.cross(N.j) + (N.i|N.k) + + """ + + Vector = sympy.vector.Vector + if other == Vector.zero: + return Dyadic.zero + elif isinstance(other, Vector): + outdyad = Dyadic.zero + for k, v in self.components.items(): + cross_product = k.args[1].cross(other) + outer = k.args[0].outer(cross_product) + outdyad += v * outer + return outdyad + else: + raise TypeError(str(type(other)) + " not supported for " + + "cross with dyadics") + + def __xor__(self, other): + return self.cross(other) + + __xor__.__doc__ = cross.__doc__ + + def to_matrix(self, system, second_system=None): + """ + Returns the matrix form of the dyadic with respect to one or two + coordinate systems. + + Parameters + ========== + + system : CoordSys3D + The coordinate system that the rows and columns of the matrix + correspond to. If a second system is provided, this + only corresponds to the rows of the matrix. + second_system : CoordSys3D, optional, default=None + The coordinate system that the columns of the matrix correspond + to. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> v = N.i + 2*N.j + >>> d = v.outer(N.i) + >>> d.to_matrix(N) + Matrix([ + [1, 0, 0], + [2, 0, 0], + [0, 0, 0]]) + >>> from sympy import Symbol + >>> q = Symbol('q') + >>> P = N.orient_new_axis('P', q, N.k) + >>> d.to_matrix(N, P) + Matrix([ + [ cos(q), -sin(q), 0], + [2*cos(q), -2*sin(q), 0], + [ 0, 0, 0]]) + + """ + + if second_system is None: + second_system = system + + return Matrix([i.dot(self).dot(j) for i in system for j in + second_system]).reshape(3, 3) + + def _div_helper(one, other): + """ Helper for division involving dyadics """ + if isinstance(one, Dyadic) and isinstance(other, Dyadic): + raise TypeError("Cannot divide two dyadics") + elif isinstance(one, Dyadic): + return DyadicMul(one, Pow(other, S.NegativeOne)) + else: + raise TypeError("Cannot divide by a dyadic") + + +class BaseDyadic(Dyadic, AtomicExpr): + """ + Class to denote a base dyadic tensor component. + """ + + def __new__(cls, vector1, vector2): + Vector = sympy.vector.Vector + BaseVector = sympy.vector.BaseVector + VectorZero = sympy.vector.VectorZero + # Verify arguments + if not isinstance(vector1, (BaseVector, VectorZero)) or \ + not isinstance(vector2, (BaseVector, VectorZero)): + raise TypeError("BaseDyadic cannot be composed of non-base " + + "vectors") + # Handle special case of zero vector + elif vector1 == Vector.zero or vector2 == Vector.zero: + return Dyadic.zero + # Initialize instance + obj = super().__new__(cls, vector1, vector2) + obj._base_instance = obj + obj._measure_number = 1 + obj._components = {obj: S.One} + obj._sys = vector1._sys + obj._pretty_form = ('(' + vector1._pretty_form + '|' + + vector2._pretty_form + ')') + obj._latex_form = (r'\left(' + vector1._latex_form + r"{\middle|}" + + vector2._latex_form + r'\right)') + + return obj + + def _sympystr(self, printer): + return "({}|{})".format( + printer._print(self.args[0]), printer._print(self.args[1])) + + def _sympyrepr(self, printer): + return "BaseDyadic({}, {})".format( + printer._print(self.args[0]), printer._print(self.args[1])) + + +class DyadicMul(BasisDependentMul, Dyadic): + """ Products of scalars and BaseDyadics """ + + def __new__(cls, *args, **options): + obj = BasisDependentMul.__new__(cls, *args, **options) + return obj + + @property + def base_dyadic(self): + """ The BaseDyadic involved in the product. """ + return self._base_instance + + @property + def measure_number(self): + """ The scalar expression involved in the definition of + this DyadicMul. + """ + return self._measure_number + + +class DyadicAdd(BasisDependentAdd, Dyadic): + """ Class to hold dyadic sums """ + + def __new__(cls, *args, **options): + obj = BasisDependentAdd.__new__(cls, *args, **options) + return obj + + def _sympystr(self, printer): + items = list(self.components.items()) + items.sort(key=lambda x: x[0].__str__()) + return " + ".join(printer._print(k * v) for k, v in items) + + +class DyadicZero(BasisDependentZero, Dyadic): + """ + Class to denote a zero dyadic + """ + + _op_priority = 13.1 + _pretty_form = '(0|0)' + _latex_form = r'(\mathbf{\hat{0}}|\mathbf{\hat{0}})' + + def __new__(cls): + obj = BasisDependentZero.__new__(cls) + return obj + + +Dyadic._expr_type = Dyadic +Dyadic._mul_func = DyadicMul +Dyadic._add_func = DyadicAdd +Dyadic._zero_func = DyadicZero +Dyadic._base_func = BaseDyadic +Dyadic.zero = DyadicZero() diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/functions.py b/.venv/lib/python3.13/site-packages/sympy/vector/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..b78df8ae2e182f3e571ca7fa8bfabd39bf99d26e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/functions.py @@ -0,0 +1,513 @@ +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.deloperator import Del +from sympy.vector.scalar import BaseScalar +from sympy.vector.vector import Vector, BaseVector +from sympy.vector.operators import gradient, curl, divergence +from sympy.core.function import diff +from sympy.core.singleton import S +from sympy.integrals.integrals import integrate +from sympy.core import sympify +from sympy.vector.dyadic import Dyadic + + +def express(expr, system, system2=None, variables=False): + """ + Global function for 'express' functionality. + + Re-expresses a Vector, Dyadic or scalar(sympyfiable) in the given + coordinate system. + + If 'variables' is True, then the coordinate variables (base scalars) + of other coordinate systems present in the vector/scalar field or + dyadic are also substituted in terms of the base scalars of the + given system. + + Parameters + ========== + + expr : Vector/Dyadic/scalar(sympyfiable) + The expression to re-express in CoordSys3D 'system' + + system: CoordSys3D + The coordinate system the expr is to be expressed in + + system2: CoordSys3D + The other coordinate system required for re-expression + (only for a Dyadic Expr) + + variables : boolean + Specifies whether to substitute the coordinate variables present + in expr, in terms of those of parameter system + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import Symbol, cos, sin + >>> N = CoordSys3D('N') + >>> q = Symbol('q') + >>> B = N.orient_new_axis('B', q, N.k) + >>> from sympy.vector import express + >>> express(B.i, N) + (cos(q))*N.i + (sin(q))*N.j + >>> express(N.x, B, variables=True) + B.x*cos(q) - B.y*sin(q) + >>> d = N.i.outer(N.i) + >>> express(d, B, N) == (cos(q))*(B.i|N.i) + (-sin(q))*(B.j|N.i) + True + + """ + + if expr in (0, Vector.zero): + return expr + + if not isinstance(system, CoordSys3D): + raise TypeError("system should be a CoordSys3D \ + instance") + + if isinstance(expr, Vector): + if system2 is not None: + raise ValueError("system2 should not be provided for \ + Vectors") + # Given expr is a Vector + if variables: + # If variables attribute is True, substitute + # the coordinate variables in the Vector + system_list = {x.system for x in expr.atoms(BaseScalar, BaseVector)} - {system} + subs_dict = {} + for f in system_list: + subs_dict.update(f.scalar_map(system)) + expr = expr.subs(subs_dict) + # Re-express in this coordinate system + outvec = Vector.zero + parts = expr.separate() + for x in parts: + if x != system: + temp = system.rotation_matrix(x) * parts[x].to_matrix(x) + outvec += matrix_to_vector(temp, system) + else: + outvec += parts[x] + return outvec + + elif isinstance(expr, Dyadic): + if system2 is None: + system2 = system + if not isinstance(system2, CoordSys3D): + raise TypeError("system2 should be a CoordSys3D \ + instance") + outdyad = Dyadic.zero + var = variables + for k, v in expr.components.items(): + outdyad += (express(v, system, variables=var) * + (express(k.args[0], system, variables=var) | + express(k.args[1], system2, variables=var))) + + return outdyad + + else: + if system2 is not None: + raise ValueError("system2 should not be provided for \ + Vectors") + if variables: + # Given expr is a scalar field + system_set = set() + expr = sympify(expr) + # Substitute all the coordinate variables + for x in expr.atoms(BaseScalar): + if x.system != system: + system_set.add(x.system) + subs_dict = {} + for f in system_set: + subs_dict.update(f.scalar_map(system)) + return expr.subs(subs_dict) + return expr + + +def directional_derivative(field, direction_vector): + """ + Returns the directional derivative of a scalar or vector field computed + along a given vector in coordinate system which parameters are expressed. + + Parameters + ========== + + field : Vector or Scalar + The scalar or vector field to compute the directional derivative of + + direction_vector : Vector + The vector to calculated directional derivative along them. + + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, directional_derivative + >>> R = CoordSys3D('R') + >>> f1 = R.x*R.y*R.z + >>> v1 = 3*R.i + 4*R.j + R.k + >>> directional_derivative(f1, v1) + R.x*R.y + 4*R.x*R.z + 3*R.y*R.z + >>> f2 = 5*R.x**2*R.z + >>> directional_derivative(f2, v1) + 5*R.x**2 + 30*R.x*R.z + + """ + from sympy.vector.operators import _get_coord_systems + coord_sys = _get_coord_systems(field) + if len(coord_sys) > 0: + # TODO: This gets a random coordinate system in case of multiple ones: + coord_sys = next(iter(coord_sys)) + field = express(field, coord_sys, variables=True) + i, j, k = coord_sys.base_vectors() + x, y, z = coord_sys.base_scalars() + out = Vector.dot(direction_vector, i) * diff(field, x) + out += Vector.dot(direction_vector, j) * diff(field, y) + out += Vector.dot(direction_vector, k) * diff(field, z) + if out == 0 and isinstance(field, Vector): + out = Vector.zero + return out + elif isinstance(field, Vector): + return Vector.zero + else: + return S.Zero + + +def laplacian(expr): + """ + Return the laplacian of the given field computed in terms of + the base scalars of the given coordinate system. + + Parameters + ========== + + expr : SymPy Expr or Vector + expr denotes a scalar or vector field. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, laplacian + >>> R = CoordSys3D('R') + >>> f = R.x**2*R.y**5*R.z + >>> laplacian(f) + 20*R.x**2*R.y**3*R.z + 2*R.y**5*R.z + >>> f = R.x**2*R.i + R.y**3*R.j + R.z**4*R.k + >>> laplacian(f) + 2*R.i + 6*R.y*R.j + 12*R.z**2*R.k + + """ + + delop = Del() + if expr.is_Vector: + return (gradient(divergence(expr)) - curl(curl(expr))).doit() + return delop.dot(delop(expr)).doit() + + +def is_conservative(field): + """ + Checks if a field is conservative. + + Parameters + ========== + + field : Vector + The field to check for conservative property + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.vector import is_conservative + >>> R = CoordSys3D('R') + >>> is_conservative(R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k) + True + >>> is_conservative(R.z*R.j) + False + + """ + + # Field is conservative irrespective of system + # Take the first coordinate system in the result of the + # separate method of Vector + if not isinstance(field, Vector): + raise TypeError("field should be a Vector") + if field == Vector.zero: + return True + return curl(field).simplify() == Vector.zero + + +def is_solenoidal(field): + """ + Checks if a field is solenoidal. + + Parameters + ========== + + field : Vector + The field to check for solenoidal property + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.vector import is_solenoidal + >>> R = CoordSys3D('R') + >>> is_solenoidal(R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k) + True + >>> is_solenoidal(R.y * R.j) + False + + """ + + # Field is solenoidal irrespective of system + # Take the first coordinate system in the result of the + # separate method in Vector + if not isinstance(field, Vector): + raise TypeError("field should be a Vector") + if field == Vector.zero: + return True + return divergence(field).simplify() is S.Zero + + +def scalar_potential(field, coord_sys): + """ + Returns the scalar potential function of a field in a given + coordinate system (without the added integration constant). + + Parameters + ========== + + field : Vector + The vector field whose scalar potential function is to be + calculated + + coord_sys : CoordSys3D + The coordinate system to do the calculation in + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.vector import scalar_potential, gradient + >>> R = CoordSys3D('R') + >>> scalar_potential(R.k, R) == R.z + True + >>> scalar_field = 2*R.x**2*R.y*R.z + >>> grad_field = gradient(scalar_field) + >>> scalar_potential(grad_field, R) + 2*R.x**2*R.y*R.z + + """ + + # Check whether field is conservative + if not is_conservative(field): + raise ValueError("Field is not conservative") + if field == Vector.zero: + return S.Zero + # Express the field exntirely in coord_sys + # Substitute coordinate variables also + if not isinstance(coord_sys, CoordSys3D): + raise TypeError("coord_sys must be a CoordSys3D") + field = express(field, coord_sys, variables=True) + dimensions = coord_sys.base_vectors() + scalars = coord_sys.base_scalars() + # Calculate scalar potential function + temp_function = integrate(field.dot(dimensions[0]), scalars[0]) + for i, dim in enumerate(dimensions[1:]): + partial_diff = diff(temp_function, scalars[i + 1]) + partial_diff = field.dot(dim) - partial_diff + temp_function += integrate(partial_diff, scalars[i + 1]) + return temp_function + + +def scalar_potential_difference(field, coord_sys, point1, point2): + """ + Returns the scalar potential difference between two points in a + certain coordinate system, wrt a given field. + + If a scalar field is provided, its values at the two points are + considered. If a conservative vector field is provided, the values + of its scalar potential function at the two points are used. + + Returns (potential at point2) - (potential at point1) + + The position vectors of the two Points are calculated wrt the + origin of the coordinate system provided. + + Parameters + ========== + + field : Vector/Expr + The field to calculate wrt + + coord_sys : CoordSys3D + The coordinate system to do the calculations in + + point1 : Point + The initial Point in given coordinate system + + position2 : Point + The second Point in the given coordinate system + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.vector import scalar_potential_difference + >>> R = CoordSys3D('R') + >>> P = R.origin.locate_new('P', R.x*R.i + R.y*R.j + R.z*R.k) + >>> vectfield = 4*R.x*R.y*R.i + 2*R.x**2*R.j + >>> scalar_potential_difference(vectfield, R, R.origin, P) + 2*R.x**2*R.y + >>> Q = R.origin.locate_new('O', 3*R.i + R.j + 2*R.k) + >>> scalar_potential_difference(vectfield, R, P, Q) + -2*R.x**2*R.y + 18 + + """ + + if not isinstance(coord_sys, CoordSys3D): + raise TypeError("coord_sys must be a CoordSys3D") + if isinstance(field, Vector): + # Get the scalar potential function + scalar_fn = scalar_potential(field, coord_sys) + else: + # Field is a scalar + scalar_fn = field + # Express positions in required coordinate system + origin = coord_sys.origin + position1 = express(point1.position_wrt(origin), coord_sys, + variables=True) + position2 = express(point2.position_wrt(origin), coord_sys, + variables=True) + # Get the two positions as substitution dicts for coordinate variables + subs_dict1 = {} + subs_dict2 = {} + scalars = coord_sys.base_scalars() + for i, x in enumerate(coord_sys.base_vectors()): + subs_dict1[scalars[i]] = x.dot(position1) + subs_dict2[scalars[i]] = x.dot(position2) + return scalar_fn.subs(subs_dict2) - scalar_fn.subs(subs_dict1) + + +def matrix_to_vector(matrix, system): + """ + Converts a vector in matrix form to a Vector instance. + + It is assumed that the elements of the Matrix represent the + measure numbers of the components of the vector along basis + vectors of 'system'. + + Parameters + ========== + + matrix : SymPy Matrix, Dimensions: (3, 1) + The matrix to be converted to a vector + + system : CoordSys3D + The coordinate system the vector is to be defined in + + Examples + ======== + + >>> from sympy import ImmutableMatrix as Matrix + >>> m = Matrix([1, 2, 3]) + >>> from sympy.vector import CoordSys3D, matrix_to_vector + >>> C = CoordSys3D('C') + >>> v = matrix_to_vector(m, C) + >>> v + C.i + 2*C.j + 3*C.k + >>> v.to_matrix(C) == m + True + + """ + + outvec = Vector.zero + vects = system.base_vectors() + for i, x in enumerate(matrix): + outvec += x * vects[i] + return outvec + + +def _path(from_object, to_object): + """ + Calculates the 'path' of objects starting from 'from_object' + to 'to_object', along with the index of the first common + ancestor in the tree. + + Returns (index, list) tuple. + """ + + if from_object._root != to_object._root: + raise ValueError("No connecting path found between " + + str(from_object) + " and " + str(to_object)) + + other_path = [] + obj = to_object + while obj._parent is not None: + other_path.append(obj) + obj = obj._parent + other_path.append(obj) + object_set = set(other_path) + from_path = [] + obj = from_object + while obj not in object_set: + from_path.append(obj) + obj = obj._parent + index = len(from_path) + from_path.extend(other_path[other_path.index(obj)::-1]) + return index, from_path + + +def orthogonalize(*vlist, orthonormal=False): + """ + Takes a sequence of independent vectors and orthogonalizes them + using the Gram - Schmidt process. Returns a list of + orthogonal or orthonormal vectors. + + Parameters + ========== + + vlist : sequence of independent vectors to be made orthogonal. + + orthonormal : Optional parameter + Set to True if the vectors returned should be + orthonormal. + Default: False + + Examples + ======== + + >>> from sympy.vector.coordsysrect import CoordSys3D + >>> from sympy.vector.functions import orthogonalize + >>> C = CoordSys3D('C') + >>> i, j, k = C.base_vectors() + >>> v1 = i + 2*j + >>> v2 = 2*i + 3*j + >>> orthogonalize(v1, v2) + [C.i + 2*C.j, 2/5*C.i + (-1/5)*C.j] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gram-Schmidt_process + + """ + + if not all(isinstance(vec, Vector) for vec in vlist): + raise TypeError('Each element must be of Type Vector') + + ortho_vlist = [] + for i, term in enumerate(vlist): + for j in range(i): + term -= ortho_vlist[j].projection(vlist[i]) + # TODO : The following line introduces a performance issue + # and needs to be changed once a good solution for issue #10279 is + # found. + if term.equals(Vector.zero): + raise ValueError("Vector set not linearly independent") + ortho_vlist.append(term) + + if orthonormal: + ortho_vlist = [vec.normalize() for vec in ortho_vlist] + + return ortho_vlist diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/implicitregion.py b/.venv/lib/python3.13/site-packages/sympy/vector/implicitregion.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2d55a1be8b1eaca71d08b632a94886a2b0269c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/implicitregion.py @@ -0,0 +1,506 @@ +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys.polytools import gcd +from sympy.sets.sets import Complement +from sympy.core import Basic, Tuple, diff, expand, Eq, Integer +from sympy.core.sorting import ordered +from sympy.core.symbol import _symbol +from sympy.solvers import solveset, nonlinsolve, diophantine +from sympy.polys import total_degree +from sympy.geometry import Point +from sympy.ntheory.factor_ import core + + +class ImplicitRegion(Basic): + """ + Represents an implicit region in space. + + Examples + ======== + + >>> from sympy import Eq + >>> from sympy.abc import x, y, z, t + >>> from sympy.vector import ImplicitRegion + + >>> ImplicitRegion((x, y), x**2 + y**2 - 4) + ImplicitRegion((x, y), x**2 + y**2 - 4) + >>> ImplicitRegion((x, y), Eq(y*x, 1)) + ImplicitRegion((x, y), x*y - 1) + + >>> parabola = ImplicitRegion((x, y), y**2 - 4*x) + >>> parabola.degree + 2 + >>> parabola.equation + -4*x + y**2 + >>> parabola.rational_parametrization(t) + (4/t**2, 4/t) + + >>> r = ImplicitRegion((x, y, z), Eq(z, x**2 + y**2)) + >>> r.variables + (x, y, z) + >>> r.singular_points() + EmptySet + >>> r.regular_point() + (-10, -10, 200) + + Parameters + ========== + + variables : tuple to map variables in implicit equation to base scalars. + + equation : An expression or Eq denoting the implicit equation of the region. + + """ + def __new__(cls, variables, equation): + if not isinstance(variables, Tuple): + variables = Tuple(*variables) + + if isinstance(equation, Eq): + equation = equation.lhs - equation.rhs + + return super().__new__(cls, variables, equation) + + @property + def variables(self): + return self.args[0] + + @property + def equation(self): + return self.args[1] + + @property + def degree(self): + return total_degree(self.equation) + + def regular_point(self): + """ + Returns a point on the implicit region. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.vector import ImplicitRegion + >>> circle = ImplicitRegion((x, y), (x + 2)**2 + (y - 3)**2 - 16) + >>> circle.regular_point() + (-2, -1) + >>> parabola = ImplicitRegion((x, y), x**2 - 4*y) + >>> parabola.regular_point() + (0, 0) + >>> r = ImplicitRegion((x, y, z), (x + y + z)**4) + >>> r.regular_point() + (-10, -10, 20) + + References + ========== + + - Erik Hillgarter, "Rational Points on Conics", Diploma Thesis, RISC-Linz, + J. Kepler Universitat Linz, 1996. Available: + https://www3.risc.jku.at/publications/download/risc_1355/Rational%20Points%20on%20Conics.pdf + + """ + equation = self.equation + + if len(self.variables) == 1: + return (list(solveset(equation, self.variables[0], domain=S.Reals))[0],) + elif len(self.variables) == 2: + + if self.degree == 2: + coeffs = a, b, c, d, e, f = conic_coeff(self.variables, equation) + + if b**2 == 4*a*c: + x_reg, y_reg = self._regular_point_parabola(*coeffs) + else: + x_reg, y_reg = self._regular_point_ellipse(*coeffs) + return x_reg, y_reg + + if len(self.variables) == 3: + x, y, z = self.variables + + for x_reg in range(-10, 10): + for y_reg in range(-10, 10): + if not solveset(equation.subs({x: x_reg, y: y_reg}), self.variables[2], domain=S.Reals).is_empty: + return (x_reg, y_reg, list(solveset(equation.subs({x: x_reg, y: y_reg})))[0]) + + if len(self.singular_points()) != 0: + return list[self.singular_points()][0] + + raise NotImplementedError() + + def _regular_point_parabola(self, a, b, c, d, e, f): + ok = (a, d) != (0, 0) and (c, e) != (0, 0) and b**2 == 4*a*c and (a, c) != (0, 0) + + if not ok: + raise ValueError("Rational Point on the conic does not exist") + + if a != 0: + d_dash, f_dash = (4*a*e - 2*b*d, 4*a*f - d**2) + if d_dash != 0: + y_reg = -f_dash/d_dash + x_reg = -(d + b*y_reg)/(2*a) + else: + ok = False + elif c != 0: + d_dash, f_dash = (4*c*d - 2*b*e, 4*c*f - e**2) + if d_dash != 0: + x_reg = -f_dash/d_dash + y_reg = -(e + b*x_reg)/(2*c) + else: + ok = False + + if ok: + return x_reg, y_reg + else: + raise ValueError("Rational Point on the conic does not exist") + + def _regular_point_ellipse(self, a, b, c, d, e, f): + D = 4*a*c - b**2 + ok = D + + if not ok: + raise ValueError("Rational Point on the conic does not exist") + + if a == 0 and c == 0: + K = -1 + L = 4*(d*e - b*f) + elif c != 0: + K = D + L = 4*c**2*d**2 - 4*b*c*d*e + 4*a*c*e**2 + 4*b**2*c*f - 16*a*c**2*f + else: + K = D + L = 4*a**2*e**2 - 4*b*a*d*e + 4*b**2*a*f + + ok = L != 0 and not(K > 0 and L < 0) + if not ok: + raise ValueError("Rational Point on the conic does not exist") + + K = Rational(K).limit_denominator(10**12) + L = Rational(L).limit_denominator(10**12) + + k1, k2 = K.p, K.q + l1, l2 = L.p, L.q + g = gcd(k2, l2) + + a1 = (l2*k2)/g + b1 = (k1*l2)/g + c1 = -(l1*k2)/g + a2 = sign(a1)*core(abs(a1), 2) + r1 = sqrt(a1/a2) + b2 = sign(b1)*core(abs(b1), 2) + r2 = sqrt(b1/b2) + c2 = sign(c1)*core(abs(c1), 2) + r3 = sqrt(c1/c2) + + g = gcd(gcd(a2, b2), c2) + a2 = a2/g + b2 = b2/g + c2 = c2/g + + g1 = gcd(a2, b2) + a2 = a2/g1 + b2 = b2/g1 + c2 = c2*g1 + + g2 = gcd(a2,c2) + a2 = a2/g2 + b2 = b2*g2 + c2 = c2/g2 + + g3 = gcd(b2, c2) + a2 = a2*g3 + b2 = b2/g3 + c2 = c2/g3 + + x, y, z = symbols("x y z") + eq = a2*x**2 + b2*y**2 + c2*z**2 + + solutions = diophantine(eq) + + if len(solutions) == 0: + raise ValueError("Rational Point on the conic does not exist") + + flag = False + for sol in solutions: + syms = Tuple(*sol).free_symbols + rep = dict.fromkeys(syms, 3) + sol_z = sol[2] + + if sol_z == 0: + flag = True + continue + + if not isinstance(sol_z, (int, Integer)): + syms_z = sol_z.free_symbols + + if len(syms_z) == 1: + p = next(iter(syms_z)) + p_values = Complement(S.Integers, solveset(Eq(sol_z, 0), p, S.Integers)) + rep[p] = next(iter(p_values)) + + if len(syms_z) == 2: + p, q = list(ordered(syms_z)) + + for i in S.Integers: + subs_sol_z = sol_z.subs(p, i) + q_values = Complement(S.Integers, solveset(Eq(subs_sol_z, 0), q, S.Integers)) + + if not q_values.is_empty: + rep[p] = i + rep[q] = next(iter(q_values)) + break + + if len(syms) != 0: + x, y, z = tuple(s.subs(rep) for s in sol) + else: + x, y, z = sol + flag = False + break + + if flag: + raise ValueError("Rational Point on the conic does not exist") + + x = (x*g3)/r1 + y = (y*g2)/r2 + z = (z*g1)/r3 + x = x/z + y = y/z + + if a == 0 and c == 0: + x_reg = (x + y - 2*e)/(2*b) + y_reg = (x - y - 2*d)/(2*b) + elif c != 0: + x_reg = (x - 2*d*c + b*e)/K + y_reg = (y - b*x_reg - e)/(2*c) + else: + y_reg = (x - 2*e*a + b*d)/K + x_reg = (y - b*y_reg - d)/(2*a) + + return x_reg, y_reg + + def singular_points(self): + """ + Returns a set of singular points of the region. + + The singular points are those points on the region + where all partial derivatives vanish. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.vector import ImplicitRegion + >>> I = ImplicitRegion((x, y), (y-1)**2 -x**3 + 2*x**2 -x) + >>> I.singular_points() + {(1, 1)} + + """ + eq_list = [self.equation] + for var in self.variables: + eq_list += [diff(self.equation, var)] + + return nonlinsolve(eq_list, list(self.variables)) + + def multiplicity(self, point): + """ + Returns the multiplicity of a singular point on the region. + + A singular point (x,y) of region is said to be of multiplicity m + if all the partial derivatives off to order m - 1 vanish there. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.vector import ImplicitRegion + >>> I = ImplicitRegion((x, y, z), x**2 + y**3 - z**4) + >>> I.singular_points() + {(0, 0, 0)} + >>> I.multiplicity((0, 0, 0)) + 2 + + """ + if isinstance(point, Point): + point = point.args + + modified_eq = self.equation + + for i, var in enumerate(self.variables): + modified_eq = modified_eq.subs(var, var + point[i]) + modified_eq = expand(modified_eq) + + if len(modified_eq.args) != 0: + terms = modified_eq.args + m = min(total_degree(term) for term in terms) + else: + terms = modified_eq + m = total_degree(terms) + + return m + + def rational_parametrization(self, parameters=('t', 's'), reg_point=None): + """ + Returns the rational parametrization of implicit region. + + Examples + ======== + + >>> from sympy import Eq + >>> from sympy.abc import x, y, z, s, t + >>> from sympy.vector import ImplicitRegion + + >>> parabola = ImplicitRegion((x, y), y**2 - 4*x) + >>> parabola.rational_parametrization() + (4/t**2, 4/t) + + >>> circle = ImplicitRegion((x, y), Eq(x**2 + y**2, 4)) + >>> circle.rational_parametrization() + (4*t/(t**2 + 1), 4*t**2/(t**2 + 1) - 2) + + >>> I = ImplicitRegion((x, y), x**3 + x**2 - y**2) + >>> I.rational_parametrization() + (t**2 - 1, t*(t**2 - 1)) + + >>> cubic_curve = ImplicitRegion((x, y), x**3 + x**2 - y**2) + >>> cubic_curve.rational_parametrization(parameters=(t)) + (t**2 - 1, t*(t**2 - 1)) + + >>> sphere = ImplicitRegion((x, y, z), x**2 + y**2 + z**2 - 4) + >>> sphere.rational_parametrization(parameters=(t, s)) + (-2 + 4/(s**2 + t**2 + 1), 4*s/(s**2 + t**2 + 1), 4*t/(s**2 + t**2 + 1)) + + For some conics, regular_points() is unable to find a point on curve. + To calulcate the parametric representation in such cases, user need + to determine a point on the region and pass it using reg_point. + + >>> c = ImplicitRegion((x, y), (x - 1/2)**2 + (y)**2 - (1/4)**2) + >>> c.rational_parametrization(reg_point=(3/4, 0)) + (0.75 - 0.5/(t**2 + 1), -0.5*t/(t**2 + 1)) + + References + ========== + + - Christoph M. Hoffmann, "Conversion Methods between Parametric and + Implicit Curves and Surfaces", Purdue e-Pubs, 1990. Available: + https://docs.lib.purdue.edu/cgi/viewcontent.cgi?article=1827&context=cstech + + """ + equation = self.equation + degree = self.degree + + if degree == 1: + if len(self.variables) == 1: + return (equation,) + elif len(self.variables) == 2: + x, y = self.variables + y_par = list(solveset(equation, y))[0] + return x, y_par + else: + raise NotImplementedError() + + point = () + + # Finding the (n - 1) fold point of the monoid of degree + if degree == 2: + # For degree 2 curves, either a regular point or a singular point can be used. + if reg_point is not None: + # Using point provided by the user as regular point + point = reg_point + else: + if len(self.singular_points()) != 0: + point = list(self.singular_points())[0] + else: + point = self.regular_point() + + if len(self.singular_points()) != 0: + singular_points = self.singular_points() + for spoint in singular_points: + syms = Tuple(*spoint).free_symbols + rep = dict.fromkeys(syms, 2) + + if len(syms) != 0: + spoint = tuple(s.subs(rep) for s in spoint) + + if self.multiplicity(spoint) == degree - 1: + point = spoint + break + + if len(point) == 0: + # The region in not a monoid + raise NotImplementedError() + + modified_eq = equation + + # Shifting the region such that fold point moves to origin + for i, var in enumerate(self.variables): + modified_eq = modified_eq.subs(var, var + point[i]) + modified_eq = expand(modified_eq) + + hn = hn_1 = 0 + for term in modified_eq.args: + if total_degree(term) == degree: + hn += term + else: + hn_1 += term + + hn_1 = -1*hn_1 + + if not isinstance(parameters, tuple): + parameters = (parameters,) + + if len(self.variables) == 2: + + parameter1 = parameters[0] + if parameter1 == 's': + # To avoid name conflict between parameters + s = _symbol('s_', real=True) + else: + s = _symbol('s', real=True) + t = _symbol(parameter1, real=True) + + hn = hn.subs({self.variables[0]: s, self.variables[1]: t}) + hn_1 = hn_1.subs({self.variables[0]: s, self.variables[1]: t}) + + x_par = (s*(hn_1/hn)).subs(s, 1) + point[0] + y_par = (t*(hn_1/hn)).subs(s, 1) + point[1] + + return x_par, y_par + + elif len(self.variables) == 3: + + parameter1, parameter2 = parameters + if 'r' in parameters: + # To avoid name conflict between parameters + r = _symbol('r_', real=True) + else: + r = _symbol('r', real=True) + s = _symbol(parameter2, real=True) + t = _symbol(parameter1, real=True) + + hn = hn.subs({self.variables[0]: r, self.variables[1]: s, self.variables[2]: t}) + hn_1 = hn_1.subs({self.variables[0]: r, self.variables[1]: s, self.variables[2]: t}) + + x_par = (r*(hn_1/hn)).subs(r, 1) + point[0] + y_par = (s*(hn_1/hn)).subs(r, 1) + point[1] + z_par = (t*(hn_1/hn)).subs(r, 1) + point[2] + + return x_par, y_par, z_par + + raise NotImplementedError() + +def conic_coeff(variables, equation): + if total_degree(equation) != 2: + raise ValueError() + x = variables[0] + y = variables[1] + + equation = expand(equation) + a = equation.coeff(x**2) + b = equation.coeff(x*y) + c = equation.coeff(y**2) + d = equation.coeff(x, 1).coeff(y, 0) + e = equation.coeff(y, 1).coeff(x, 0) + f = equation.coeff(x, 0).coeff(y, 0) + return a, b, c, d, e, f diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/integrals.py b/.venv/lib/python3.13/site-packages/sympy/vector/integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..a6451c182f214b20b1105eb0a4dc243455c9d126 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/integrals.py @@ -0,0 +1,206 @@ +from sympy.core import Basic, diff +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.matrices import Matrix +from sympy.integrals import Integral, integrate +from sympy.geometry.entity import GeometryEntity +from sympy.simplify.simplify import simplify +from sympy.utilities.iterables import topological_sort +from sympy.vector import (CoordSys3D, Vector, ParametricRegion, + parametric_region_list, ImplicitRegion) +from sympy.vector.operators import _get_coord_systems + + +class ParametricIntegral(Basic): + """ + Represents integral of a scalar or vector field + over a Parametric Region + + Examples + ======== + + >>> from sympy import cos, sin, pi + >>> from sympy.vector import CoordSys3D, ParametricRegion, ParametricIntegral + >>> from sympy.abc import r, t, theta, phi + + >>> C = CoordSys3D('C') + >>> curve = ParametricRegion((3*t - 2, t + 1), (t, 1, 2)) + >>> ParametricIntegral(C.x, curve) + 5*sqrt(10)/2 + >>> length = ParametricIntegral(1, curve) + >>> length + sqrt(10) + >>> semisphere = ParametricRegion((2*sin(phi)*cos(theta), 2*sin(phi)*sin(theta), 2*cos(phi)),\ + (theta, 0, 2*pi), (phi, 0, pi/2)) + >>> ParametricIntegral(C.z, semisphere) + 8*pi + + >>> ParametricIntegral(C.j + C.k, ParametricRegion((r*cos(theta), r*sin(theta)), r, theta)) + 0 + + """ + + def __new__(cls, field, parametricregion): + + coord_set = _get_coord_systems(field) + + if len(coord_set) == 0: + coord_sys = CoordSys3D('C') + elif len(coord_set) > 1: + raise ValueError + else: + coord_sys = next(iter(coord_set)) + + if parametricregion.dimensions == 0: + return S.Zero + + base_vectors = coord_sys.base_vectors() + base_scalars = coord_sys.base_scalars() + + parametricfield = field + + r = Vector.zero + for i in range(len(parametricregion.definition)): + r += base_vectors[i]*parametricregion.definition[i] + + if len(coord_set) != 0: + for i in range(len(parametricregion.definition)): + parametricfield = parametricfield.subs(base_scalars[i], parametricregion.definition[i]) + + if parametricregion.dimensions == 1: + parameter = parametricregion.parameters[0] + + r_diff = diff(r, parameter) + lower, upper = parametricregion.limits[parameter][0], parametricregion.limits[parameter][1] + + if isinstance(parametricfield, Vector): + integrand = simplify(r_diff.dot(parametricfield)) + else: + integrand = simplify(r_diff.magnitude()*parametricfield) + + result = integrate(integrand, (parameter, lower, upper)) + + elif parametricregion.dimensions == 2: + u, v = cls._bounds_case(parametricregion.parameters, parametricregion.limits) + + r_u = diff(r, u) + r_v = diff(r, v) + normal_vector = simplify(r_u.cross(r_v)) + + if isinstance(parametricfield, Vector): + integrand = parametricfield.dot(normal_vector) + else: + integrand = parametricfield*normal_vector.magnitude() + + integrand = simplify(integrand) + + lower_u, upper_u = parametricregion.limits[u][0], parametricregion.limits[u][1] + lower_v, upper_v = parametricregion.limits[v][0], parametricregion.limits[v][1] + + result = integrate(integrand, (u, lower_u, upper_u), (v, lower_v, upper_v)) + + else: + variables = cls._bounds_case(parametricregion.parameters, parametricregion.limits) + coeff = Matrix(parametricregion.definition).jacobian(variables).det() + integrand = simplify(parametricfield*coeff) + + l = [(var, parametricregion.limits[var][0], parametricregion.limits[var][1]) for var in variables] + result = integrate(integrand, *l) + + if not isinstance(result, Integral): + return result + else: + return super().__new__(cls, field, parametricregion) + + @classmethod + def _bounds_case(cls, parameters, limits): + + V = list(limits.keys()) + E = [] + + for p in V: + lower_p = limits[p][0] + upper_p = limits[p][1] + + lower_p = lower_p.atoms() + upper_p = upper_p.atoms() + E.extend((p, q) for q in V if p != q and + (lower_p.issuperset({q}) or upper_p.issuperset({q}))) + + if not E: + return parameters + else: + return topological_sort((V, E), key=default_sort_key) + + @property + def field(self): + return self.args[0] + + @property + def parametricregion(self): + return self.args[1] + + +def vector_integrate(field, *region): + """ + Compute the integral of a vector/scalar field + over a a region or a set of parameters. + + Examples + ======== + >>> from sympy.vector import CoordSys3D, ParametricRegion, vector_integrate + >>> from sympy.abc import x, y, t + >>> C = CoordSys3D('C') + + >>> region = ParametricRegion((t, t**2), (t, 1, 5)) + >>> vector_integrate(C.x*C.i, region) + 12 + + Integrals over some objects of geometry module can also be calculated. + + >>> from sympy.geometry import Point, Circle, Triangle + >>> c = Circle(Point(0, 2), 5) + >>> vector_integrate(C.x**2 + C.y**2, c) + 290*pi + >>> triangle = Triangle(Point(-2, 3), Point(2, 3), Point(0, 5)) + >>> vector_integrate(3*C.x**2*C.y*C.i + C.j, triangle) + -8 + + Integrals over some simple implicit regions can be computed. But in most cases, + it takes too long to compute over them. This is due to the expressions of parametric + representation becoming large. + + >>> from sympy.vector import ImplicitRegion + >>> c2 = ImplicitRegion((x, y), (x - 2)**2 + (y - 1)**2 - 9) + >>> vector_integrate(1, c2) + 6*pi + + Integral of fields with respect to base scalars: + + >>> vector_integrate(12*C.y**3, (C.y, 1, 3)) + 240 + >>> vector_integrate(C.x**2*C.z, C.x) + C.x**3*C.z/3 + >>> vector_integrate(C.x*C.i - C.y*C.k, C.x) + (Integral(C.x, C.x))*C.i + (Integral(-C.y, C.x))*C.k + >>> _.doit() + C.x**2/2*C.i + (-C.x*C.y)*C.k + + """ + if len(region) == 1: + if isinstance(region[0], ParametricRegion): + return ParametricIntegral(field, region[0]) + + if isinstance(region[0], ImplicitRegion): + region = parametric_region_list(region[0])[0] + return vector_integrate(field, region) + + if isinstance(region[0], GeometryEntity): + regions_list = parametric_region_list(region[0]) + + result = 0 + for reg in regions_list: + result += vector_integrate(field, reg) + return result + + return integrate(field, *region) diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/kind.py b/.venv/lib/python3.13/site-packages/sympy/vector/kind.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c04896b34c9c92c3fb340d94985df859e5877d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/kind.py @@ -0,0 +1,67 @@ +#sympy.vector.kind + +from sympy.core.kind import Kind, _NumberKind, NumberKind +from sympy.core.mul import Mul + +class VectorKind(Kind): + """ + Kind for all vector objects in SymPy. + + Parameters + ========== + + element_kind : Kind + Kind of the element. Default is + :class:`sympy.core.kind.NumberKind`, + which means that the vector contains only numbers. + + Examples + ======== + + Any instance of Vector class has kind ``VectorKind``: + + >>> from sympy.vector.coordsysrect import CoordSys3D + >>> Sys = CoordSys3D('Sys') + >>> Sys.i.kind + VectorKind(NumberKind) + + Operations between instances of Vector keep also have the kind ``VectorKind``: + + >>> from sympy.core.add import Add + >>> v1 = Sys.i * 2 + Sys.j * 3 + Sys.k * 4 + >>> v2 = Sys.i * Sys.x + Sys.j * Sys.y + Sys.k * Sys.z + >>> v1.kind + VectorKind(NumberKind) + >>> v2.kind + VectorKind(NumberKind) + >>> Add(v1, v2).kind + VectorKind(NumberKind) + + Subclasses of Vector also have the kind ``VectorKind``, such as + Cross, VectorAdd, VectorMul or VectorZero. + + See Also + ======== + + sympy.core.kind.Kind + sympy.matrices.kind.MatrixKind + + """ + def __new__(cls, element_kind=NumberKind): + obj = super().__new__(cls, element_kind) + obj.element_kind = element_kind + return obj + + def __repr__(self): + return "VectorKind(%s)" % self.element_kind + +@Mul._kind_dispatcher.register(_NumberKind, VectorKind) +def num_vec_mul(k1, k2): + """ + The result of a multiplication between a number and a Vector should be of VectorKind. + The element kind is selected by recursive dispatching. + """ + if not isinstance(k2, VectorKind): + k1, k2 = k2, k1 + elemk = Mul._kind_dispatcher(k1, k2.element_kind) + return VectorKind(elemk) diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/operators.py b/.venv/lib/python3.13/site-packages/sympy/vector/operators.py new file mode 100644 index 0000000000000000000000000000000000000000..3ca42d20302f972cef66e0b4a35ac75606b2da94 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/operators.py @@ -0,0 +1,335 @@ +import collections +from sympy.core.expr import Expr +from sympy.core import sympify, S, preorder_traversal +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.vector import Vector, VectorMul, VectorAdd, Cross, Dot +from sympy.core.function import Derivative +from sympy.core.add import Add +from sympy.core.mul import Mul + + +def _get_coord_systems(expr): + g = preorder_traversal(expr) + ret = set() + for i in g: + if isinstance(i, CoordSys3D): + ret.add(i) + g.skip() + return frozenset(ret) + + +def _split_mul_args_wrt_coordsys(expr): + d = collections.defaultdict(lambda: S.One) + for i in expr.args: + d[_get_coord_systems(i)] *= i + return list(d.values()) + + +class Gradient(Expr): + """ + Represents unevaluated Gradient. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Gradient + >>> R = CoordSys3D('R') + >>> s = R.x*R.y*R.z + >>> Gradient(s) + Gradient(R.x*R.y*R.z) + + """ + + def __new__(cls, expr): + expr = sympify(expr) + obj = Expr.__new__(cls, expr) + obj._expr = expr + return obj + + def doit(self, **hints): + return gradient(self._expr, doit=True) + + +class Divergence(Expr): + """ + Represents unevaluated Divergence. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Divergence + >>> R = CoordSys3D('R') + >>> v = R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k + >>> Divergence(v) + Divergence(R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k) + + """ + + def __new__(cls, expr): + expr = sympify(expr) + obj = Expr.__new__(cls, expr) + obj._expr = expr + return obj + + def doit(self, **hints): + return divergence(self._expr, doit=True) + + +class Curl(Expr): + """ + Represents unevaluated Curl. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Curl + >>> R = CoordSys3D('R') + >>> v = R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k + >>> Curl(v) + Curl(R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k) + + """ + + def __new__(cls, expr): + expr = sympify(expr) + obj = Expr.__new__(cls, expr) + obj._expr = expr + return obj + + def doit(self, **hints): + return curl(self._expr, doit=True) + + +def curl(vect, doit=True): + """ + Returns the curl of a vector field computed wrt the base scalars + of the given coordinate system. + + Parameters + ========== + + vect : Vector + The vector operand + + doit : bool + If True, the result is returned after calling .doit() on + each component. Else, the returned expression contains + Derivative instances + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, curl + >>> R = CoordSys3D('R') + >>> v1 = R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k + >>> curl(v1) + 0 + >>> v2 = R.x*R.y*R.z*R.i + >>> curl(v2) + R.x*R.y*R.j + (-R.x*R.z)*R.k + + """ + + coord_sys = _get_coord_systems(vect) + + if len(coord_sys) == 0: + return Vector.zero + elif len(coord_sys) == 1: + coord_sys = next(iter(coord_sys)) + i, j, k = coord_sys.base_vectors() + x, y, z = coord_sys.base_scalars() + h1, h2, h3 = coord_sys.lame_coefficients() + vectx = vect.dot(i) + vecty = vect.dot(j) + vectz = vect.dot(k) + outvec = Vector.zero + outvec += (Derivative(vectz * h3, y) - + Derivative(vecty * h2, z)) * i / (h2 * h3) + outvec += (Derivative(vectx * h1, z) - + Derivative(vectz * h3, x)) * j / (h1 * h3) + outvec += (Derivative(vecty * h2, x) - + Derivative(vectx * h1, y)) * k / (h2 * h1) + + if doit: + return outvec.doit() + return outvec + else: + if isinstance(vect, (Add, VectorAdd)): + from sympy.vector import express + try: + cs = next(iter(coord_sys)) + args = [express(i, cs, variables=True) for i in vect.args] + except ValueError: + args = vect.args + return VectorAdd.fromiter(curl(i, doit=doit) for i in args) + elif isinstance(vect, (Mul, VectorMul)): + vector = [i for i in vect.args if isinstance(i, (Vector, Cross, Gradient))][0] + scalar = Mul.fromiter(i for i in vect.args if not isinstance(i, (Vector, Cross, Gradient))) + res = Cross(gradient(scalar), vector).doit() + scalar*curl(vector, doit=doit) + if doit: + return res.doit() + return res + elif isinstance(vect, (Cross, Curl, Gradient)): + return Curl(vect) + else: + raise ValueError("Invalid argument for curl") + + +def divergence(vect, doit=True): + """ + Returns the divergence of a vector field computed wrt the base + scalars of the given coordinate system. + + Parameters + ========== + + vector : Vector + The vector operand + + doit : bool + If True, the result is returned after calling .doit() on + each component. Else, the returned expression contains + Derivative instances + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, divergence + >>> R = CoordSys3D('R') + >>> v1 = R.x*R.y*R.z * (R.i+R.j+R.k) + + >>> divergence(v1) + R.x*R.y + R.x*R.z + R.y*R.z + >>> v2 = 2*R.y*R.z*R.j + >>> divergence(v2) + 2*R.z + + """ + coord_sys = _get_coord_systems(vect) + if len(coord_sys) == 0: + return S.Zero + elif len(coord_sys) == 1: + if isinstance(vect, (Cross, Curl, Gradient)): + return Divergence(vect) + # TODO: is case of many coord systems, this gets a random one: + coord_sys = next(iter(coord_sys)) + i, j, k = coord_sys.base_vectors() + x, y, z = coord_sys.base_scalars() + h1, h2, h3 = coord_sys.lame_coefficients() + vx = _diff_conditional(vect.dot(i), x, h2, h3) \ + / (h1 * h2 * h3) + vy = _diff_conditional(vect.dot(j), y, h3, h1) \ + / (h1 * h2 * h3) + vz = _diff_conditional(vect.dot(k), z, h1, h2) \ + / (h1 * h2 * h3) + res = vx + vy + vz + if doit: + return res.doit() + return res + else: + if isinstance(vect, (Add, VectorAdd)): + return Add.fromiter(divergence(i, doit=doit) for i in vect.args) + elif isinstance(vect, (Mul, VectorMul)): + vector = [i for i in vect.args if isinstance(i, (Vector, Cross, Gradient))][0] + scalar = Mul.fromiter(i for i in vect.args if not isinstance(i, (Vector, Cross, Gradient))) + res = Dot(vector, gradient(scalar)) + scalar*divergence(vector, doit=doit) + if doit: + return res.doit() + return res + elif isinstance(vect, (Cross, Curl, Gradient)): + return Divergence(vect) + else: + raise ValueError("Invalid argument for divergence") + + +def gradient(scalar_field, doit=True): + """ + Returns the vector gradient of a scalar field computed wrt the + base scalars of the given coordinate system. + + Parameters + ========== + + scalar_field : SymPy Expr + The scalar field to compute the gradient of + + doit : bool + If True, the result is returned after calling .doit() on + each component. Else, the returned expression contains + Derivative instances + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, gradient + >>> R = CoordSys3D('R') + >>> s1 = R.x*R.y*R.z + >>> gradient(s1) + R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k + >>> s2 = 5*R.x**2*R.z + >>> gradient(s2) + 10*R.x*R.z*R.i + 5*R.x**2*R.k + + """ + coord_sys = _get_coord_systems(scalar_field) + + if len(coord_sys) == 0: + return Vector.zero + elif len(coord_sys) == 1: + coord_sys = next(iter(coord_sys)) + h1, h2, h3 = coord_sys.lame_coefficients() + i, j, k = coord_sys.base_vectors() + x, y, z = coord_sys.base_scalars() + vx = Derivative(scalar_field, x) / h1 + vy = Derivative(scalar_field, y) / h2 + vz = Derivative(scalar_field, z) / h3 + + if doit: + return (vx * i + vy * j + vz * k).doit() + return vx * i + vy * j + vz * k + else: + if isinstance(scalar_field, (Add, VectorAdd)): + return VectorAdd.fromiter(gradient(i) for i in scalar_field.args) + if isinstance(scalar_field, (Mul, VectorMul)): + s = _split_mul_args_wrt_coordsys(scalar_field) + return VectorAdd.fromiter(scalar_field / i * gradient(i) for i in s) + return Gradient(scalar_field) + + +class Laplacian(Expr): + """ + Represents unevaluated Laplacian. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Laplacian + >>> R = CoordSys3D('R') + >>> v = 3*R.x**3*R.y**2*R.z**3 + >>> Laplacian(v) + Laplacian(3*R.x**3*R.y**2*R.z**3) + + """ + + def __new__(cls, expr): + expr = sympify(expr) + obj = Expr.__new__(cls, expr) + obj._expr = expr + return obj + + def doit(self, **hints): + from sympy.vector.functions import laplacian + return laplacian(self._expr) + + +def _diff_conditional(expr, base_scalar, coeff_1, coeff_2): + """ + First re-expresses expr in the system that base_scalar belongs to. + If base_scalar appears in the re-expressed form, differentiates + it wrt base_scalar. + Else, returns 0 + """ + from sympy.vector.functions import express + new_expr = express(expr, base_scalar.system, variables=True) + arg = coeff_1 * coeff_2 * new_expr + return Derivative(arg, base_scalar) if arg else S.Zero diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/orienters.py b/.venv/lib/python3.13/site-packages/sympy/vector/orienters.py new file mode 100644 index 0000000000000000000000000000000000000000..0c22089e568bc817c943c1beecebde0fea46b6ae --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/orienters.py @@ -0,0 +1,398 @@ +from sympy.core.basic import Basic +from sympy.core.sympify import sympify +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import (eye, rot_axis1, rot_axis2, rot_axis3) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.core.cache import cacheit +from sympy.core.symbol import Str +import sympy.vector + + +class Orienter(Basic): + """ + Super-class for all orienter classes. + """ + + def rotation_matrix(self): + """ + The rotation matrix corresponding to this orienter + instance. + """ + return self._parent_orient + + +class AxisOrienter(Orienter): + """ + Class to denote an axis orienter. + """ + + def __new__(cls, angle, axis): + if not isinstance(axis, sympy.vector.Vector): + raise TypeError("axis should be a Vector") + angle = sympify(angle) + + obj = super().__new__(cls, angle, axis) + obj._angle = angle + obj._axis = axis + + return obj + + def __init__(self, angle, axis): + """ + Axis rotation is a rotation about an arbitrary axis by + some angle. The angle is supplied as a SymPy expr scalar, and + the axis is supplied as a Vector. + + Parameters + ========== + + angle : Expr + The angle by which the new system is to be rotated + + axis : Vector + The axis around which the rotation has to be performed + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q1 = symbols('q1') + >>> N = CoordSys3D('N') + >>> from sympy.vector import AxisOrienter + >>> orienter = AxisOrienter(q1, N.i + 2 * N.j) + >>> B = N.orient_new('B', (orienter, )) + + """ + # Dummy initializer for docstrings + pass + + @cacheit + def rotation_matrix(self, system): + """ + The rotation matrix corresponding to this orienter + instance. + + Parameters + ========== + + system : CoordSys3D + The coordinate system wrt which the rotation matrix + is to be computed + """ + + axis = sympy.vector.express(self.axis, system).normalize() + axis = axis.to_matrix(system) + theta = self.angle + parent_orient = ((eye(3) - axis * axis.T) * cos(theta) + + Matrix([[0, -axis[2], axis[1]], + [axis[2], 0, -axis[0]], + [-axis[1], axis[0], 0]]) * sin(theta) + + axis * axis.T) + parent_orient = parent_orient.T + return parent_orient + + @property + def angle(self): + return self._angle + + @property + def axis(self): + return self._axis + + +class ThreeAngleOrienter(Orienter): + """ + Super-class for Body and Space orienters. + """ + + def __new__(cls, angle1, angle2, angle3, rot_order): + if isinstance(rot_order, Str): + rot_order = rot_order.name + + approved_orders = ('123', '231', '312', '132', '213', + '321', '121', '131', '212', '232', + '313', '323', '') + original_rot_order = rot_order + rot_order = str(rot_order).upper() + if not (len(rot_order) == 3): + raise TypeError('rot_order should be a str of length 3') + rot_order = [i.replace('X', '1') for i in rot_order] + rot_order = [i.replace('Y', '2') for i in rot_order] + rot_order = [i.replace('Z', '3') for i in rot_order] + rot_order = ''.join(rot_order) + if rot_order not in approved_orders: + raise TypeError('Invalid rot_type parameter') + a1 = int(rot_order[0]) + a2 = int(rot_order[1]) + a3 = int(rot_order[2]) + angle1 = sympify(angle1) + angle2 = sympify(angle2) + angle3 = sympify(angle3) + if cls._in_order: + parent_orient = (_rot(a1, angle1) * + _rot(a2, angle2) * + _rot(a3, angle3)) + else: + parent_orient = (_rot(a3, angle3) * + _rot(a2, angle2) * + _rot(a1, angle1)) + parent_orient = parent_orient.T + + obj = super().__new__( + cls, angle1, angle2, angle3, Str(rot_order)) + obj._angle1 = angle1 + obj._angle2 = angle2 + obj._angle3 = angle3 + obj._rot_order = original_rot_order + obj._parent_orient = parent_orient + + return obj + + @property + def angle1(self): + return self._angle1 + + @property + def angle2(self): + return self._angle2 + + @property + def angle3(self): + return self._angle3 + + @property + def rot_order(self): + return self._rot_order + + +class BodyOrienter(ThreeAngleOrienter): + """ + Class to denote a body-orienter. + """ + + _in_order = True + + def __new__(cls, angle1, angle2, angle3, rot_order): + obj = ThreeAngleOrienter.__new__(cls, angle1, angle2, angle3, + rot_order) + return obj + + def __init__(self, angle1, angle2, angle3, rot_order): + """ + Body orientation takes this coordinate system through three + successive simple rotations. + + Body fixed rotations include both Euler Angles and + Tait-Bryan Angles, see https://en.wikipedia.org/wiki/Euler_angles. + + Parameters + ========== + + angle1, angle2, angle3 : Expr + Three successive angles to rotate the coordinate system by + + rotation_order : string + String defining the order of axes for rotation + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, BodyOrienter + >>> from sympy import symbols + >>> q1, q2, q3 = symbols('q1 q2 q3') + >>> N = CoordSys3D('N') + + A 'Body' fixed rotation is described by three angles and + three body-fixed rotation axes. To orient a coordinate system D + with respect to N, each sequential rotation is always about + the orthogonal unit vectors fixed to D. For example, a '123' + rotation will specify rotations about N.i, then D.j, then + D.k. (Initially, D.i is same as N.i) + Therefore, + + >>> body_orienter = BodyOrienter(q1, q2, q3, '123') + >>> D = N.orient_new('D', (body_orienter, )) + + is same as + + >>> from sympy.vector import AxisOrienter + >>> axis_orienter1 = AxisOrienter(q1, N.i) + >>> D = N.orient_new('D', (axis_orienter1, )) + >>> axis_orienter2 = AxisOrienter(q2, D.j) + >>> D = D.orient_new('D', (axis_orienter2, )) + >>> axis_orienter3 = AxisOrienter(q3, D.k) + >>> D = D.orient_new('D', (axis_orienter3, )) + + Acceptable rotation orders are of length 3, expressed in XYZ or + 123, and cannot have a rotation about about an axis twice in a row. + + >>> body_orienter1 = BodyOrienter(q1, q2, q3, '123') + >>> body_orienter2 = BodyOrienter(q1, q2, 0, 'ZXZ') + >>> body_orienter3 = BodyOrienter(0, 0, 0, 'XYX') + + """ + # Dummy initializer for docstrings + pass + + +class SpaceOrienter(ThreeAngleOrienter): + """ + Class to denote a space-orienter. + """ + + _in_order = False + + def __new__(cls, angle1, angle2, angle3, rot_order): + obj = ThreeAngleOrienter.__new__(cls, angle1, angle2, angle3, + rot_order) + return obj + + def __init__(self, angle1, angle2, angle3, rot_order): + """ + Space rotation is similar to Body rotation, but the rotations + are applied in the opposite order. + + Parameters + ========== + + angle1, angle2, angle3 : Expr + Three successive angles to rotate the coordinate system by + + rotation_order : string + String defining the order of axes for rotation + + See Also + ======== + + BodyOrienter : Orienter to orient systems wrt Euler angles. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, SpaceOrienter + >>> from sympy import symbols + >>> q1, q2, q3 = symbols('q1 q2 q3') + >>> N = CoordSys3D('N') + + To orient a coordinate system D with respect to N, each + sequential rotation is always about N's orthogonal unit vectors. + For example, a '123' rotation will specify rotations about + N.i, then N.j, then N.k. + Therefore, + + >>> space_orienter = SpaceOrienter(q1, q2, q3, '312') + >>> D = N.orient_new('D', (space_orienter, )) + + is same as + + >>> from sympy.vector import AxisOrienter + >>> axis_orienter1 = AxisOrienter(q1, N.i) + >>> B = N.orient_new('B', (axis_orienter1, )) + >>> axis_orienter2 = AxisOrienter(q2, N.j) + >>> C = B.orient_new('C', (axis_orienter2, )) + >>> axis_orienter3 = AxisOrienter(q3, N.k) + >>> D = C.orient_new('C', (axis_orienter3, )) + + """ + # Dummy initializer for docstrings + pass + + +class QuaternionOrienter(Orienter): + """ + Class to denote a quaternion-orienter. + """ + + def __new__(cls, q0, q1, q2, q3): + q0 = sympify(q0) + q1 = sympify(q1) + q2 = sympify(q2) + q3 = sympify(q3) + parent_orient = (Matrix([[q0 ** 2 + q1 ** 2 - q2 ** 2 - + q3 ** 2, + 2 * (q1 * q2 - q0 * q3), + 2 * (q0 * q2 + q1 * q3)], + [2 * (q1 * q2 + q0 * q3), + q0 ** 2 - q1 ** 2 + + q2 ** 2 - q3 ** 2, + 2 * (q2 * q3 - q0 * q1)], + [2 * (q1 * q3 - q0 * q2), + 2 * (q0 * q1 + q2 * q3), + q0 ** 2 - q1 ** 2 - + q2 ** 2 + q3 ** 2]])) + parent_orient = parent_orient.T + + obj = super().__new__(cls, q0, q1, q2, q3) + obj._q0 = q0 + obj._q1 = q1 + obj._q2 = q2 + obj._q3 = q3 + obj._parent_orient = parent_orient + + return obj + + def __init__(self, angle1, angle2, angle3, rot_order): + """ + Quaternion orientation orients the new CoordSys3D with + Quaternions, defined as a finite rotation about lambda, a unit + vector, by some amount theta. + + This orientation is described by four parameters: + + q0 = cos(theta/2) + + q1 = lambda_x sin(theta/2) + + q2 = lambda_y sin(theta/2) + + q3 = lambda_z sin(theta/2) + + Quaternion does not take in a rotation order. + + Parameters + ========== + + q0, q1, q2, q3 : Expr + The quaternions to rotate the coordinate system by + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy import symbols + >>> q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + >>> N = CoordSys3D('N') + >>> from sympy.vector import QuaternionOrienter + >>> q_orienter = QuaternionOrienter(q0, q1, q2, q3) + >>> B = N.orient_new('B', (q_orienter, )) + + """ + # Dummy initializer for docstrings + pass + + @property + def q0(self): + return self._q0 + + @property + def q1(self): + return self._q1 + + @property + def q2(self): + return self._q2 + + @property + def q3(self): + return self._q3 + + +def _rot(axis, angle): + """DCM for simple axis 1, 2 or 3 rotations. """ + if axis == 1: + return Matrix(rot_axis1(angle).T) + elif axis == 2: + return Matrix(rot_axis2(angle).T) + elif axis == 3: + return Matrix(rot_axis3(angle).T) diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/parametricregion.py b/.venv/lib/python3.13/site-packages/sympy/vector/parametricregion.py new file mode 100644 index 0000000000000000000000000000000000000000..5246769dabe208fe630f7c33b2da3ef4e11b3f67 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/parametricregion.py @@ -0,0 +1,189 @@ +from functools import singledispatch +from sympy.core.numbers import pi +from sympy.functions.elementary.trigonometric import tan +from sympy.simplify import trigsimp +from sympy.core import Basic, Tuple +from sympy.core.symbol import _symbol +from sympy.solvers import solve +from sympy.geometry import Point, Segment, Curve, Ellipse, Polygon +from sympy.vector import ImplicitRegion + + +class ParametricRegion(Basic): + """ + Represents a parametric region in space. + + Examples + ======== + + >>> from sympy import cos, sin, pi + >>> from sympy.abc import r, theta, t, a, b, x, y + >>> from sympy.vector import ParametricRegion + + >>> ParametricRegion((t, t**2), (t, -1, 2)) + ParametricRegion((t, t**2), (t, -1, 2)) + >>> ParametricRegion((x, y), (x, 3, 4), (y, 5, 6)) + ParametricRegion((x, y), (x, 3, 4), (y, 5, 6)) + >>> ParametricRegion((r*cos(theta), r*sin(theta)), (r, -2, 2), (theta, 0, pi)) + ParametricRegion((r*cos(theta), r*sin(theta)), (r, -2, 2), (theta, 0, pi)) + >>> ParametricRegion((a*cos(t), b*sin(t)), t) + ParametricRegion((a*cos(t), b*sin(t)), t) + + >>> circle = ParametricRegion((r*cos(theta), r*sin(theta)), r, (theta, 0, pi)) + >>> circle.parameters + (r, theta) + >>> circle.definition + (r*cos(theta), r*sin(theta)) + >>> circle.limits + {theta: (0, pi)} + + Dimension of a parametric region determines whether a region is a curve, surface + or volume region. It does not represent its dimensions in space. + + >>> circle.dimensions + 1 + + Parameters + ========== + + definition : tuple to define base scalars in terms of parameters. + + bounds : Parameter or a tuple of length 3 to define parameter and corresponding lower and upper bound. + + """ + def __new__(cls, definition, *bounds): + parameters = () + limits = {} + + if not isinstance(bounds, Tuple): + bounds = Tuple(*bounds) + + for bound in bounds: + if isinstance(bound, (tuple, Tuple)): + if len(bound) != 3: + raise ValueError("Tuple should be in the form (parameter, lowerbound, upperbound)") + parameters += (bound[0],) + limits[bound[0]] = (bound[1], bound[2]) + else: + parameters += (bound,) + + if not isinstance(definition, (tuple, Tuple)): + definition = (definition,) + + obj = super().__new__(cls, Tuple(*definition), *bounds) + obj._parameters = parameters + obj._limits = limits + + return obj + + @property + def definition(self): + return self.args[0] + + @property + def limits(self): + return self._limits + + @property + def parameters(self): + return self._parameters + + @property + def dimensions(self): + return len(self.limits) + + +@singledispatch +def parametric_region_list(reg): + """ + Returns a list of ParametricRegion objects representing the geometric region. + + Examples + ======== + + >>> from sympy.abc import t + >>> from sympy.vector import parametric_region_list + >>> from sympy.geometry import Point, Curve, Ellipse, Segment, Polygon + + >>> p = Point(2, 5) + >>> parametric_region_list(p) + [ParametricRegion((2, 5))] + + >>> c = Curve((t**3, 4*t), (t, -3, 4)) + >>> parametric_region_list(c) + [ParametricRegion((t**3, 4*t), (t, -3, 4))] + + >>> e = Ellipse(Point(1, 3), 2, 3) + >>> parametric_region_list(e) + [ParametricRegion((2*cos(t) + 1, 3*sin(t) + 3), (t, 0, 2*pi))] + + >>> s = Segment(Point(1, 3), Point(2, 6)) + >>> parametric_region_list(s) + [ParametricRegion((t + 1, 3*t + 3), (t, 0, 1))] + + >>> p1, p2, p3, p4 = [(0, 1), (2, -3), (5, 3), (-2, 3)] + >>> poly = Polygon(p1, p2, p3, p4) + >>> parametric_region_list(poly) + [ParametricRegion((2*t, 1 - 4*t), (t, 0, 1)), ParametricRegion((3*t + 2, 6*t - 3), (t, 0, 1)),\ + ParametricRegion((5 - 7*t, 3), (t, 0, 1)), ParametricRegion((2*t - 2, 3 - 2*t), (t, 0, 1))] + + """ + raise ValueError("SymPy cannot determine parametric representation of the region.") + + +@parametric_region_list.register(Point) +def _(obj): + return [ParametricRegion(obj.args)] + + +@parametric_region_list.register(Curve) # type: ignore +def _(obj): + definition = obj.arbitrary_point(obj.parameter).args + bounds = obj.limits + return [ParametricRegion(definition, bounds)] + + +@parametric_region_list.register(Ellipse) # type: ignore +def _(obj, parameter='t'): + definition = obj.arbitrary_point(parameter).args + t = _symbol(parameter, real=True) + bounds = (t, 0, 2*pi) + return [ParametricRegion(definition, bounds)] + + +@parametric_region_list.register(Segment) # type: ignore +def _(obj, parameter='t'): + t = _symbol(parameter, real=True) + definition = obj.arbitrary_point(t).args + + for i in range(0, 3): + lower_bound = solve(definition[i] - obj.points[0].args[i], t) + upper_bound = solve(definition[i] - obj.points[1].args[i], t) + + if len(lower_bound) == 1 and len(upper_bound) == 1: + bounds = t, lower_bound[0], upper_bound[0] + break + + definition_tuple = obj.arbitrary_point(parameter).args + return [ParametricRegion(definition_tuple, bounds)] + + +@parametric_region_list.register(Polygon) # type: ignore +def _(obj, parameter='t'): + l = [parametric_region_list(side, parameter)[0] for side in obj.sides] + return l + + +@parametric_region_list.register(ImplicitRegion) # type: ignore +def _(obj, parameters=('t', 's')): + definition = obj.rational_parametrization(parameters) + bounds = [] + + for i in range(len(obj.variables) - 1): + # Each parameter is replaced by its tangent to simplify integration + parameter = _symbol(parameters[i], real=True) + definition = [trigsimp(elem.subs(parameter, tan(parameter/2))) for elem in definition] + bounds.append((parameter, 0, 2*pi),) + + definition = Tuple(*definition) + return [ParametricRegion(definition, *bounds)] diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/point.py b/.venv/lib/python3.13/site-packages/sympy/vector/point.py new file mode 100644 index 0000000000000000000000000000000000000000..442ea4e8edc0e33c4f83f774ea3f11a01725ac3a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/point.py @@ -0,0 +1,148 @@ +from sympy.core.basic import Basic +from sympy.core.symbol import Str +from sympy.vector.vector import Vector +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.functions import _path +from sympy.core.cache import cacheit + + +class Point(Basic): + """ + Represents a point in 3-D space. + """ + + def __new__(cls, name, position=Vector.zero, parent_point=None): + name = str(name) + # Check the args first + if not isinstance(position, Vector): + raise TypeError( + "position should be an instance of Vector, not %s" % type( + position)) + if (not isinstance(parent_point, Point) and + parent_point is not None): + raise TypeError( + "parent_point should be an instance of Point, not %s" % type( + parent_point)) + # Super class construction + if parent_point is None: + obj = super().__new__(cls, Str(name), position) + else: + obj = super().__new__(cls, Str(name), position, parent_point) + # Decide the object parameters + obj._name = name + obj._pos = position + if parent_point is None: + obj._parent = None + obj._root = obj + else: + obj._parent = parent_point + obj._root = parent_point._root + # Return object + return obj + + @cacheit + def position_wrt(self, other): + """ + Returns the position vector of this Point with respect to + another Point/CoordSys3D. + + Parameters + ========== + + other : Point/CoordSys3D + If other is a Point, the position of this Point wrt it is + returned. If its an instance of CoordSyRect, the position + wrt its origin is returned. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> p1 = N.origin.locate_new('p1', 10 * N.i) + >>> N.origin.position_wrt(p1) + (-10)*N.i + + """ + + if (not isinstance(other, Point) and + not isinstance(other, CoordSys3D)): + raise TypeError(str(other) + + "is not a Point or CoordSys3D") + if isinstance(other, CoordSys3D): + other = other.origin + # Handle special cases + if other == self: + return Vector.zero + elif other == self._parent: + return self._pos + elif other._parent == self: + return -1 * other._pos + # Else, use point tree to calculate position + rootindex, path = _path(self, other) + result = Vector.zero + for i in range(rootindex): + result += path[i]._pos + for i in range(rootindex + 1, len(path)): + result -= path[i]._pos + return result + + def locate_new(self, name, position): + """ + Returns a new Point located at the given position wrt this + Point. + Thus, the position vector of the new Point wrt this one will + be equal to the given 'position' parameter. + + Parameters + ========== + + name : str + Name of the new point + + position : Vector + The position vector of the new Point wrt this one + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> p1 = N.origin.locate_new('p1', 10 * N.i) + >>> p1.position_wrt(N.origin) + 10*N.i + + """ + return Point(name, position, self) + + def express_coordinates(self, coordinate_system): + """ + Returns the Cartesian/rectangular coordinates of this point + wrt the origin of the given CoordSys3D instance. + + Parameters + ========== + + coordinate_system : CoordSys3D + The coordinate system to express the coordinates of this + Point in. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> p1 = N.origin.locate_new('p1', 10 * N.i) + >>> p2 = p1.locate_new('p2', 5 * N.j) + >>> p2.express_coordinates(N) + (10, 5, 0) + + """ + + # Determine the position vector + pos_vect = self.position_wrt(coordinate_system.origin) + # Express it in the given coordinate system + return tuple(pos_vect.to_matrix(coordinate_system)) + + def _sympystr(self, printer): + return self._name diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/scalar.py b/.venv/lib/python3.13/site-packages/sympy/vector/scalar.py new file mode 100644 index 0000000000000000000000000000000000000000..bcfb56cf177b9378a24f81ca1e6524fe048a5f94 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/scalar.py @@ -0,0 +1,72 @@ +from sympy.core import AtomicExpr, Symbol, S +from sympy.core.sympify import _sympify +from sympy.printing.pretty.stringpict import prettyForm +from sympy.printing.precedence import PRECEDENCE +from sympy.core.kind import NumberKind + + +class BaseScalar(AtomicExpr): + """ + A coordinate symbol/base scalar. + + Ideally, users should not instantiate this class. + + """ + + kind = NumberKind + + def __new__(cls, index, system, pretty_str=None, latex_str=None): + from sympy.vector.coordsysrect import CoordSys3D + if pretty_str is None: + pretty_str = "x{}".format(index) + elif isinstance(pretty_str, Symbol): + pretty_str = pretty_str.name + if latex_str is None: + latex_str = "x_{}".format(index) + elif isinstance(latex_str, Symbol): + latex_str = latex_str.name + + index = _sympify(index) + system = _sympify(system) + obj = super().__new__(cls, index, system) + if not isinstance(system, CoordSys3D): + raise TypeError("system should be a CoordSys3D") + if index not in range(0, 3): + raise ValueError("Invalid index specified.") + # The _id is used for equating purposes, and for hashing + obj._id = (index, system) + obj._name = obj.name = system._name + '.' + system._variable_names[index] + obj._pretty_form = '' + pretty_str + obj._latex_form = latex_str + obj._system = system + + return obj + + is_commutative = True + is_symbol = True + + @property + def free_symbols(self): + return {self} + + _diff_wrt = True + + def _eval_derivative(self, s): + if self == s: + return S.One + return S.Zero + + def _latex(self, printer=None): + return self._latex_form + + def _pretty(self, printer=None): + return prettyForm(self._pretty_form) + + precedence = PRECEDENCE['Atom'] + + @property + def system(self): + return self._system + + def _sympystr(self, printer): + return self._name diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/vector/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_coordsysrect.py b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_coordsysrect.py new file mode 100644 index 0000000000000000000000000000000000000000..53eb8c89ec1643a71800efe3e370acff3cb6f9c0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_coordsysrect.py @@ -0,0 +1,464 @@ +from sympy.testing.pytest import raises +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.scalar import BaseScalar +from sympy.core.function import expand +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.functions.elementary.hyperbolic import (cosh, sinh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, atan2, cos, sin) +from sympy.matrices.dense import zeros +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.simplify.simplify import simplify +from sympy.vector.functions import express +from sympy.vector.point import Point +from sympy.vector.vector import Vector +from sympy.vector.orienters import (AxisOrienter, BodyOrienter, + SpaceOrienter, QuaternionOrienter) + + +x, y, z = symbols('x y z') +a, b, c, q = symbols('a b c q') +q1, q2, q3, q4 = symbols('q1 q2 q3 q4') + + +def test_func_args(): + A = CoordSys3D('A') + assert A.x.func(*A.x.args) == A.x + expr = 3*A.x + 4*A.y + assert expr.func(*expr.args) == expr + assert A.i.func(*A.i.args) == A.i + v = A.x*A.i + A.y*A.j + A.z*A.k + assert v.func(*v.args) == v + assert A.origin.func(*A.origin.args) == A.origin + + +def test_coordsys3d_equivalence(): + A = CoordSys3D('A') + A1 = CoordSys3D('A') + assert A1 == A + B = CoordSys3D('B') + assert A != B + + +def test_orienters(): + A = CoordSys3D('A') + axis_orienter = AxisOrienter(a, A.k) + body_orienter = BodyOrienter(a, b, c, '123') + space_orienter = SpaceOrienter(a, b, c, '123') + q_orienter = QuaternionOrienter(q1, q2, q3, q4) + assert axis_orienter.rotation_matrix(A) == Matrix([ + [ cos(a), sin(a), 0], + [-sin(a), cos(a), 0], + [ 0, 0, 1]]) + assert body_orienter.rotation_matrix() == Matrix([ + [ cos(b)*cos(c), sin(a)*sin(b)*cos(c) + sin(c)*cos(a), + sin(a)*sin(c) - sin(b)*cos(a)*cos(c)], + [-sin(c)*cos(b), -sin(a)*sin(b)*sin(c) + cos(a)*cos(c), + sin(a)*cos(c) + sin(b)*sin(c)*cos(a)], + [ sin(b), -sin(a)*cos(b), + cos(a)*cos(b)]]) + assert space_orienter.rotation_matrix() == Matrix([ + [cos(b)*cos(c), sin(c)*cos(b), -sin(b)], + [sin(a)*sin(b)*cos(c) - sin(c)*cos(a), + sin(a)*sin(b)*sin(c) + cos(a)*cos(c), sin(a)*cos(b)], + [sin(a)*sin(c) + sin(b)*cos(a)*cos(c), -sin(a)*cos(c) + + sin(b)*sin(c)*cos(a), cos(a)*cos(b)]]) + assert q_orienter.rotation_matrix() == Matrix([ + [q1**2 + q2**2 - q3**2 - q4**2, 2*q1*q4 + 2*q2*q3, + -2*q1*q3 + 2*q2*q4], + [-2*q1*q4 + 2*q2*q3, q1**2 - q2**2 + q3**2 - q4**2, + 2*q1*q2 + 2*q3*q4], + [2*q1*q3 + 2*q2*q4, + -2*q1*q2 + 2*q3*q4, q1**2 - q2**2 - q3**2 + q4**2]]) + + +def test_coordinate_vars(): + """ + Tests the coordinate variables functionality with respect to + reorientation of coordinate systems. + """ + A = CoordSys3D('A') + # Note that the name given on the lhs is different from A.x._name + assert BaseScalar(0, A, 'A_x', r'\mathbf{{x}_{A}}') == A.x + assert BaseScalar(1, A, 'A_y', r'\mathbf{{y}_{A}}') == A.y + assert BaseScalar(2, A, 'A_z', r'\mathbf{{z}_{A}}') == A.z + assert BaseScalar(0, A, 'A_x', r'\mathbf{{x}_{A}}').__hash__() == A.x.__hash__() + assert isinstance(A.x, BaseScalar) and \ + isinstance(A.y, BaseScalar) and \ + isinstance(A.z, BaseScalar) + assert A.x*A.y == A.y*A.x + assert A.scalar_map(A) == {A.x: A.x, A.y: A.y, A.z: A.z} + assert A.x.system == A + assert A.x.diff(A.x) == 1 + B = A.orient_new_axis('B', q, A.k) + assert B.scalar_map(A) == {B.z: A.z, B.y: -A.x*sin(q) + A.y*cos(q), + B.x: A.x*cos(q) + A.y*sin(q)} + assert A.scalar_map(B) == {A.x: B.x*cos(q) - B.y*sin(q), + A.y: B.x*sin(q) + B.y*cos(q), A.z: B.z} + assert express(B.x, A, variables=True) == A.x*cos(q) + A.y*sin(q) + assert express(B.y, A, variables=True) == -A.x*sin(q) + A.y*cos(q) + assert express(B.z, A, variables=True) == A.z + assert expand(express(B.x*B.y*B.z, A, variables=True)) == \ + expand(A.z*(-A.x*sin(q) + A.y*cos(q))*(A.x*cos(q) + A.y*sin(q))) + assert express(B.x*B.i + B.y*B.j + B.z*B.k, A) == \ + (B.x*cos(q) - B.y*sin(q))*A.i + (B.x*sin(q) + \ + B.y*cos(q))*A.j + B.z*A.k + assert simplify(express(B.x*B.i + B.y*B.j + B.z*B.k, A, \ + variables=True)) == \ + A.x*A.i + A.y*A.j + A.z*A.k + assert express(A.x*A.i + A.y*A.j + A.z*A.k, B) == \ + (A.x*cos(q) + A.y*sin(q))*B.i + \ + (-A.x*sin(q) + A.y*cos(q))*B.j + A.z*B.k + assert simplify(express(A.x*A.i + A.y*A.j + A.z*A.k, B, \ + variables=True)) == \ + B.x*B.i + B.y*B.j + B.z*B.k + N = B.orient_new_axis('N', -q, B.k) + assert N.scalar_map(A) == \ + {N.x: A.x, N.z: A.z, N.y: A.y} + C = A.orient_new_axis('C', q, A.i + A.j + A.k) + mapping = A.scalar_map(C) + assert mapping[A.x].equals(C.x*(2*cos(q) + 1)/3 + + C.y*(-2*sin(q + pi/6) + 1)/3 + + C.z*(-2*cos(q + pi/3) + 1)/3) + assert mapping[A.y].equals(C.x*(-2*cos(q + pi/3) + 1)/3 + + C.y*(2*cos(q) + 1)/3 + + C.z*(-2*sin(q + pi/6) + 1)/3) + assert mapping[A.z].equals(C.x*(-2*sin(q + pi/6) + 1)/3 + + C.y*(-2*cos(q + pi/3) + 1)/3 + + C.z*(2*cos(q) + 1)/3) + D = A.locate_new('D', a*A.i + b*A.j + c*A.k) + assert D.scalar_map(A) == {D.z: A.z - c, D.x: A.x - a, D.y: A.y - b} + E = A.orient_new_axis('E', a, A.k, a*A.i + b*A.j + c*A.k) + assert A.scalar_map(E) == {A.z: E.z + c, + A.x: E.x*cos(a) - E.y*sin(a) + a, + A.y: E.x*sin(a) + E.y*cos(a) + b} + assert E.scalar_map(A) == {E.x: (A.x - a)*cos(a) + (A.y - b)*sin(a), + E.y: (-A.x + a)*sin(a) + (A.y - b)*cos(a), + E.z: A.z - c} + F = A.locate_new('F', Vector.zero) + assert A.scalar_map(F) == {A.z: F.z, A.x: F.x, A.y: F.y} + + +def test_rotation_matrix(): + N = CoordSys3D('N') + A = N.orient_new_axis('A', q1, N.k) + B = A.orient_new_axis('B', q2, A.i) + C = B.orient_new_axis('C', q3, B.j) + D = N.orient_new_axis('D', q4, N.j) + E = N.orient_new_space('E', q1, q2, q3, '123') + F = N.orient_new_quaternion('F', q1, q2, q3, q4) + G = N.orient_new_body('G', q1, q2, q3, '123') + assert N.rotation_matrix(C) == Matrix([ + [- sin(q1) * sin(q2) * sin(q3) + cos(q1) * cos(q3), - sin(q1) * + cos(q2), sin(q1) * sin(q2) * cos(q3) + sin(q3) * cos(q1)], \ + [sin(q1) * cos(q3) + sin(q2) * sin(q3) * cos(q1), \ + cos(q1) * cos(q2), sin(q1) * sin(q3) - sin(q2) * cos(q1) * \ + cos(q3)], [- sin(q3) * cos(q2), sin(q2), cos(q2) * cos(q3)]]) + test_mat = D.rotation_matrix(C) - Matrix( + [[cos(q1) * cos(q3) * cos(q4) - sin(q3) * (- sin(q4) * cos(q2) + + sin(q1) * sin(q2) * cos(q4)), - sin(q2) * sin(q4) - sin(q1) * + cos(q2) * cos(q4), sin(q3) * cos(q1) * cos(q4) + cos(q3) * \ + (- sin(q4) * cos(q2) + sin(q1) * sin(q2) * cos(q4))], \ + [sin(q1) * cos(q3) + sin(q2) * sin(q3) * cos(q1), cos(q1) * \ + cos(q2), sin(q1) * sin(q3) - sin(q2) * cos(q1) * cos(q3)], \ + [sin(q4) * cos(q1) * cos(q3) - sin(q3) * (cos(q2) * cos(q4) + \ + sin(q1) * sin(q2) * \ + sin(q4)), sin(q2) * + cos(q4) - sin(q1) * sin(q4) * cos(q2), sin(q3) * \ + sin(q4) * cos(q1) + cos(q3) * (cos(q2) * cos(q4) + \ + sin(q1) * sin(q2) * sin(q4))]]) + assert test_mat.expand() == zeros(3, 3) + assert E.rotation_matrix(N) == Matrix( + [[cos(q2)*cos(q3), sin(q3)*cos(q2), -sin(q2)], + [sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1), \ + sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3), sin(q1)*cos(q2)], \ + [sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3), - \ + sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1), cos(q1)*cos(q2)]]) + assert F.rotation_matrix(N) == Matrix([[ + q1**2 + q2**2 - q3**2 - q4**2, + 2*q1*q4 + 2*q2*q3, -2*q1*q3 + 2*q2*q4],[ -2*q1*q4 + 2*q2*q3, + q1**2 - q2**2 + q3**2 - q4**2, 2*q1*q2 + 2*q3*q4], + [2*q1*q3 + 2*q2*q4, + -2*q1*q2 + 2*q3*q4, + q1**2 - q2**2 - q3**2 + q4**2]]) + assert G.rotation_matrix(N) == Matrix([[ + cos(q2)*cos(q3), sin(q1)*sin(q2)*cos(q3) + sin(q3)*cos(q1), + sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3)], [ + -sin(q3)*cos(q2), -sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3), + sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1)],[ + sin(q2), -sin(q1)*cos(q2), cos(q1)*cos(q2)]]) + + +def test_vector_with_orientation(): + """ + Tests the effects of orientation of coordinate systems on + basic vector operations. + """ + N = CoordSys3D('N') + A = N.orient_new_axis('A', q1, N.k) + B = A.orient_new_axis('B', q2, A.i) + C = B.orient_new_axis('C', q3, B.j) + + # Test to_matrix + v1 = a*N.i + b*N.j + c*N.k + assert v1.to_matrix(A) == Matrix([[ a*cos(q1) + b*sin(q1)], + [-a*sin(q1) + b*cos(q1)], + [ c]]) + + # Test dot + assert N.i.dot(A.i) == cos(q1) + assert N.i.dot(A.j) == -sin(q1) + assert N.i.dot(A.k) == 0 + assert N.j.dot(A.i) == sin(q1) + assert N.j.dot(A.j) == cos(q1) + assert N.j.dot(A.k) == 0 + assert N.k.dot(A.i) == 0 + assert N.k.dot(A.j) == 0 + assert N.k.dot(A.k) == 1 + + assert N.i.dot(A.i + A.j) == -sin(q1) + cos(q1) == \ + (A.i + A.j).dot(N.i) + + assert A.i.dot(C.i) == cos(q3) + assert A.i.dot(C.j) == 0 + assert A.i.dot(C.k) == sin(q3) + assert A.j.dot(C.i) == sin(q2)*sin(q3) + assert A.j.dot(C.j) == cos(q2) + assert A.j.dot(C.k) == -sin(q2)*cos(q3) + assert A.k.dot(C.i) == -cos(q2)*sin(q3) + assert A.k.dot(C.j) == sin(q2) + assert A.k.dot(C.k) == cos(q2)*cos(q3) + + # Test cross + assert N.i.cross(A.i) == sin(q1)*A.k + assert N.i.cross(A.j) == cos(q1)*A.k + assert N.i.cross(A.k) == -sin(q1)*A.i - cos(q1)*A.j + assert N.j.cross(A.i) == -cos(q1)*A.k + assert N.j.cross(A.j) == sin(q1)*A.k + assert N.j.cross(A.k) == cos(q1)*A.i - sin(q1)*A.j + assert N.k.cross(A.i) == A.j + assert N.k.cross(A.j) == -A.i + assert N.k.cross(A.k) == Vector.zero + + assert N.i.cross(A.i) == sin(q1)*A.k + assert N.i.cross(A.j) == cos(q1)*A.k + assert N.i.cross(A.i + A.j) == sin(q1)*A.k + cos(q1)*A.k + assert (A.i + A.j).cross(N.i) == (-sin(q1) - cos(q1))*N.k + + assert A.i.cross(C.i) == sin(q3)*C.j + assert A.i.cross(C.j) == -sin(q3)*C.i + cos(q3)*C.k + assert A.i.cross(C.k) == -cos(q3)*C.j + assert C.i.cross(A.i) == (-sin(q3)*cos(q2))*A.j + \ + (-sin(q2)*sin(q3))*A.k + assert C.j.cross(A.i) == (sin(q2))*A.j + (-cos(q2))*A.k + assert express(C.k.cross(A.i), C).trigsimp() == cos(q3)*C.j + + +def test_orient_new_methods(): + N = CoordSys3D('N') + orienter1 = AxisOrienter(q4, N.j) + orienter2 = SpaceOrienter(q1, q2, q3, '123') + orienter3 = QuaternionOrienter(q1, q2, q3, q4) + orienter4 = BodyOrienter(q1, q2, q3, '123') + D = N.orient_new('D', (orienter1, )) + E = N.orient_new('E', (orienter2, )) + F = N.orient_new('F', (orienter3, )) + G = N.orient_new('G', (orienter4, )) + assert D == N.orient_new_axis('D', q4, N.j) + assert E == N.orient_new_space('E', q1, q2, q3, '123') + assert F == N.orient_new_quaternion('F', q1, q2, q3, q4) + assert G == N.orient_new_body('G', q1, q2, q3, '123') + + +def test_locatenew_point(): + """ + Tests Point class, and locate_new method in CoordSys3D. + """ + A = CoordSys3D('A') + assert isinstance(A.origin, Point) + v = a*A.i + b*A.j + c*A.k + C = A.locate_new('C', v) + assert C.origin.position_wrt(A) == \ + C.position_wrt(A) == \ + C.origin.position_wrt(A.origin) == v + assert A.origin.position_wrt(C) == \ + A.position_wrt(C) == \ + A.origin.position_wrt(C.origin) == -v + assert A.origin.express_coordinates(C) == (-a, -b, -c) + p = A.origin.locate_new('p', -v) + assert p.express_coordinates(A) == (-a, -b, -c) + assert p.position_wrt(C.origin) == p.position_wrt(C) == \ + -2 * v + p1 = p.locate_new('p1', 2*v) + assert p1.position_wrt(C.origin) == Vector.zero + assert p1.express_coordinates(C) == (0, 0, 0) + p2 = p.locate_new('p2', A.i) + assert p1.position_wrt(p2) == 2*v - A.i + assert p2.express_coordinates(C) == (-2*a + 1, -2*b, -2*c) + + +def test_create_new(): + a = CoordSys3D('a') + c = a.create_new('c', transformation='spherical') + assert c._parent == a + assert c.transformation_to_parent() == \ + (c.r*sin(c.theta)*cos(c.phi), c.r*sin(c.theta)*sin(c.phi), c.r*cos(c.theta)) + assert c.transformation_from_parent() == \ + (sqrt(a.x**2 + a.y**2 + a.z**2), acos(a.z/sqrt(a.x**2 + a.y**2 + a.z**2)), atan2(a.y, a.x)) + + +def test_evalf(): + A = CoordSys3D('A') + v = 3*A.i + 4*A.j + a*A.k + assert v.n() == v.evalf() + assert v.evalf(subs={a:1}) == v.subs(a, 1).evalf() + + +def test_lame_coefficients(): + a = CoordSys3D('a', 'spherical') + assert a.lame_coefficients() == (1, a.r, sin(a.theta)*a.r) + a = CoordSys3D('a') + assert a.lame_coefficients() == (1, 1, 1) + a = CoordSys3D('a', 'cartesian') + assert a.lame_coefficients() == (1, 1, 1) + a = CoordSys3D('a', 'cylindrical') + assert a.lame_coefficients() == (1, a.r, 1) + + +def test_transformation_equations(): + + x, y, z = symbols('x y z') + # Str + a = CoordSys3D('a', transformation='spherical', + variable_names=["r", "theta", "phi"]) + r, theta, phi = a.base_scalars() + + assert r == a.r + assert theta == a.theta + assert phi == a.phi + + raises(AttributeError, lambda: a.x) + raises(AttributeError, lambda: a.y) + raises(AttributeError, lambda: a.z) + + assert a.transformation_to_parent() == ( + r*sin(theta)*cos(phi), + r*sin(theta)*sin(phi), + r*cos(theta) + ) + assert a.lame_coefficients() == (1, r, r*sin(theta)) + assert a.transformation_from_parent_function()(x, y, z) == ( + sqrt(x ** 2 + y ** 2 + z ** 2), + acos((z) / sqrt(x**2 + y**2 + z**2)), + atan2(y, x) + ) + a = CoordSys3D('a', transformation='cylindrical', + variable_names=["r", "theta", "z"]) + r, theta, z = a.base_scalars() + assert a.transformation_to_parent() == ( + r*cos(theta), + r*sin(theta), + z + ) + assert a.lame_coefficients() == (1, a.r, 1) + assert a.transformation_from_parent_function()(x, y, z) == (sqrt(x**2 + y**2), + atan2(y, x), z) + + a = CoordSys3D('a', 'cartesian') + assert a.transformation_to_parent() == (a.x, a.y, a.z) + assert a.lame_coefficients() == (1, 1, 1) + assert a.transformation_from_parent_function()(x, y, z) == (x, y, z) + + # Variables and expressions + + # Cartesian with equation tuple: + x, y, z = symbols('x y z') + a = CoordSys3D('a', ((x, y, z), (x, y, z))) + a._calculate_inv_trans_equations() + assert a.transformation_to_parent() == (a.x1, a.x2, a.x3) + assert a.lame_coefficients() == (1, 1, 1) + assert a.transformation_from_parent_function()(x, y, z) == (x, y, z) + r, theta, z = symbols("r theta z") + + # Cylindrical with equation tuple: + a = CoordSys3D('a', [(r, theta, z), (r*cos(theta), r*sin(theta), z)], + variable_names=["r", "theta", "z"]) + r, theta, z = a.base_scalars() + assert a.transformation_to_parent() == ( + r*cos(theta), r*sin(theta), z + ) + assert a.lame_coefficients() == ( + sqrt(sin(theta)**2 + cos(theta)**2), + sqrt(r**2*sin(theta)**2 + r**2*cos(theta)**2), + 1 + ) # ==> this should simplify to (1, r, 1), tests are too slow with `simplify`. + + # Definitions with `lambda`: + + # Cartesian with `lambda` + a = CoordSys3D('a', lambda x, y, z: (x, y, z)) + assert a.transformation_to_parent() == (a.x1, a.x2, a.x3) + assert a.lame_coefficients() == (1, 1, 1) + a._calculate_inv_trans_equations() + assert a.transformation_from_parent_function()(x, y, z) == (x, y, z) + + # Spherical with `lambda` + a = CoordSys3D('a', lambda r, theta, phi: (r*sin(theta)*cos(phi), r*sin(theta)*sin(phi), r*cos(theta)), + variable_names=["r", "theta", "phi"]) + r, theta, phi = a.base_scalars() + assert a.transformation_to_parent() == ( + r*sin(theta)*cos(phi), r*sin(phi)*sin(theta), r*cos(theta) + ) + assert a.lame_coefficients() == ( + sqrt(sin(phi)**2*sin(theta)**2 + sin(theta)**2*cos(phi)**2 + cos(theta)**2), + sqrt(r**2*sin(phi)**2*cos(theta)**2 + r**2*sin(theta)**2 + r**2*cos(phi)**2*cos(theta)**2), + sqrt(r**2*sin(phi)**2*sin(theta)**2 + r**2*sin(theta)**2*cos(phi)**2) + ) # ==> this should simplify to (1, r, sin(theta)*r), `simplify` is too slow. + + # Cylindrical with `lambda` + a = CoordSys3D('a', lambda r, theta, z: + (r*cos(theta), r*sin(theta), z), + variable_names=["r", "theta", "z"] + ) + r, theta, z = a.base_scalars() + assert a.transformation_to_parent() == (r*cos(theta), r*sin(theta), z) + assert a.lame_coefficients() == ( + sqrt(sin(theta)**2 + cos(theta)**2), + sqrt(r**2*sin(theta)**2 + r**2*cos(theta)**2), + 1 + ) # ==> this should simplify to (1, a.x, 1) + + raises(TypeError, lambda: CoordSys3D('a', transformation={ + x: x*sin(y)*cos(z), y:x*sin(y)*sin(z), z: x*cos(y)})) + + +def test_check_orthogonality(): + x, y, z = symbols('x y z') + u,v = symbols('u, v') + a = CoordSys3D('a', transformation=((x, y, z), (x*sin(y)*cos(z), x*sin(y)*sin(z), x*cos(y)))) + assert a._check_orthogonality(a._transformation) is True + a = CoordSys3D('a', transformation=((x, y, z), (x * cos(y), x * sin(y), z))) + assert a._check_orthogonality(a._transformation) is True + a = CoordSys3D('a', transformation=((u, v, z), (cosh(u) * cos(v), sinh(u) * sin(v), z))) + assert a._check_orthogonality(a._transformation) is True + + raises(ValueError, lambda: CoordSys3D('a', transformation=((x, y, z), (x, x, z)))) + raises(ValueError, lambda: CoordSys3D('a', transformation=( + (x, y, z), (x*sin(y/2)*cos(z), x*sin(y)*sin(z), x*cos(y))))) + + +def test_rotation_trans_equations(): + a = CoordSys3D('a') + from sympy.core.symbol import symbols + q0 = symbols('q0') + assert a._rotation_trans_equations(a._parent_rotation_matrix, a.base_scalars()) == (a.x, a.y, a.z) + assert a._rotation_trans_equations(a._inverse_rotation_matrix(), a.base_scalars()) == (a.x, a.y, a.z) + b = a.orient_new_axis('b', 0, -a.k) + assert b._rotation_trans_equations(b._parent_rotation_matrix, b.base_scalars()) == (b.x, b.y, b.z) + assert b._rotation_trans_equations(b._inverse_rotation_matrix(), b.base_scalars()) == (b.x, b.y, b.z) + c = a.orient_new_axis('c', q0, -a.k) + assert c._rotation_trans_equations(c._parent_rotation_matrix, c.base_scalars()) == \ + (-sin(q0) * c.y + cos(q0) * c.x, sin(q0) * c.x + cos(q0) * c.y, c.z) + assert c._rotation_trans_equations(c._inverse_rotation_matrix(), c.base_scalars()) == \ + (sin(q0) * c.y + cos(q0) * c.x, -sin(q0) * c.x + cos(q0) * c.y, c.z) diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_dyadic.py b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_dyadic.py new file mode 100644 index 0000000000000000000000000000000000000000..2e396fcf2a81af897b59c0065f6b15f5c6933222 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_dyadic.py @@ -0,0 +1,134 @@ +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.simplify.simplify import simplify +from sympy.vector import (CoordSys3D, Vector, Dyadic, + DyadicAdd, DyadicMul, DyadicZero, + BaseDyadic, express) + + +A = CoordSys3D('A') + + +def test_dyadic(): + a, b = symbols('a, b') + assert Dyadic.zero != 0 + assert isinstance(Dyadic.zero, DyadicZero) + assert BaseDyadic(A.i, A.j) != BaseDyadic(A.j, A.i) + assert (BaseDyadic(Vector.zero, A.i) == + BaseDyadic(A.i, Vector.zero) == Dyadic.zero) + + d1 = A.i | A.i + d2 = A.j | A.j + d3 = A.i | A.j + + assert isinstance(d1, BaseDyadic) + d_mul = a*d1 + assert isinstance(d_mul, DyadicMul) + assert d_mul.base_dyadic == d1 + assert d_mul.measure_number == a + assert isinstance(a*d1 + b*d3, DyadicAdd) + assert d1 == A.i.outer(A.i) + assert d3 == A.i.outer(A.j) + v1 = a*A.i - A.k + v2 = A.i + b*A.j + assert v1 | v2 == v1.outer(v2) == a * (A.i|A.i) + (a*b) * (A.i|A.j) +\ + - (A.k|A.i) - b * (A.k|A.j) + assert d1 * 0 == Dyadic.zero + assert d1 != Dyadic.zero + assert d1 * 2 == 2 * (A.i | A.i) + assert d1 / 2. == 0.5 * d1 + + assert d1.dot(0 * d1) == Vector.zero + assert d1 & d2 == Dyadic.zero + assert d1.dot(A.i) == A.i == d1 & A.i + + assert d1.cross(Vector.zero) == Dyadic.zero + assert d1.cross(A.i) == Dyadic.zero + assert d1 ^ A.j == d1.cross(A.j) + assert d1.cross(A.k) == - A.i | A.j + assert d2.cross(A.i) == - A.j | A.k == d2 ^ A.i + + assert A.i ^ d1 == Dyadic.zero + assert A.j.cross(d1) == - A.k | A.i == A.j ^ d1 + assert Vector.zero.cross(d1) == Dyadic.zero + assert A.k ^ d1 == A.j | A.i + assert A.i.dot(d1) == A.i & d1 == A.i + assert A.j.dot(d1) == Vector.zero + assert Vector.zero.dot(d1) == Vector.zero + assert A.j & d2 == A.j + + assert d1.dot(d3) == d1 & d3 == A.i | A.j == d3 + assert d3 & d1 == Dyadic.zero + + q = symbols('q') + B = A.orient_new_axis('B', q, A.k) + assert express(d1, B) == express(d1, B, B) + + expr1 = ((cos(q)**2) * (B.i | B.i) + (-sin(q) * cos(q)) * + (B.i | B.j) + (-sin(q) * cos(q)) * (B.j | B.i) + (sin(q)**2) * + (B.j | B.j)) + assert (express(d1, B) - expr1).simplify() == Dyadic.zero + + expr2 = (cos(q)) * (B.i | A.i) + (-sin(q)) * (B.j | A.i) + assert (express(d1, B, A) - expr2).simplify() == Dyadic.zero + + expr3 = (cos(q)) * (A.i | B.i) + (-sin(q)) * (A.i | B.j) + assert (express(d1, A, B) - expr3).simplify() == Dyadic.zero + + assert d1.to_matrix(A) == Matrix([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) + assert d1.to_matrix(A, B) == Matrix([[cos(q), -sin(q), 0], + [0, 0, 0], + [0, 0, 0]]) + assert d3.to_matrix(A) == Matrix([[0, 1, 0], [0, 0, 0], [0, 0, 0]]) + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + v1 = a * A.i + b * A.j + c * A.k + v2 = d * A.i + e * A.j + f * A.k + d4 = v1.outer(v2) + assert d4.to_matrix(A) == Matrix([[a * d, a * e, a * f], + [b * d, b * e, b * f], + [c * d, c * e, c * f]]) + d5 = v1.outer(v1) + C = A.orient_new_axis('C', q, A.i) + for expected, actual in zip(C.rotation_matrix(A) * d5.to_matrix(A) * \ + C.rotation_matrix(A).T, d5.to_matrix(C)): + assert (expected - actual).simplify() == 0 + + +def test_dyadic_simplify(): + x, y, z, k, n, m, w, f, s, A = symbols('x, y, z, k, n, m, w, f, s, A') + N = CoordSys3D('N') + + dy = N.i | N.i + test1 = (1 / x + 1 / y) * dy + assert (N.i & test1 & N.i) != (x + y) / (x * y) + test1 = test1.simplify() + assert test1.simplify() == simplify(test1) + assert (N.i & test1 & N.i) == (x + y) / (x * y) + + test2 = (A**2 * s**4 / (4 * pi * k * m**3)) * dy + test2 = test2.simplify() + assert (N.i & test2 & N.i) == (A**2 * s**4 / (4 * pi * k * m**3)) + + test3 = ((4 + 4 * x - 2 * (2 + 2 * x)) / (2 + 2 * x)) * dy + test3 = test3.simplify() + assert (N.i & test3 & N.i) == 0 + + test4 = ((-4 * x * y**2 - 2 * y**3 - 2 * x**2 * y) / (x + y)**2) * dy + test4 = test4.simplify() + assert (N.i & test4 & N.i) == -2 * y + + +def test_dyadic_srepr(): + from sympy.printing.repr import srepr + N = CoordSys3D('N') + + dy = N.i | N.j + res = "BaseDyadic(CoordSys3D(Str('N'), Tuple(ImmutableDenseMatrix([["\ + "Integer(1), Integer(0), Integer(0)], [Integer(0), Integer(1), "\ + "Integer(0)], [Integer(0), Integer(0), Integer(1)]]), "\ + "VectorZero())).i, CoordSys3D(Str('N'), Tuple(ImmutableDenseMatrix("\ + "[[Integer(1), Integer(0), Integer(0)], [Integer(0), Integer(1), "\ + "Integer(0)], [Integer(0), Integer(0), Integer(1)]]), VectorZero())).j)" + assert srepr(dy) == res diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_field_functions.py b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_field_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..035c2ce0234b81069c5ad8dcb1c74f4de0164a8f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_field_functions.py @@ -0,0 +1,321 @@ +from sympy.core.function import Derivative +from sympy.vector.vector import Vector +from sympy.vector.coordsysrect import CoordSys3D +from sympy.simplify import simplify +from sympy.core.symbol import symbols +from sympy.core import S +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.vector.vector import Dot +from sympy.vector.operators import curl, divergence, gradient, Gradient, Divergence, Cross +from sympy.vector.deloperator import Del +from sympy.vector.functions import (is_conservative, is_solenoidal, + scalar_potential, directional_derivative, + laplacian, scalar_potential_difference) +from sympy.testing.pytest import raises + +C = CoordSys3D('C') +i, j, k = C.base_vectors() +x, y, z = C.base_scalars() +delop = Del() +a, b, c, q = symbols('a b c q') + + +def test_del_operator(): + # Tests for curl + + assert delop ^ Vector.zero == Vector.zero + assert ((delop ^ Vector.zero).doit() == Vector.zero == + curl(Vector.zero)) + assert delop.cross(Vector.zero) == delop ^ Vector.zero + assert (delop ^ i).doit() == Vector.zero + assert delop.cross(2*y**2*j, doit=True) == Vector.zero + assert delop.cross(2*y**2*j) == delop ^ 2*y**2*j + v = x*y*z * (i + j + k) + assert ((delop ^ v).doit() == + (-x*y + x*z)*i + (x*y - y*z)*j + (-x*z + y*z)*k == + curl(v)) + assert delop ^ v == delop.cross(v) + assert (delop.cross(2*x**2*j) == + (Derivative(0, C.y) - Derivative(2*C.x**2, C.z))*C.i + + (-Derivative(0, C.x) + Derivative(0, C.z))*C.j + + (-Derivative(0, C.y) + Derivative(2*C.x**2, C.x))*C.k) + assert (delop.cross(2*x**2*j, doit=True) == 4*x*k == + curl(2*x**2*j)) + + #Tests for divergence + assert delop & Vector.zero is S.Zero == divergence(Vector.zero) + assert (delop & Vector.zero).doit() is S.Zero + assert delop.dot(Vector.zero) == delop & Vector.zero + assert (delop & i).doit() is S.Zero + assert (delop & x**2*i).doit() == 2*x == divergence(x**2*i) + assert (delop.dot(v, doit=True) == x*y + y*z + z*x == + divergence(v)) + assert delop & v == delop.dot(v) + assert delop.dot(1/(x*y*z) * (i + j + k), doit=True) == \ + - 1 / (x*y*z**2) - 1 / (x*y**2*z) - 1 / (x**2*y*z) + v = x*i + y*j + z*k + assert (delop & v == Derivative(C.x, C.x) + + Derivative(C.y, C.y) + Derivative(C.z, C.z)) + assert delop.dot(v, doit=True) == 3 == divergence(v) + assert delop & v == delop.dot(v) + assert simplify((delop & v).doit()) == 3 + + #Tests for gradient + assert (delop.gradient(0, doit=True) == Vector.zero == + gradient(0)) + assert delop.gradient(0) == delop(0) + assert (delop(S.Zero)).doit() == Vector.zero + assert (delop(x) == (Derivative(C.x, C.x))*C.i + + (Derivative(C.x, C.y))*C.j + (Derivative(C.x, C.z))*C.k) + assert (delop(x)).doit() == i == gradient(x) + assert (delop(x*y*z) == + (Derivative(C.x*C.y*C.z, C.x))*C.i + + (Derivative(C.x*C.y*C.z, C.y))*C.j + + (Derivative(C.x*C.y*C.z, C.z))*C.k) + assert (delop.gradient(x*y*z, doit=True) == + y*z*i + z*x*j + x*y*k == + gradient(x*y*z)) + assert delop(x*y*z) == delop.gradient(x*y*z) + assert (delop(2*x**2)).doit() == 4*x*i + assert ((delop(a*sin(y) / x)).doit() == + -a*sin(y)/x**2 * i + a*cos(y)/x * j) + + #Tests for directional derivative + assert (Vector.zero & delop)(a) is S.Zero + assert ((Vector.zero & delop)(a)).doit() is S.Zero + assert ((v & delop)(Vector.zero)).doit() == Vector.zero + assert ((v & delop)(S.Zero)).doit() is S.Zero + assert ((i & delop)(x)).doit() == 1 + assert ((j & delop)(y)).doit() == 1 + assert ((k & delop)(z)).doit() == 1 + assert ((i & delop)(x*y*z)).doit() == y*z + assert ((v & delop)(x)).doit() == x + assert ((v & delop)(x*y*z)).doit() == 3*x*y*z + assert (v & delop)(x + y + z) == C.x + C.y + C.z + assert ((v & delop)(x + y + z)).doit() == x + y + z + assert ((v & delop)(v)).doit() == v + assert ((i & delop)(v)).doit() == i + assert ((j & delop)(v)).doit() == j + assert ((k & delop)(v)).doit() == k + assert ((v & delop)(Vector.zero)).doit() == Vector.zero + + # Tests for laplacian on scalar fields + assert laplacian(x*y*z) is S.Zero + assert laplacian(x**2) == S(2) + assert laplacian(x**2*y**2*z**2) == \ + 2*y**2*z**2 + 2*x**2*z**2 + 2*x**2*y**2 + A = CoordSys3D('A', transformation="spherical", variable_names=["r", "theta", "phi"]) + B = CoordSys3D('B', transformation='cylindrical', variable_names=["r", "theta", "z"]) + assert laplacian(A.r + A.theta + A.phi) == 2/A.r + cos(A.theta)/(A.r**2*sin(A.theta)) + assert laplacian(B.r + B.theta + B.z) == 1/B.r + + # Tests for laplacian on vector fields + assert laplacian(x*y*z*(i + j + k)) == Vector.zero + assert laplacian(x*y**2*z*(i + j + k)) == \ + 2*x*z*i + 2*x*z*j + 2*x*z*k + + +def test_product_rules(): + """ + Tests the six product rules defined with respect to the Del + operator + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Del + + """ + + #Define the scalar and vector functions + f = 2*x*y*z + g = x*y + y*z + z*x + u = x**2*i + 4*j - y**2*z*k + v = 4*i + x*y*z*k + + # First product rule + lhs = delop(f * g, doit=True) + rhs = (f * delop(g) + g * delop(f)).doit() + assert simplify(lhs) == simplify(rhs) + + # Second product rule + lhs = delop(u & v).doit() + rhs = ((u ^ (delop ^ v)) + (v ^ (delop ^ u)) + \ + ((u & delop)(v)) + ((v & delop)(u))).doit() + assert simplify(lhs) == simplify(rhs) + + # Third product rule + lhs = (delop & (f*v)).doit() + rhs = ((f * (delop & v)) + (v & (delop(f)))).doit() + assert simplify(lhs) == simplify(rhs) + + # Fourth product rule + lhs = (delop & (u ^ v)).doit() + rhs = ((v & (delop ^ u)) - (u & (delop ^ v))).doit() + assert simplify(lhs) == simplify(rhs) + + # Fifth product rule + lhs = (delop ^ (f * v)).doit() + rhs = (((delop(f)) ^ v) + (f * (delop ^ v))).doit() + assert simplify(lhs) == simplify(rhs) + + # Sixth product rule + lhs = (delop ^ (u ^ v)).doit() + rhs = (u * (delop & v) - v * (delop & u) + + (v & delop)(u) - (u & delop)(v)).doit() + assert simplify(lhs) == simplify(rhs) + + +P = C.orient_new_axis('P', q, C.k) # type: ignore +scalar_field = 2*x**2*y*z +grad_field = gradient(scalar_field) +vector_field = y**2*i + 3*x*j + 5*y*z*k +curl_field = curl(vector_field) + + +def test_conservative(): + assert is_conservative(Vector.zero) is True + assert is_conservative(i) is True + assert is_conservative(2 * i + 3 * j + 4 * k) is True + assert (is_conservative(y*z*i + x*z*j + x*y*k) is + True) + assert is_conservative(x * j) is False + assert is_conservative(grad_field) is True + assert is_conservative(curl_field) is False + assert (is_conservative(4*x*y*z*i + 2*x**2*z*j) is + False) + assert is_conservative(z*P.i + P.x*k) is True + + +def test_solenoidal(): + assert is_solenoidal(Vector.zero) is True + assert is_solenoidal(i) is True + assert is_solenoidal(2 * i + 3 * j + 4 * k) is True + assert (is_solenoidal(y*z*i + x*z*j + x*y*k) is + True) + assert is_solenoidal(y * j) is False + assert is_solenoidal(grad_field) is False + assert is_solenoidal(curl_field) is True + assert is_solenoidal((-2*y + 3)*k) is True + assert is_solenoidal(cos(q)*i + sin(q)*j + cos(q)*P.k) is True + assert is_solenoidal(z*P.i + P.x*k) is True + + +def test_directional_derivative(): + assert directional_derivative(C.x*C.y*C.z, 3*C.i + 4*C.j + C.k) == C.x*C.y + 4*C.x*C.z + 3*C.y*C.z + assert directional_derivative(5*C.x**2*C.z, 3*C.i + 4*C.j + C.k) == 5*C.x**2 + 30*C.x*C.z + assert directional_derivative(5*C.x**2*C.z, 4*C.j) is S.Zero + + D = CoordSys3D("D", "spherical", variable_names=["r", "theta", "phi"], + vector_names=["e_r", "e_theta", "e_phi"]) + r, theta, phi = D.base_scalars() + e_r, e_theta, e_phi = D.base_vectors() + assert directional_derivative(r**2*e_r, e_r) == 2*r*e_r + assert directional_derivative(5*r**2*phi, 3*e_r + 4*e_theta + e_phi) == 5*r**2 + 30*r*phi + + +def test_scalar_potential(): + assert scalar_potential(Vector.zero, C) == 0 + assert scalar_potential(i, C) == x + assert scalar_potential(j, C) == y + assert scalar_potential(k, C) == z + assert scalar_potential(y*z*i + x*z*j + x*y*k, C) == x*y*z + assert scalar_potential(grad_field, C) == scalar_field + assert scalar_potential(z*P.i + P.x*k, C) == x*z*cos(q) + y*z*sin(q) + assert scalar_potential(z*P.i + P.x*k, P) == P.x*P.z + raises(ValueError, lambda: scalar_potential(x*j, C)) + + +def test_scalar_potential_difference(): + point1 = C.origin.locate_new('P1', 1*i + 2*j + 3*k) + point2 = C.origin.locate_new('P2', 4*i + 5*j + 6*k) + genericpointC = C.origin.locate_new('RP', x*i + y*j + z*k) + genericpointP = P.origin.locate_new('PP', P.x*P.i + P.y*P.j + P.z*P.k) + assert scalar_potential_difference(S.Zero, C, point1, point2) == 0 + assert (scalar_potential_difference(scalar_field, C, C.origin, + genericpointC) == + scalar_field) + assert (scalar_potential_difference(grad_field, C, C.origin, + genericpointC) == + scalar_field) + assert scalar_potential_difference(grad_field, C, point1, point2) == 948 + assert (scalar_potential_difference(y*z*i + x*z*j + + x*y*k, C, point1, + genericpointC) == + x*y*z - 6) + potential_diff_P = (2*P.z*(P.x*sin(q) + P.y*cos(q))* + (P.x*cos(q) - P.y*sin(q))**2) + assert (scalar_potential_difference(grad_field, P, P.origin, + genericpointP).simplify() == + potential_diff_P.simplify()) + + +def test_differential_operators_curvilinear_system(): + A = CoordSys3D('A', transformation="spherical", variable_names=["r", "theta", "phi"]) + B = CoordSys3D('B', transformation='cylindrical', variable_names=["r", "theta", "z"]) + # Test for spherical coordinate system and gradient + assert gradient(3*A.r + 4*A.theta) == 3*A.i + 4/A.r*A.j + assert gradient(3*A.r*A.phi + 4*A.theta) == 3*A.phi*A.i + 4/A.r*A.j + (3/sin(A.theta))*A.k + assert gradient(0*A.r + 0*A.theta+0*A.phi) == Vector.zero + assert gradient(A.r*A.theta*A.phi) == A.theta*A.phi*A.i + A.phi*A.j + (A.theta/sin(A.theta))*A.k + # Test for spherical coordinate system and divergence + assert divergence(A.r * A.i + A.theta * A.j + A.phi * A.k) == \ + (sin(A.theta)*A.r + cos(A.theta)*A.r*A.theta)/(sin(A.theta)*A.r**2) + 3 + 1/(sin(A.theta)*A.r) + assert divergence(3*A.r*A.phi*A.i + A.theta*A.j + A.r*A.theta*A.phi*A.k) == \ + (sin(A.theta)*A.r + cos(A.theta)*A.r*A.theta)/(sin(A.theta)*A.r**2) + 9*A.phi + A.theta/sin(A.theta) + assert divergence(Vector.zero) == 0 + assert divergence(0*A.i + 0*A.j + 0*A.k) == 0 + # Test for spherical coordinate system and curl + assert curl(A.r*A.i + A.theta*A.j + A.phi*A.k) == \ + (cos(A.theta)*A.phi/(sin(A.theta)*A.r))*A.i + (-A.phi/A.r)*A.j + A.theta/A.r*A.k + assert curl(A.r*A.j + A.phi*A.k) == (cos(A.theta)*A.phi/(sin(A.theta)*A.r))*A.i + (-A.phi/A.r)*A.j + 2*A.k + + # Test for cylindrical coordinate system and gradient + assert gradient(0*B.r + 0*B.theta+0*B.z) == Vector.zero + assert gradient(B.r*B.theta*B.z) == B.theta*B.z*B.i + B.z*B.j + B.r*B.theta*B.k + assert gradient(3*B.r) == 3*B.i + assert gradient(2*B.theta) == 2/B.r * B.j + assert gradient(4*B.z) == 4*B.k + # Test for cylindrical coordinate system and divergence + assert divergence(B.r*B.i + B.theta*B.j + B.z*B.k) == 3 + 1/B.r + assert divergence(B.r*B.j + B.z*B.k) == 1 + # Test for cylindrical coordinate system and curl + assert curl(B.r*B.j + B.z*B.k) == 2*B.k + assert curl(3*B.i + 2/B.r*B.j + 4*B.k) == Vector.zero + +def test_mixed_coordinates(): + # gradient + a = CoordSys3D('a') + b = CoordSys3D('b') + c = CoordSys3D('c') + assert gradient(a.x*b.y) == b.y*a.i + a.x*b.j + assert gradient(3*cos(q)*a.x*b.x+a.y*(a.x+(cos(q)+b.x))) ==\ + (a.y + 3*b.x*cos(q))*a.i + (a.x + b.x + cos(q))*a.j + (3*a.x*cos(q) + a.y)*b.i + # Some tests need further work: + # assert gradient(a.x*(cos(a.x+b.x))) == (cos(a.x + b.x))*a.i + a.x*Gradient(cos(a.x + b.x)) + # assert gradient(cos(a.x + b.x)*cos(a.x + b.z)) == Gradient(cos(a.x + b.x)*cos(a.x + b.z)) + assert gradient(a.x**b.y) == Gradient(a.x**b.y) + # assert gradient(cos(a.x+b.y)*a.z) == None + assert gradient(cos(a.x*b.y)) == Gradient(cos(a.x*b.y)) + assert gradient(3*cos(q)*a.x*b.x*a.z*a.y+ b.y*b.z + cos(a.x+a.y)*b.z) == \ + (3*a.y*a.z*b.x*cos(q) - b.z*sin(a.x + a.y))*a.i + \ + (3*a.x*a.z*b.x*cos(q) - b.z*sin(a.x + a.y))*a.j + (3*a.x*a.y*b.x*cos(q))*a.k + \ + (3*a.x*a.y*a.z*cos(q))*b.i + b.z*b.j + (b.y + cos(a.x + a.y))*b.k + # divergence + assert divergence(a.i*a.x+a.j*a.y+a.z*a.k + b.i*b.x+b.j*b.y+b.z*b.k + c.i*c.x+c.j*c.y+c.z*c.k) == S(9) + # assert divergence(3*a.i*a.x*cos(a.x+b.z) + a.j*b.x*c.z) == None + assert divergence(3*a.i*a.x*a.z + b.j*b.x*c.z + 3*a.j*a.z*a.y) == \ + 6*a.z + b.x*Dot(b.j, c.k) + assert divergence(3*cos(q)*a.x*b.x*b.i*c.x) == \ + 3*a.x*b.x*cos(q)*Dot(b.i, c.i) + 3*a.x*c.x*cos(q) + 3*b.x*c.x*cos(q)*Dot(b.i, a.i) + assert divergence(a.x*b.x*c.x*Cross(a.x*a.i, a.y*b.j)) ==\ + a.x*b.x*c.x*Divergence(Cross(a.x*a.i, a.y*b.j)) + \ + b.x*c.x*Dot(Cross(a.x*a.i, a.y*b.j), a.i) + \ + a.x*c.x*Dot(Cross(a.x*a.i, a.y*b.j), b.i) + \ + a.x*b.x*Dot(Cross(a.x*a.i, a.y*b.j), c.i) + assert divergence(a.x*b.x*c.x*(a.x*a.i + b.x*b.i)) == \ + 4*a.x*b.x*c.x +\ + a.x**2*c.x*Dot(a.i, b.i) +\ + a.x**2*b.x*Dot(a.i, c.i) +\ + b.x**2*c.x*Dot(b.i, a.i) +\ + a.x*b.x**2*Dot(b.i, c.i) diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_functions.py b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..dfdf9821b6c853755ce12d0cbdfa599bd4f312e4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_functions.py @@ -0,0 +1,184 @@ +from sympy.vector.vector import Vector +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.functions import express, matrix_to_vector, orthogonalize +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.testing.pytest import raises + +N = CoordSys3D('N') +q1, q2, q3, q4, q5 = symbols('q1 q2 q3 q4 q5') +A = N.orient_new_axis('A', q1, N.k) # type: ignore +B = A.orient_new_axis('B', q2, A.i) +C = B.orient_new_axis('C', q3, B.j) + + +def test_express(): + assert express(Vector.zero, N) == Vector.zero + assert express(S.Zero, N) is S.Zero + assert express(A.i, C) == cos(q3)*C.i + sin(q3)*C.k + assert express(A.j, C) == sin(q2)*sin(q3)*C.i + cos(q2)*C.j - \ + sin(q2)*cos(q3)*C.k + assert express(A.k, C) == -sin(q3)*cos(q2)*C.i + sin(q2)*C.j + \ + cos(q2)*cos(q3)*C.k + assert express(A.i, N) == cos(q1)*N.i + sin(q1)*N.j + assert express(A.j, N) == -sin(q1)*N.i + cos(q1)*N.j + assert express(A.k, N) == N.k + assert express(A.i, A) == A.i + assert express(A.j, A) == A.j + assert express(A.k, A) == A.k + assert express(A.i, B) == B.i + assert express(A.j, B) == cos(q2)*B.j - sin(q2)*B.k + assert express(A.k, B) == sin(q2)*B.j + cos(q2)*B.k + assert express(A.i, C) == cos(q3)*C.i + sin(q3)*C.k + assert express(A.j, C) == sin(q2)*sin(q3)*C.i + cos(q2)*C.j - \ + sin(q2)*cos(q3)*C.k + assert express(A.k, C) == -sin(q3)*cos(q2)*C.i + sin(q2)*C.j + \ + cos(q2)*cos(q3)*C.k + # Check to make sure UnitVectors get converted properly + assert express(N.i, N) == N.i + assert express(N.j, N) == N.j + assert express(N.k, N) == N.k + assert express(N.i, A) == (cos(q1)*A.i - sin(q1)*A.j) + assert express(N.j, A) == (sin(q1)*A.i + cos(q1)*A.j) + assert express(N.k, A) == A.k + assert express(N.i, B) == (cos(q1)*B.i - sin(q1)*cos(q2)*B.j + + sin(q1)*sin(q2)*B.k) + assert express(N.j, B) == (sin(q1)*B.i + cos(q1)*cos(q2)*B.j - + sin(q2)*cos(q1)*B.k) + assert express(N.k, B) == (sin(q2)*B.j + cos(q2)*B.k) + assert express(N.i, C) == ( + (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*C.i - + sin(q1)*cos(q2)*C.j + + (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*C.k) + assert express(N.j, C) == ( + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*C.i + + cos(q1)*cos(q2)*C.j + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*C.k) + assert express(N.k, C) == (-sin(q3)*cos(q2)*C.i + sin(q2)*C.j + + cos(q2)*cos(q3)*C.k) + + assert express(A.i, N) == (cos(q1)*N.i + sin(q1)*N.j) + assert express(A.j, N) == (-sin(q1)*N.i + cos(q1)*N.j) + assert express(A.k, N) == N.k + assert express(A.i, A) == A.i + assert express(A.j, A) == A.j + assert express(A.k, A) == A.k + assert express(A.i, B) == B.i + assert express(A.j, B) == (cos(q2)*B.j - sin(q2)*B.k) + assert express(A.k, B) == (sin(q2)*B.j + cos(q2)*B.k) + assert express(A.i, C) == (cos(q3)*C.i + sin(q3)*C.k) + assert express(A.j, C) == (sin(q2)*sin(q3)*C.i + cos(q2)*C.j - + sin(q2)*cos(q3)*C.k) + assert express(A.k, C) == (-sin(q3)*cos(q2)*C.i + sin(q2)*C.j + + cos(q2)*cos(q3)*C.k) + + assert express(B.i, N) == (cos(q1)*N.i + sin(q1)*N.j) + assert express(B.j, N) == (-sin(q1)*cos(q2)*N.i + + cos(q1)*cos(q2)*N.j + sin(q2)*N.k) + assert express(B.k, N) == (sin(q1)*sin(q2)*N.i - + sin(q2)*cos(q1)*N.j + cos(q2)*N.k) + assert express(B.i, A) == A.i + assert express(B.j, A) == (cos(q2)*A.j + sin(q2)*A.k) + assert express(B.k, A) == (-sin(q2)*A.j + cos(q2)*A.k) + assert express(B.i, B) == B.i + assert express(B.j, B) == B.j + assert express(B.k, B) == B.k + assert express(B.i, C) == (cos(q3)*C.i + sin(q3)*C.k) + assert express(B.j, C) == C.j + assert express(B.k, C) == (-sin(q3)*C.i + cos(q3)*C.k) + + assert express(C.i, N) == ( + (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*N.i + + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*N.j - + sin(q3)*cos(q2)*N.k) + assert express(C.j, N) == ( + -sin(q1)*cos(q2)*N.i + cos(q1)*cos(q2)*N.j + sin(q2)*N.k) + assert express(C.k, N) == ( + (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*N.i + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*N.j + + cos(q2)*cos(q3)*N.k) + assert express(C.i, A) == (cos(q3)*A.i + sin(q2)*sin(q3)*A.j - + sin(q3)*cos(q2)*A.k) + assert express(C.j, A) == (cos(q2)*A.j + sin(q2)*A.k) + assert express(C.k, A) == (sin(q3)*A.i - sin(q2)*cos(q3)*A.j + + cos(q2)*cos(q3)*A.k) + assert express(C.i, B) == (cos(q3)*B.i - sin(q3)*B.k) + assert express(C.j, B) == B.j + assert express(C.k, B) == (sin(q3)*B.i + cos(q3)*B.k) + assert express(C.i, C) == C.i + assert express(C.j, C) == C.j + assert express(C.k, C) == C.k == (C.k) + + # Check to make sure Vectors get converted back to UnitVectors + assert N.i == express((cos(q1)*A.i - sin(q1)*A.j), N).simplify() + assert N.j == express((sin(q1)*A.i + cos(q1)*A.j), N).simplify() + assert N.i == express((cos(q1)*B.i - sin(q1)*cos(q2)*B.j + + sin(q1)*sin(q2)*B.k), N).simplify() + assert N.j == express((sin(q1)*B.i + cos(q1)*cos(q2)*B.j - + sin(q2)*cos(q1)*B.k), N).simplify() + assert N.k == express((sin(q2)*B.j + cos(q2)*B.k), N).simplify() + + + assert A.i == express((cos(q1)*N.i + sin(q1)*N.j), A).simplify() + assert A.j == express((-sin(q1)*N.i + cos(q1)*N.j), A).simplify() + + assert A.j == express((cos(q2)*B.j - sin(q2)*B.k), A).simplify() + assert A.k == express((sin(q2)*B.j + cos(q2)*B.k), A).simplify() + + assert A.i == express((cos(q3)*C.i + sin(q3)*C.k), A).simplify() + assert A.j == express((sin(q2)*sin(q3)*C.i + cos(q2)*C.j - + sin(q2)*cos(q3)*C.k), A).simplify() + + assert A.k == express((-sin(q3)*cos(q2)*C.i + sin(q2)*C.j + + cos(q2)*cos(q3)*C.k), A).simplify() + assert B.i == express((cos(q1)*N.i + sin(q1)*N.j), B).simplify() + assert B.j == express((-sin(q1)*cos(q2)*N.i + + cos(q1)*cos(q2)*N.j + sin(q2)*N.k), B).simplify() + + assert B.k == express((sin(q1)*sin(q2)*N.i - + sin(q2)*cos(q1)*N.j + cos(q2)*N.k), B).simplify() + + assert B.j == express((cos(q2)*A.j + sin(q2)*A.k), B).simplify() + assert B.k == express((-sin(q2)*A.j + cos(q2)*A.k), B).simplify() + assert B.i == express((cos(q3)*C.i + sin(q3)*C.k), B).simplify() + assert B.k == express((-sin(q3)*C.i + cos(q3)*C.k), B).simplify() + assert C.i == express((cos(q3)*A.i + sin(q2)*sin(q3)*A.j - + sin(q3)*cos(q2)*A.k), C).simplify() + assert C.j == express((cos(q2)*A.j + sin(q2)*A.k), C).simplify() + assert C.k == express((sin(q3)*A.i - sin(q2)*cos(q3)*A.j + + cos(q2)*cos(q3)*A.k), C).simplify() + assert C.i == express((cos(q3)*B.i - sin(q3)*B.k), C).simplify() + assert C.k == express((sin(q3)*B.i + cos(q3)*B.k), C).simplify() + + +def test_matrix_to_vector(): + m = Matrix([[1], [2], [3]]) + assert matrix_to_vector(m, C) == C.i + 2*C.j + 3*C.k + m = Matrix([[0], [0], [0]]) + assert matrix_to_vector(m, N) == matrix_to_vector(m, C) == \ + Vector.zero + m = Matrix([[q1], [q2], [q3]]) + assert matrix_to_vector(m, N) == q1*N.i + q2*N.j + q3*N.k + + +def test_orthogonalize(): + C = CoordSys3D('C') + a, b = symbols('a b', integer=True) + i, j, k = C.base_vectors() + v1 = i + 2*j + v2 = 2*i + 3*j + v3 = 3*i + 5*j + v4 = 3*i + j + v5 = 2*i + 2*j + v6 = a*i + b*j + v7 = 4*a*i + 4*b*j + assert orthogonalize(v1, v2) == [C.i + 2*C.j, C.i*Rational(2, 5) + -C.j/5] + # from wikipedia + assert orthogonalize(v4, v5, orthonormal=True) == \ + [(3*sqrt(10))*C.i/10 + (sqrt(10))*C.j/10, (-sqrt(10))*C.i/10 + (3*sqrt(10))*C.j/10] + raises(ValueError, lambda: orthogonalize(v1, v2, v3)) + raises(ValueError, lambda: orthogonalize(v6, v7)) diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_implicitregion.py b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_implicitregion.py new file mode 100644 index 0000000000000000000000000000000000000000..3686d847a7f165cb5ba9aeb813e5922aaa17e1e0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_implicitregion.py @@ -0,0 +1,90 @@ +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.abc import x, y, z, s, t +from sympy.sets import FiniteSet, EmptySet +from sympy.geometry import Point +from sympy.vector import ImplicitRegion +from sympy.testing.pytest import raises + + +def test_ImplicitRegion(): + ellipse = ImplicitRegion((x, y), (x**2/4 + y**2/16 - 1)) + assert ellipse.equation == x**2/4 + y**2/16 - 1 + assert ellipse.variables == (x, y) + assert ellipse.degree == 2 + r = ImplicitRegion((x, y, z), Eq(x**4 + y**2 - x*y, 6)) + assert r.equation == x**4 + y**2 - x*y - 6 + assert r.variables == (x, y, z) + assert r.degree == 4 + + +def test_regular_point(): + r1 = ImplicitRegion((x,), x**2 - 16) + assert r1.regular_point() == (-4,) + c1 = ImplicitRegion((x, y), x**2 + y**2 - 4) + assert c1.regular_point() == (0, -2) + c2 = ImplicitRegion((x, y), (x - S(5)/2)**2 + y**2 - (S(1)/4)**2) + assert c2.regular_point() == (S(5)/2, -S(1)/4) + c3 = ImplicitRegion((x, y), (y - 5)**2 - 16*(x - 5)) + assert c3.regular_point() == (5, 5) + r2 = ImplicitRegion((x, y), x**2 - 4*x*y - 3*y**2 + 4*x + 8*y - 5) + assert r2.regular_point() == (S(4)/7, S(9)/7) + r3 = ImplicitRegion((x, y), x**2 - 2*x*y + 3*y**2 - 2*x - 5*y + 3/2) + raises(ValueError, lambda: r3.regular_point()) + + +def test_singular_points_and_multiplicty(): + r1 = ImplicitRegion((x, y, z), Eq(x + y + z, 0)) + assert r1.singular_points() == EmptySet + r2 = ImplicitRegion((x, y, z), x*y*z + y**4 -x**2*z**2) + assert r2.singular_points() == FiniteSet((0, 0, z), (x, 0, 0)) + assert r2.multiplicity((0, 0, 0)) == 3 + assert r2.multiplicity((0, 0, 6)) == 2 + r3 = ImplicitRegion((x, y, z), z**2 - x**2 - y**2) + assert r3.singular_points() == FiniteSet((0, 0, 0)) + assert r3.multiplicity((0, 0, 0)) == 2 + r4 = ImplicitRegion((x, y), x**2 + y**2 - 2*x) + assert r4.singular_points() == EmptySet + assert r4.multiplicity(Point(1, 3)) == 0 + + +def test_rational_parametrization(): + p = ImplicitRegion((x,), x - 2) + assert p.rational_parametrization() == (x - 2,) + + line = ImplicitRegion((x, y), Eq(y, 3*x + 2)) + assert line.rational_parametrization() == (x, 3*x + 2) + + circle1 = ImplicitRegion((x, y), (x-2)**2 + (y+3)**2 - 4) + assert circle1.rational_parametrization(parameters=t) == (4*t/(t**2 + 1) + 2, 4*t**2/(t**2 + 1) - 5) + circle2 = ImplicitRegion((x, y), (x - S.Half)**2 + y**2 - (S(1)/2)**2) + + assert circle2.rational_parametrization(parameters=t) == (t/(t**2 + 1) + S(1)/2, t**2/(t**2 + 1) - S(1)/2) + circle3 = ImplicitRegion((x, y), Eq(x**2 + y**2, 2*x)) + assert circle3.rational_parametrization(parameters=(t,)) == (2*t/(t**2 + 1) + 1, 2*t**2/(t**2 + 1) - 1) + + parabola = ImplicitRegion((x, y), (y - 3)**2 - 4*(x + 6)) + assert parabola.rational_parametrization(t) == (-6 + 4/t**2, 3 + 4/t) + + rect_hyperbola = ImplicitRegion((x, y), x*y - 1) + assert rect_hyperbola.rational_parametrization(t) == (-1 + (t + 1)/t, t) + + cubic_curve = ImplicitRegion((x, y), x**3 + x**2 - y**2) + assert cubic_curve.rational_parametrization(parameters=(t)) == (t**2 - 1, t*(t**2 - 1)) + cuspidal = ImplicitRegion((x, y), (x**3 - y**2)) + assert cuspidal.rational_parametrization(t) == (t**2, t**3) + + I = ImplicitRegion((x, y), x**3 + x**2 - y**2) + assert I.rational_parametrization(t) == (t**2 - 1, t*(t**2 - 1)) + + sphere = ImplicitRegion((x, y, z), Eq(x**2 + y**2 + z**2, 2*x)) + assert sphere.rational_parametrization(parameters=(s, t)) == (2/(s**2 + t**2 + 1), 2*t/(s**2 + t**2 + 1), 2*s/(s**2 + t**2 + 1)) + + conic = ImplicitRegion((x, y), Eq(x**2 + 4*x*y + 3*y**2 + x - y + 10, 0)) + assert conic.rational_parametrization(t) == ( + S(17)/2 + 4/(3*t**2 + 4*t + 1), 4*t/(3*t**2 + 4*t + 1) - S(11)/2) + + r1 = ImplicitRegion((x, y), y**2 - x**3 + x) + raises(NotImplementedError, lambda: r1.rational_parametrization()) + r2 = ImplicitRegion((x, y), y**2 - x**3 - x**2 + 1) + raises(NotImplementedError, lambda: r2.rational_parametrization()) diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_integrals.py b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..84c900d038e214df1ea59a8cd8fb2929005c3674 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_integrals.py @@ -0,0 +1,106 @@ +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.testing.pytest import raises +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.integrals import ParametricIntegral, vector_integrate +from sympy.vector.parametricregion import ParametricRegion +from sympy.vector.implicitregion import ImplicitRegion +from sympy.abc import x, y, z, u, v, r, t, theta, phi +from sympy.geometry import Point, Segment, Curve, Circle, Polygon, Plane + +C = CoordSys3D('C') + +def test_parametric_lineintegrals(): + halfcircle = ParametricRegion((4*cos(theta), 4*sin(theta)), (theta, -pi/2, pi/2)) + assert ParametricIntegral(C.x*C.y**4, halfcircle) == S(8192)/5 + + curve = ParametricRegion((t, t**2, t**3), (t, 0, 1)) + field1 = 8*C.x**2*C.y*C.z*C.i + 5*C.z*C.j - 4*C.x*C.y*C.k + assert ParametricIntegral(field1, curve) == 1 + line = ParametricRegion((4*t - 1, 2 - 2*t, t), (t, 0, 1)) + assert ParametricIntegral(C.x*C.z*C.i - C.y*C.z*C.k, line) == 3 + + assert ParametricIntegral(4*C.x**3, ParametricRegion((1, t), (t, 0, 2))) == 8 + + helix = ParametricRegion((cos(t), sin(t), 3*t), (t, 0, 4*pi)) + assert ParametricIntegral(C.x*C.y*C.z, helix) == -3*sqrt(10)*pi + + field2 = C.y*C.i + C.z*C.j + C.z*C.k + assert ParametricIntegral(field2, ParametricRegion((cos(t), sin(t), t**2), (t, 0, pi))) == -5*pi/2 + pi**4/2 + +def test_parametric_surfaceintegrals(): + + semisphere = ParametricRegion((2*sin(phi)*cos(theta), 2*sin(phi)*sin(theta), 2*cos(phi)),\ + (theta, 0, 2*pi), (phi, 0, pi/2)) + assert ParametricIntegral(C.z, semisphere) == 8*pi + + cylinder = ParametricRegion((sqrt(3)*cos(theta), sqrt(3)*sin(theta), z), (z, 0, 6), (theta, 0, 2*pi)) + assert ParametricIntegral(C.y, cylinder) == 0 + + cone = ParametricRegion((v*cos(u), v*sin(u), v), (u, 0, 2*pi), (v, 0, 1)) + assert ParametricIntegral(C.x*C.i + C.y*C.j + C.z**4*C.k, cone) == pi/3 + + triangle1 = ParametricRegion((x, y), (x, 0, 2), (y, 0, 10 - 5*x)) + triangle2 = ParametricRegion((x, y), (y, 0, 10 - 5*x), (x, 0, 2)) + assert ParametricIntegral(-15.6*C.y*C.k, triangle1) == ParametricIntegral(-15.6*C.y*C.k, triangle2) + assert ParametricIntegral(C.z, triangle1) == 10*C.z + +def test_parametric_volumeintegrals(): + + cube = ParametricRegion((x, y, z), (x, 0, 1), (y, 0, 1), (z, 0, 1)) + assert ParametricIntegral(1, cube) == 1 + + solidsphere1 = ParametricRegion((r*sin(phi)*cos(theta), r*sin(phi)*sin(theta), r*cos(phi)),\ + (r, 0, 2), (theta, 0, 2*pi), (phi, 0, pi)) + solidsphere2 = ParametricRegion((r*sin(phi)*cos(theta), r*sin(phi)*sin(theta), r*cos(phi)),\ + (r, 0, 2), (phi, 0, pi), (theta, 0, 2*pi)) + assert ParametricIntegral(C.x**2 + C.y**2, solidsphere1) == -256*pi/15 + assert ParametricIntegral(C.x**2 + C.y**2, solidsphere2) == 256*pi/15 + + region_under_plane1 = ParametricRegion((x, y, z), (x, 0, 3), (y, 0, -2*x/3 + 2),\ + (z, 0, 6 - 2*x - 3*y)) + region_under_plane2 = ParametricRegion((x, y, z), (x, 0, 3), (z, 0, 6 - 2*x - 3*y),\ + (y, 0, -2*x/3 + 2)) + + assert ParametricIntegral(C.x*C.i + C.j - 100*C.k, region_under_plane1) == \ + ParametricIntegral(C.x*C.i + C.j - 100*C.k, region_under_plane2) + assert ParametricIntegral(2*C.x, region_under_plane2) == -9 + +def test_vector_integrate(): + halfdisc = ParametricRegion((r*cos(theta), r* sin(theta)), (r, -2, 2), (theta, 0, pi)) + assert vector_integrate(C.x**2, halfdisc) == 4*pi + assert vector_integrate(C.x, ParametricRegion((t, t**2), (t, 2, 3))) == -17*sqrt(17)/12 + 37*sqrt(37)/12 + + assert vector_integrate(C.y**3*C.z, (C.x, 0, 3), (C.y, -1, 4)) == 765*C.z/4 + + s1 = Segment(Point(0, 0), Point(0, 1)) + assert vector_integrate(-15*C.y, s1) == S(-15)/2 + s2 = Segment(Point(4, 3, 9), Point(1, 1, 7)) + assert vector_integrate(C.y*C.i, s2) == -6 + + curve = Curve((sin(t), cos(t)), (t, 0, 2)) + assert vector_integrate(5*C.z, curve) == 10*C.z + + c1 = Circle(Point(2, 3), 6) + assert vector_integrate(C.x*C.y, c1) == 72*pi + c2 = Circle(Point(0, 0), Point(1, 1), Point(1, 0)) + assert vector_integrate(1, c2) == c2.circumference + + triangle = Polygon((0, 0), (1, 0), (1, 1)) + assert vector_integrate(C.x*C.i - 14*C.y*C.j, triangle) == 0 + p1, p2, p3, p4 = [(0, 0), (1, 0), (5, 1), (0, 1)] + poly = Polygon(p1, p2, p3, p4) + assert vector_integrate(-23*C.z, poly) == -161*C.z - 23*sqrt(17)*C.z + + point = Point(2, 3) + assert vector_integrate(C.i*C.y, point) == ParametricIntegral(C.y*C.i, ParametricRegion((2, 3))) + + c3 = ImplicitRegion((x, y), x**2 + y**2 - 4) + assert vector_integrate(45, c3) == 180*pi + c4 = ImplicitRegion((x, y), (x - 3)**2 + (y - 4)**2 - 9) + assert vector_integrate(1, c4) == 6*pi + + pl = Plane(Point(1, 1, 1), Point(2, 3, 4), Point(2, 2, 2)) + raises(ValueError, lambda: vector_integrate(C.x*C.z*C.i + C.k, pl)) diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_operators.py b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_operators.py new file mode 100644 index 0000000000000000000000000000000000000000..5734edadd00547c67d6f864b50afd966ad8392a6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_operators.py @@ -0,0 +1,43 @@ +from sympy.vector import CoordSys3D, Gradient, Divergence, Curl, VectorZero, Laplacian +from sympy.printing.repr import srepr + +R = CoordSys3D('R') +s1 = R.x*R.y*R.z # type: ignore +s2 = R.x + 3*R.y**2 # type: ignore +s3 = R.x**2 + R.y**2 + R.z**2 # type: ignore +v1 = R.x*R.i + R.z*R.z*R.j # type: ignore +v2 = R.x*R.i + R.y*R.j + R.z*R.k # type: ignore +v3 = R.x**2*R.i + R.y**2*R.j + R.z**2*R.k # type: ignore + + +def test_Gradient(): + assert Gradient(s1) == Gradient(R.x*R.y*R.z) + assert Gradient(s2) == Gradient(R.x + 3*R.y**2) + assert Gradient(s1).doit() == R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k + assert Gradient(s2).doit() == R.i + 6*R.y*R.j + + +def test_Divergence(): + assert Divergence(v1) == Divergence(R.x*R.i + R.z*R.z*R.j) + assert Divergence(v2) == Divergence(R.x*R.i + R.y*R.j + R.z*R.k) + assert Divergence(v1).doit() == 1 + assert Divergence(v2).doit() == 3 + # issue 22384 + Rc = CoordSys3D('R', transformation='cylindrical') + assert Divergence(Rc.i).doit() == 1/Rc.r + + +def test_Curl(): + assert Curl(v1) == Curl(R.x*R.i + R.z*R.z*R.j) + assert Curl(v2) == Curl(R.x*R.i + R.y*R.j + R.z*R.k) + assert Curl(v1).doit() == (-2*R.z)*R.i + assert Curl(v2).doit() == VectorZero() + + +def test_Laplacian(): + assert Laplacian(s3) == Laplacian(R.x**2 + R.y**2 + R.z**2) + assert Laplacian(v3) == Laplacian(R.x**2*R.i + R.y**2*R.j + R.z**2*R.k) + assert Laplacian(s3).doit() == 6 + assert Laplacian(v3).doit() == 2*R.i + 2*R.j + 2*R.k + assert srepr(Laplacian(s3)) == \ + 'Laplacian(Add(Pow(R.x, Integer(2)), Pow(R.y, Integer(2)), Pow(R.z, Integer(2))))' diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_parametricregion.py b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_parametricregion.py new file mode 100644 index 0000000000000000000000000000000000000000..e785b96744f9e2c39e91b997fcb70f8a921256bd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_parametricregion.py @@ -0,0 +1,97 @@ +from sympy.core.numbers import pi +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.parametricregion import ParametricRegion, parametric_region_list +from sympy.geometry import Point, Segment, Curve, Ellipse, Line, Parabola, Polygon +from sympy.testing.pytest import raises +from sympy.abc import a, b, r, t, x, y, z, theta, phi + + +C = CoordSys3D('C') + +def test_ParametricRegion(): + + point = ParametricRegion((3, 4)) + assert point.definition == (3, 4) + assert point.parameters == () + assert point.limits == {} + assert point.dimensions == 0 + + # line x = y + line_xy = ParametricRegion((y, y), (y, 1, 5)) + assert line_xy .definition == (y, y) + assert line_xy.parameters == (y,) + assert line_xy.dimensions == 1 + + # line y = z + line_yz = ParametricRegion((x,t,t), x, (t, 1, 2)) + assert line_yz.definition == (x,t,t) + assert line_yz.parameters == (x, t) + assert line_yz.limits == {t: (1, 2)} + assert line_yz.dimensions == 1 + + p1 = ParametricRegion((9*a, -16*b), (a, 0, 2), (b, -1, 5)) + assert p1.definition == (9*a, -16*b) + assert p1.parameters == (a, b) + assert p1.limits == {a: (0, 2), b: (-1, 5)} + assert p1.dimensions == 2 + + p2 = ParametricRegion((t, t**3), t) + assert p2.parameters == (t,) + assert p2.limits == {} + assert p2.dimensions == 0 + + circle = ParametricRegion((r*cos(theta), r*sin(theta)), r, (theta, 0, 2*pi)) + assert circle.definition == (r*cos(theta), r*sin(theta)) + assert circle.dimensions == 1 + + halfdisc = ParametricRegion((r*cos(theta), r*sin(theta)), (r, -2, 2), (theta, 0, pi)) + assert halfdisc.definition == (r*cos(theta), r*sin(theta)) + assert halfdisc.parameters == (r, theta) + assert halfdisc.limits == {r: (-2, 2), theta: (0, pi)} + assert halfdisc.dimensions == 2 + + ellipse = ParametricRegion((a*cos(t), b*sin(t)), (t, 0, 8)) + assert ellipse.parameters == (t,) + assert ellipse.limits == {t: (0, 8)} + assert ellipse.dimensions == 1 + + cylinder = ParametricRegion((r*cos(theta), r*sin(theta), z), (r, 0, 1), (theta, 0, 2*pi), (z, 0, 4)) + assert cylinder.parameters == (r, theta, z) + assert cylinder.dimensions == 3 + + sphere = ParametricRegion((r*sin(phi)*cos(theta),r*sin(phi)*sin(theta), r*cos(phi)), + r, (theta, 0, 2*pi), (phi, 0, pi)) + assert sphere.definition == (r*sin(phi)*cos(theta),r*sin(phi)*sin(theta), r*cos(phi)) + assert sphere.parameters == (r, theta, phi) + assert sphere.dimensions == 2 + + raises(ValueError, lambda: ParametricRegion((a*t**2, 2*a*t), (a, -2))) + raises(ValueError, lambda: ParametricRegion((a, b), (a**2, sin(b)), (a, 2, 4, 6))) + + +def test_parametric_region_list(): + + point = Point(-5, 12) + assert parametric_region_list(point) == [ParametricRegion((-5, 12))] + + e = Ellipse(Point(2, 8), 2, 6) + assert parametric_region_list(e, t) == [ParametricRegion((2*cos(t) + 2, 6*sin(t) + 8), (t, 0, 2*pi))] + + c = Curve((t, t**3), (t, 5, 3)) + assert parametric_region_list(c) == [ParametricRegion((t, t**3), (t, 5, 3))] + + s = Segment(Point(2, 11, -6), Point(0, 2, 5)) + assert parametric_region_list(s, t) == [ParametricRegion((2 - 2*t, 11 - 9*t, 11*t - 6), (t, 0, 1))] + s1 = Segment(Point(0, 0), (1, 0)) + assert parametric_region_list(s1, t) == [ParametricRegion((t, 0), (t, 0, 1))] + s2 = Segment(Point(1, 2, 3), Point(1, 2, 5)) + assert parametric_region_list(s2, t) == [ParametricRegion((1, 2, 2*t + 3), (t, 0, 1))] + s3 = Segment(Point(12, 56), Point(12, 56)) + assert parametric_region_list(s3) == [ParametricRegion((12, 56))] + + poly = Polygon((1,3), (-3, 8), (2, 4)) + assert parametric_region_list(poly, t) == [ParametricRegion((1 - 4*t, 5*t + 3), (t, 0, 1)), ParametricRegion((5*t - 3, 8 - 4*t), (t, 0, 1)), ParametricRegion((2 - t, 4 - t), (t, 0, 1))] + + p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7,8))) + raises(ValueError, lambda: parametric_region_list(p1)) diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_printing.py b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_printing.py new file mode 100644 index 0000000000000000000000000000000000000000..ae76905e967bdf93485f135c6a69f968e1208986 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_printing.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- +from sympy.core.function import Function +from sympy.integrals.integrals import Integral +from sympy.printing.latex import latex +from sympy.printing.pretty import pretty as xpretty +from sympy.vector import CoordSys3D, Del, Vector, express +from sympy.abc import a, b, c +from sympy.testing.pytest import XFAIL + + +def pretty(expr): + """ASCII pretty-printing""" + return xpretty(expr, use_unicode=False, wrap_line=False) + + +def upretty(expr): + """Unicode pretty-printing""" + return xpretty(expr, use_unicode=True, wrap_line=False) + + +# Initialize the basic and tedious vector/dyadic expressions +# needed for testing. +# Some of the pretty forms shown denote how the expressions just +# above them should look with pretty printing. +N = CoordSys3D('N') +C = N.orient_new_axis('C', a, N.k) # type: ignore +v = [] +d = [] +v.append(Vector.zero) +v.append(N.i) # type: ignore +v.append(-N.i) # type: ignore +v.append(N.i + N.j) # type: ignore +v.append(a*N.i) # type: ignore +v.append(a*N.i - b*N.j) # type: ignore +v.append((a**2 + N.x)*N.i + N.k) # type: ignore +v.append((a**2 + b)*N.i + 3*(C.y - c)*N.k) # type: ignore +f = Function('f') +v.append(N.j - (Integral(f(b)) - C.x**2)*N.k) # type: ignore +upretty_v_8 = """\ + ⎛ 2 ⌠ ⎞ \n\ +j_N + ⎜x_C - ⎮ f(b) db⎟ k_N\n\ + ⎝ ⌡ ⎠ \ +""" +pretty_v_8 = """\ +j_N + / / \\\n\ + | 2 | |\n\ + |x_C - | f(b) db|\n\ + | | |\n\ + \\ / / \ +""" + +v.append(N.i + C.k) # type: ignore +v.append(express(N.i, C)) # type: ignore +v.append((a**2 + b)*N.i + (Integral(f(b)))*N.k) # type: ignore +upretty_v_11 = """\ +⎛ 2 ⎞ ⎛⌠ ⎞ \n\ +⎝a + b⎠ i_N + ⎜⎮ f(b) db⎟ k_N\n\ + ⎝⌡ ⎠ \ +""" +pretty_v_11 = """\ +/ 2 \\ + / / \\\n\ +\\a + b/ i_N| | |\n\ + | | f(b) db|\n\ + | | |\n\ + \\/ / \ +""" + +for x in v: + d.append(x | N.k) # type: ignore +s = 3*N.x**2*C.y # type: ignore +upretty_s = """\ + 2\n\ +3⋅y_C⋅x_N \ +""" +pretty_s = """\ + 2\n\ +3*y_C*x_N \ +""" + +# This is the pretty form for ((a**2 + b)*N.i + 3*(C.y - c)*N.k) | N.k +upretty_d_7 = """\ +⎛ 2 ⎞ \n\ +⎝a + b⎠ (i_N|k_N) + (3⋅y_C - 3⋅c) (k_N|k_N)\ +""" +pretty_d_7 = """\ +/ 2 \\ (i_N|k_N) + (3*y_C - 3*c) (k_N|k_N)\n\ +\\a + b/ \ +""" + + +def test_str_printing(): + assert str(v[0]) == '0' + assert str(v[1]) == 'N.i' + assert str(v[2]) == '(-1)*N.i' + assert str(v[3]) == 'N.i + N.j' + assert str(v[8]) == 'N.j + (C.x**2 - Integral(f(b), b))*N.k' + assert str(v[9]) == 'C.k + N.i' + assert str(s) == '3*C.y*N.x**2' + assert str(d[0]) == '0' + assert str(d[1]) == '(N.i|N.k)' + assert str(d[4]) == 'a*(N.i|N.k)' + assert str(d[5]) == 'a*(N.i|N.k) + (-b)*(N.j|N.k)' + assert str(d[8]) == ('(N.j|N.k) + (C.x**2 - ' + + 'Integral(f(b), b))*(N.k|N.k)') + + +@XFAIL +def test_pretty_printing_ascii(): + assert pretty(v[0]) == '0' + assert pretty(v[1]) == 'i_N' + assert pretty(v[5]) == '(a) i_N + (-b) j_N' + assert pretty(v[8]) == pretty_v_8 + assert pretty(v[2]) == '(-1) i_N' + assert pretty(v[11]) == pretty_v_11 + assert pretty(s) == pretty_s + assert pretty(d[0]) == '(0|0)' + assert pretty(d[5]) == '(a) (i_N|k_N) + (-b) (j_N|k_N)' + assert pretty(d[7]) == pretty_d_7 + assert pretty(d[10]) == '(cos(a)) (i_C|k_N) + (-sin(a)) (j_C|k_N)' + + +def test_pretty_print_unicode_v(): + assert upretty(v[0]) == '0' + assert upretty(v[1]) == 'i_N' + assert upretty(v[5]) == '(a) i_N + (-b) j_N' + # Make sure the printing works in other objects + assert upretty(v[5].args) == '((a) i_N, (-b) j_N)' + assert upretty(v[8]) == upretty_v_8 + assert upretty(v[2]) == '(-1) i_N' + assert upretty(v[11]) == upretty_v_11 + assert upretty(s) == upretty_s + assert upretty(d[0]) == '(0|0)' + assert upretty(d[5]) == '(a) (i_N|k_N) + (-b) (j_N|k_N)' + assert upretty(d[7]) == upretty_d_7 + assert upretty(d[10]) == '(cos(a)) (i_C|k_N) + (-sin(a)) (j_C|k_N)' + + +def test_latex_printing(): + assert latex(v[0]) == '\\mathbf{\\hat{0}}' + assert latex(v[1]) == '\\mathbf{\\hat{i}_{N}}' + assert latex(v[2]) == '- \\mathbf{\\hat{i}_{N}}' + assert latex(v[5]) == ('\\left(a\\right)\\mathbf{\\hat{i}_{N}} + ' + + '\\left(- b\\right)\\mathbf{\\hat{j}_{N}}') + assert latex(v[6]) == ('\\left(\\mathbf{{x}_{N}} + a^{2}\\right)\\mathbf{\\hat{i}_' + + '{N}} + \\mathbf{\\hat{k}_{N}}') + assert latex(v[8]) == ('\\mathbf{\\hat{j}_{N}} + \\left(\\mathbf{{x}_' + + '{C}}^{2} - \\int f{\\left(b \\right)}\\,' + + ' db\\right)\\mathbf{\\hat{k}_{N}}') + assert latex(s) == '3 \\mathbf{{y}_{C}} \\mathbf{{x}_{N}}^{2}' + assert latex(d[0]) == '(\\mathbf{\\hat{0}}|\\mathbf{\\hat{0}})' + assert latex(d[4]) == ('\\left(a\\right)\\left(\\mathbf{\\hat{i}_{N}}{\\middle|}' + + '\\mathbf{\\hat{k}_{N}}\\right)') + assert latex(d[9]) == ('\\left(\\mathbf{\\hat{k}_{C}}{\\middle|}' + + '\\mathbf{\\hat{k}_{N}}\\right) + \\left(' + + '\\mathbf{\\hat{i}_{N}}{\\middle|}\\mathbf{' + + '\\hat{k}_{N}}\\right)') + assert latex(d[11]) == ('\\left(a^{2} + b\\right)\\left(\\mathbf{\\hat{i}_{N}}' + + '{\\middle|}\\mathbf{\\hat{k}_{N}}\\right) + ' + + '\\left(\\int f{\\left(b \\right)}\\, db\\right)\\left(' + + '\\mathbf{\\hat{k}_{N}}{\\middle|}\\mathbf{' + + '\\hat{k}_{N}}\\right)') + +def test_issue_23058(): + from sympy import symbols, sin, cos, pi, UnevaluatedExpr + + delop = Del() + CC_ = CoordSys3D("C") + y = CC_.y + xhat = CC_.i + + t = symbols("t") + ten = symbols("10", positive=True) + eps, mu = 4*pi*ten**(-11), ten**(-5) + + Bx = 2 * ten**(-4) * cos(ten**5 * t) * sin(ten**(-3) * y) + vecB = Bx * xhat + vecE = (1/eps) * Integral(delop.cross(vecB/mu).doit(), t) + vecE = vecE.doit() + + vecB_str = """\ +⎛ ⎛y_C⎞ ⎛ 5 ⎞⎞ \n\ +⎜2⋅sin⎜───⎟⋅cos⎝10 ⋅t⎠⎟ i_C\n\ +⎜ ⎜ 3⎟ ⎟ \n\ +⎜ ⎝10 ⎠ ⎟ \n\ +⎜─────────────────────⎟ \n\ +⎜ 4 ⎟ \n\ +⎝ 10 ⎠ \ +""" + vecE_str = """\ +⎛ 4 ⎛ 5 ⎞ ⎛y_C⎞ ⎞ \n\ +⎜-10 ⋅sin⎝10 ⋅t⎠⋅cos⎜───⎟ ⎟ k_C\n\ +⎜ ⎜ 3⎟ ⎟ \n\ +⎜ ⎝10 ⎠ ⎟ \n\ +⎜─────────────────────────⎟ \n\ +⎝ 2⋅π ⎠ \ +""" + + assert upretty(vecB) == vecB_str + assert upretty(vecE) == vecE_str + + ten = UnevaluatedExpr(10) + eps, mu = 4*pi*ten**(-11), ten**(-5) + + Bx = 2 * ten**(-4) * cos(ten**5 * t) * sin(ten**(-3) * y) + vecB = Bx * xhat + + vecB_str = """\ +⎛ -4 ⎛ 5⎞ ⎛ -3⎞⎞ \n\ +⎝2⋅10 ⋅cos⎝t⋅10 ⎠⋅sin⎝y_C⋅10 ⎠⎠ i_C \ +""" + assert upretty(vecB) == vecB_str + +def test_custom_names(): + A = CoordSys3D('A', vector_names=['x', 'y', 'z'], + variable_names=['i', 'j', 'k']) + assert A.i.__str__() == 'A.i' + assert A.x.__str__() == 'A.x' + assert A.i._pretty_form == 'i_A' + assert A.x._pretty_form == 'x_A' + assert A.i._latex_form == r'\mathbf{{i}_{A}}' + assert A.x._latex_form == r"\mathbf{\hat{x}_{A}}" diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_vector.py b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..daba6d6a02c87b41a8bf801eee9b9045897d0003 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/tests/test_vector.py @@ -0,0 +1,342 @@ +from sympy.core import Rational, S, Add, Mul, I +from sympy.simplify import simplify, trigsimp +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.integrals.integrals import Integral +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.vector.vector import Vector, BaseVector, VectorAdd, \ + VectorMul, VectorZero +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.vector import Cross, Dot, cross +from sympy.testing.pytest import raises +from sympy.vector.kind import VectorKind +from sympy.core.kind import NumberKind +from sympy.testing.pytest import XFAIL + + +C = CoordSys3D('C') + +i, j, k = C.base_vectors() +a, b, c = symbols('a b c') + + +def test_cross(): + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert Cross(v1, v2) == Cross(C.x*C.i + C.z**2*C.j, C.x*C.i + C.y*C.j + C.z*C.k) + assert Cross(v1, v2).doit() == C.z**3*C.i + (-C.x*C.z)*C.j + (C.x*C.y - C.x*C.z**2)*C.k + assert cross(v1, v2) == C.z**3*C.i + (-C.x*C.z)*C.j + (C.x*C.y - C.x*C.z**2)*C.k + assert Cross(v1, v2) == -Cross(v2, v1) + # XXX: Cannot use Cross here. See XFAIL test below: + assert cross(v1, v2) + cross(v2, v1) == Vector.zero + + +@XFAIL +def test_cross_xfail(): + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert Cross(v1, v2) + Cross(v2, v1) == Vector.zero + + +def test_dot(): + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert Dot(v1, v2) == Dot(C.x*C.i + C.z**2*C.j, C.x*C.i + C.y*C.j + C.z*C.k) + assert Dot(v1, v2).doit() == C.x**2 + C.y*C.z**2 + assert Dot(v2, v1).doit() == C.x**2 + C.y*C.z**2 + assert Dot(v1, v2) == Dot(v2, v1) + + +def test_vector_sympy(): + """ + Test whether the Vector framework confirms to the hashing + and equality testing properties of SymPy. + """ + v1 = 3*j + assert v1 == j*3 + assert v1.components == {j: 3} + v2 = 3*i + 4*j + 5*k + v3 = 2*i + 4*j + i + 4*k + k + assert v3 == v2 + assert v3.__hash__() == v2.__hash__() + + +def test_kind(): + assert C.i.kind is VectorKind(NumberKind) + assert C.j.kind is VectorKind(NumberKind) + assert C.k.kind is VectorKind(NumberKind) + + assert C.x.kind is NumberKind + assert C.y.kind is NumberKind + assert C.z.kind is NumberKind + + assert Mul._kind_dispatcher(NumberKind, VectorKind(NumberKind)) is VectorKind(NumberKind) + assert Mul(2, C.i).kind is VectorKind(NumberKind) + + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert v1.kind is VectorKind(NumberKind) + assert v2.kind is VectorKind(NumberKind) + + assert (v1 + v2).kind is VectorKind(NumberKind) + assert Add(v1, v2).kind is VectorKind(NumberKind) + assert Cross(v1, v2).doit().kind is VectorKind(NumberKind) + assert VectorAdd(v1, v2).kind is VectorKind(NumberKind) + assert VectorMul(2, v1).kind is VectorKind(NumberKind) + assert VectorZero().kind is VectorKind(NumberKind) + + assert v1.projection(v2).kind is VectorKind(NumberKind) + assert v2.projection(v1).kind is VectorKind(NumberKind) + + +def test_vectoradd(): + assert isinstance(Add(C.i, C.j), VectorAdd) + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert isinstance(Add(v1, v2), VectorAdd) + + # https://github.com/sympy/sympy/issues/26121 + + E = Matrix([C.i, C.j, C.k]).T + a = Matrix([1, 2, 3]) + av = E*a + + assert av[0].kind == VectorKind() + assert isinstance(av[0], VectorAdd) + + +def test_vector(): + assert isinstance(i, BaseVector) + assert i != j + assert j != k + assert k != i + assert i - i == Vector.zero + assert i + Vector.zero == i + assert i - Vector.zero == i + assert Vector.zero != 0 + assert -Vector.zero == Vector.zero + + v1 = a*i + b*j + c*k + v2 = a**2*i + b**2*j + c**2*k + v3 = v1 + v2 + v4 = 2 * v1 + v5 = a * i + + assert isinstance(v1, VectorAdd) + assert v1 - v1 == Vector.zero + assert v1 + Vector.zero == v1 + assert v1.dot(i) == a + assert v1.dot(j) == b + assert v1.dot(k) == c + assert i.dot(v2) == a**2 + assert j.dot(v2) == b**2 + assert k.dot(v2) == c**2 + assert v3.dot(i) == a**2 + a + assert v3.dot(j) == b**2 + b + assert v3.dot(k) == c**2 + c + + assert v1 + v2 == v2 + v1 + assert v1 - v2 == -1 * (v2 - v1) + assert a * v1 == v1 * a + + assert isinstance(v5, VectorMul) + assert v5.base_vector == i + assert v5.measure_number == a + assert isinstance(v4, Vector) + assert isinstance(v4, VectorAdd) + assert isinstance(v4, Vector) + assert isinstance(Vector.zero, VectorZero) + assert isinstance(Vector.zero, Vector) + assert isinstance(v1 * 0, VectorZero) + + assert v1.to_matrix(C) == Matrix([[a], [b], [c]]) + + assert i.components == {i: 1} + assert v5.components == {i: a} + assert v1.components == {i: a, j: b, k: c} + + assert VectorAdd(v1, Vector.zero) == v1 + assert VectorMul(a, v1) == v1*a + assert VectorMul(1, i) == i + assert VectorAdd(v1, Vector.zero) == v1 + assert VectorMul(0, Vector.zero) == Vector.zero + raises(TypeError, lambda: v1.outer(1)) + raises(TypeError, lambda: v1.dot(1)) + + +def test_vector_magnitude_normalize(): + assert Vector.zero.magnitude() == 0 + assert Vector.zero.normalize() == Vector.zero + + assert i.magnitude() == 1 + assert j.magnitude() == 1 + assert k.magnitude() == 1 + assert i.normalize() == i + assert j.normalize() == j + assert k.normalize() == k + + v1 = a * i + assert v1.normalize() == (a/sqrt(a**2))*i + assert v1.magnitude() == sqrt(a**2) + + v2 = a*i + b*j + c*k + assert v2.magnitude() == sqrt(a**2 + b**2 + c**2) + assert v2.normalize() == v2 / v2.magnitude() + + v3 = i + j + assert v3.normalize() == (sqrt(2)/2)*C.i + (sqrt(2)/2)*C.j + + +def test_vector_simplify(): + A, s, k, m = symbols('A, s, k, m') + + test1 = (1 / a + 1 / b) * i + assert (test1 & i) != (a + b) / (a * b) + test1 = simplify(test1) + assert (test1 & i) == (a + b) / (a * b) + assert test1.simplify() == simplify(test1) + + test2 = (A**2 * s**4 / (4 * pi * k * m**3)) * i + test2 = simplify(test2) + assert (test2 & i) == (A**2 * s**4 / (4 * pi * k * m**3)) + + test3 = ((4 + 4 * a - 2 * (2 + 2 * a)) / (2 + 2 * a)) * i + test3 = simplify(test3) + assert (test3 & i) == 0 + + test4 = ((-4 * a * b**2 - 2 * b**3 - 2 * a**2 * b) / (a + b)**2) * i + test4 = simplify(test4) + assert (test4 & i) == -2 * b + + v = (sin(a)+cos(a))**2*i - j + assert trigsimp(v) == (2*sin(a + pi/4)**2)*i + (-1)*j + assert trigsimp(v) == v.trigsimp() + + assert simplify(Vector.zero) == Vector.zero + + +def test_vector_equals(): + assert (2*i).equals(j) is False + assert i.equals(i) is True + + # https://github.com/sympy/sympy/issues/25915 + A = (sqrt(2) + sqrt(6)) / sqrt(sqrt(3) + 2) + assert (A*i).equals(2*i) is True + assert (A*i).equals(3*i) is False + + # Test comparing vectors in different coordinate systems + D = C.orient_new_axis('D', pi/2, C.k) + assert (D.i).equals(C.j) is True + assert (D.i).equals(C.i) is False + + +def test_vector_conjugate(): + # https://github.com/sympy/sympy/issues/27094 + assert (I*i + (1 + I)*j + 2*k).conjugate() == -I*i + (1 - I)*j + 2*k + + +def test_vector_dot(): + assert i.dot(Vector.zero) == 0 + assert Vector.zero.dot(i) == 0 + assert i & Vector.zero == 0 + + assert i.dot(i) == 1 + assert i.dot(j) == 0 + assert i.dot(k) == 0 + assert i & i == 1 + assert i & j == 0 + assert i & k == 0 + + assert j.dot(i) == 0 + assert j.dot(j) == 1 + assert j.dot(k) == 0 + assert j & i == 0 + assert j & j == 1 + assert j & k == 0 + + assert k.dot(i) == 0 + assert k.dot(j) == 0 + assert k.dot(k) == 1 + assert k & i == 0 + assert k & j == 0 + assert k & k == 1 + + raises(TypeError, lambda: k.dot(1)) + + +def test_vector_cross(): + assert i.cross(Vector.zero) == Vector.zero + assert Vector.zero.cross(i) == Vector.zero + + assert i.cross(i) == Vector.zero + assert i.cross(j) == k + assert i.cross(k) == -j + assert i ^ i == Vector.zero + assert i ^ j == k + assert i ^ k == -j + + assert j.cross(i) == -k + assert j.cross(j) == Vector.zero + assert j.cross(k) == i + assert j ^ i == -k + assert j ^ j == Vector.zero + assert j ^ k == i + + assert k.cross(i) == j + assert k.cross(j) == -i + assert k.cross(k) == Vector.zero + assert k ^ i == j + assert k ^ j == -i + assert k ^ k == Vector.zero + + assert k.cross(1) == Cross(k, 1) + + +def test_projection(): + v1 = i + j + k + v2 = 3*i + 4*j + v3 = 0*i + 0*j + assert v1.projection(v1) == i + j + k + assert v1.projection(v2) == Rational(7, 3)*C.i + Rational(7, 3)*C.j + Rational(7, 3)*C.k + assert v1.projection(v1, scalar=True) == S.One + assert v1.projection(v2, scalar=True) == Rational(7, 3) + assert v3.projection(v1) == Vector.zero + assert v3.projection(v1, scalar=True) == S.Zero + + +def test_vector_diff_integrate(): + f = Function('f') + v = f(a)*C.i + a**2*C.j - C.k + assert Derivative(v, a) == Derivative((f(a))*C.i + + a**2*C.j + (-1)*C.k, a) + assert (diff(v, a) == v.diff(a) == Derivative(v, a).doit() == + (Derivative(f(a), a))*C.i + 2*a*C.j) + assert (Integral(v, a) == (Integral(f(a), a))*C.i + + (Integral(a**2, a))*C.j + (Integral(-1, a))*C.k) + + +def test_vector_args(): + raises(ValueError, lambda: BaseVector(3, C)) + raises(TypeError, lambda: BaseVector(0, Vector.zero)) + + +def test_srepr(): + from sympy.printing.repr import srepr + res = "CoordSys3D(Str('C'), Tuple(ImmutableDenseMatrix([[Integer(1), "\ + "Integer(0), Integer(0)], [Integer(0), Integer(1), Integer(0)], "\ + "[Integer(0), Integer(0), Integer(1)]]), VectorZero())).i" + assert srepr(C.i) == res + + +def test_scalar(): + from sympy.vector import CoordSys3D + C = CoordSys3D('C') + v1 = 3*C.i + 4*C.j + 5*C.k + v2 = 3*C.i - 4*C.j + 5*C.k + assert v1.is_Vector is True + assert v1.is_scalar is False + assert (v1.dot(v2)).is_scalar is True + assert (v1.cross(v2)).is_scalar is False diff --git a/.venv/lib/python3.13/site-packages/sympy/vector/vector.py b/.venv/lib/python3.13/site-packages/sympy/vector/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..c035ef48d2edd511f6cdbca19e965d99a2c8c66e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/vector/vector.py @@ -0,0 +1,714 @@ +from __future__ import annotations +from itertools import product + +from sympy.core import Add, Basic +from sympy.core.assumptions import StdFactKB +from sympy.core.expr import AtomicExpr, Expr +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.vector.basisdependent import (BasisDependentZero, + BasisDependent, BasisDependentMul, BasisDependentAdd) +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.dyadic import Dyadic, BaseDyadic, DyadicAdd +from sympy.vector.kind import VectorKind + + +class Vector(BasisDependent): + """ + Super class for all Vector classes. + Ideally, neither this class nor any of its subclasses should be + instantiated by the user. + """ + + is_scalar = False + is_Vector = True + _op_priority = 12.0 + + _expr_type: type[Vector] + _mul_func: type[Vector] + _add_func: type[Vector] + _zero_func: type[Vector] + _base_func: type[Vector] + zero: VectorZero + + kind: VectorKind = VectorKind() + + @property + def components(self): + """ + Returns the components of this vector in the form of a + Python dictionary mapping BaseVector instances to the + corresponding measure numbers. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> C = CoordSys3D('C') + >>> v = 3*C.i + 4*C.j + 5*C.k + >>> v.components + {C.i: 3, C.j: 4, C.k: 5} + + """ + # The '_components' attribute is defined according to the + # subclass of Vector the instance belongs to. + return self._components + + def magnitude(self): + """ + Returns the magnitude of this vector. + """ + return sqrt(self & self) + + def normalize(self): + """ + Returns the normalized version of this vector. + """ + return self / self.magnitude() + + def equals(self, other): + """ + Check if ``self`` and ``other`` are identically equal vectors. + + Explanation + =========== + + Checks if two vector expressions are equal for all possible values of + the symbols present in the expressions. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.abc import x, y + >>> from sympy import pi + >>> C = CoordSys3D('C') + + Compare vectors that are equal or not: + + >>> C.i.equals(C.j) + False + >>> C.i.equals(C.i) + True + + These two vectors are equal if `x = y` but are not identically equal + as expressions since for some values of `x` and `y` they are unequal: + + >>> v1 = x*C.i + C.j + >>> v2 = y*C.i + C.j + >>> v1.equals(v1) + True + >>> v1.equals(v2) + False + + Vectors from different coordinate systems can be compared: + + >>> D = C.orient_new_axis('D', pi/2, C.i) + >>> D.j.equals(C.j) + False + >>> D.j.equals(C.k) + True + + Parameters + ========== + + other: Vector + The other vector expression to compare with. + + Returns + ======= + + ``True``, ``False`` or ``None``. A return value of ``True`` indicates + that the two vectors are identically equal. A return value of ``False`` + indicates that they are not. In some cases it is not possible to + determine if the two vectors are identically equal and ``None`` is + returned. + + See Also + ======== + + sympy.core.expr.Expr.equals + """ + diff = self - other + diff_mag2 = diff.dot(diff) + return diff_mag2.equals(0) + + def dot(self, other): + """ + Returns the dot product of this Vector, either with another + Vector, or a Dyadic, or a Del operator. + If 'other' is a Vector, returns the dot product scalar (SymPy + expression). + If 'other' is a Dyadic, the dot product is returned as a Vector. + If 'other' is an instance of Del, returns the directional + derivative operator as a Python function. If this function is + applied to a scalar expression, it returns the directional + derivative of the scalar field wrt this Vector. + + Parameters + ========== + + other: Vector/Dyadic/Del + The Vector or Dyadic we are dotting with, or a Del operator . + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Del + >>> C = CoordSys3D('C') + >>> delop = Del() + >>> C.i.dot(C.j) + 0 + >>> C.i & C.i + 1 + >>> v = 3*C.i + 4*C.j + 5*C.k + >>> v.dot(C.k) + 5 + >>> (C.i & delop)(C.x*C.y*C.z) + C.y*C.z + >>> d = C.i.outer(C.i) + >>> C.i.dot(d) + C.i + + """ + + # Check special cases + if isinstance(other, Dyadic): + if isinstance(self, VectorZero): + return Vector.zero + outvec = Vector.zero + for k, v in other.components.items(): + vect_dot = k.args[0].dot(self) + outvec += vect_dot * v * k.args[1] + return outvec + from sympy.vector.deloperator import Del + if not isinstance(other, (Del, Vector)): + raise TypeError(str(other) + " is not a vector, dyadic or " + + "del operator") + + # Check if the other is a del operator + if isinstance(other, Del): + def directional_derivative(field): + from sympy.vector.functions import directional_derivative + return directional_derivative(field, self) + return directional_derivative + + return dot(self, other) + + def __and__(self, other): + return self.dot(other) + + __and__.__doc__ = dot.__doc__ + + def cross(self, other): + """ + Returns the cross product of this Vector with another Vector or + Dyadic instance. + The cross product is a Vector, if 'other' is a Vector. If 'other' + is a Dyadic, this returns a Dyadic instance. + + Parameters + ========== + + other: Vector/Dyadic + The Vector or Dyadic we are crossing with. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> C = CoordSys3D('C') + >>> C.i.cross(C.j) + C.k + >>> C.i ^ C.i + 0 + >>> v = 3*C.i + 4*C.j + 5*C.k + >>> v ^ C.i + 5*C.j + (-4)*C.k + >>> d = C.i.outer(C.i) + >>> C.j.cross(d) + (-1)*(C.k|C.i) + + """ + + # Check special cases + if isinstance(other, Dyadic): + if isinstance(self, VectorZero): + return Dyadic.zero + outdyad = Dyadic.zero + for k, v in other.components.items(): + cross_product = self.cross(k.args[0]) + outer = cross_product.outer(k.args[1]) + outdyad += v * outer + return outdyad + + return cross(self, other) + + def __xor__(self, other): + return self.cross(other) + + __xor__.__doc__ = cross.__doc__ + + def outer(self, other): + """ + Returns the outer product of this vector with another, in the + form of a Dyadic instance. + + Parameters + ========== + + other : Vector + The Vector with respect to which the outer product is to + be computed. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> N = CoordSys3D('N') + >>> N.i.outer(N.j) + (N.i|N.j) + + """ + + # Handle the special cases + if not isinstance(other, Vector): + raise TypeError("Invalid operand for outer product") + elif (isinstance(self, VectorZero) or + isinstance(other, VectorZero)): + return Dyadic.zero + + # Iterate over components of both the vectors to generate + # the required Dyadic instance + args = [(v1 * v2) * BaseDyadic(k1, k2) for (k1, v1), (k2, v2) + in product(self.components.items(), other.components.items())] + + return DyadicAdd(*args) + + def projection(self, other, scalar=False): + """ + Returns the vector or scalar projection of the 'other' on 'self'. + + Examples + ======== + + >>> from sympy.vector.coordsysrect import CoordSys3D + >>> C = CoordSys3D('C') + >>> i, j, k = C.base_vectors() + >>> v1 = i + j + k + >>> v2 = 3*i + 4*j + >>> v1.projection(v2) + 7/3*C.i + 7/3*C.j + 7/3*C.k + >>> v1.projection(v2, scalar=True) + 7/3 + + """ + if self.equals(Vector.zero): + return S.Zero if scalar else Vector.zero + + if scalar: + return self.dot(other) / self.dot(self) + else: + return self.dot(other) / self.dot(self) * self + + @property + def _projections(self): + """ + Returns the components of this vector but the output includes + also zero values components. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Vector + >>> C = CoordSys3D('C') + >>> v1 = 3*C.i + 4*C.j + 5*C.k + >>> v1._projections + (3, 4, 5) + >>> v2 = C.x*C.y*C.z*C.i + >>> v2._projections + (C.x*C.y*C.z, 0, 0) + >>> v3 = Vector.zero + >>> v3._projections + (0, 0, 0) + """ + + from sympy.vector.operators import _get_coord_systems + if isinstance(self, VectorZero): + return (S.Zero, S.Zero, S.Zero) + base_vec = next(iter(_get_coord_systems(self))).base_vectors() + return tuple([self.dot(i) for i in base_vec]) + + def __or__(self, other): + return self.outer(other) + + __or__.__doc__ = outer.__doc__ + + def to_matrix(self, system): + """ + Returns the matrix form of this vector with respect to the + specified coordinate system. + + Parameters + ========== + + system : CoordSys3D + The system wrt which the matrix form is to be computed + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> C = CoordSys3D('C') + >>> from sympy.abc import a, b, c + >>> v = a*C.i + b*C.j + c*C.k + >>> v.to_matrix(C) + Matrix([ + [a], + [b], + [c]]) + + """ + + return Matrix([self.dot(unit_vec) for unit_vec in + system.base_vectors()]) + + def separate(self): + """ + The constituents of this vector in different coordinate systems, + as per its definition. + + Returns a dict mapping each CoordSys3D to the corresponding + constituent Vector. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> R1 = CoordSys3D('R1') + >>> R2 = CoordSys3D('R2') + >>> v = R1.i + R2.i + >>> v.separate() == {R1: R1.i, R2: R2.i} + True + + """ + + parts = {} + for vect, measure in self.components.items(): + parts[vect.system] = (parts.get(vect.system, Vector.zero) + + vect * measure) + return parts + + def _div_helper(one, other): + """ Helper for division involving vectors. """ + if isinstance(one, Vector) and isinstance(other, Vector): + raise TypeError("Cannot divide two vectors") + elif isinstance(one, Vector): + if other == S.Zero: + raise ValueError("Cannot divide a vector by zero") + return VectorMul(one, Pow(other, S.NegativeOne)) + else: + raise TypeError("Invalid division involving a vector") + +# The following is adapted from the matrices.expressions.matexpr file + +def get_postprocessor(cls): + def _postprocessor(expr): + vec_class = {Add: VectorAdd}[cls] + vectors = [] + for term in expr.args: + if isinstance(term.kind, VectorKind): + vectors.append(term) + + if vec_class == VectorAdd: + return VectorAdd(*vectors).doit(deep=False) + return _postprocessor + + +Basic._constructor_postprocessor_mapping[Vector] = { + "Add": [get_postprocessor(Add)], +} + +class BaseVector(Vector, AtomicExpr): + """ + Class to denote a base vector. + + """ + + def __new__(cls, index, system, pretty_str=None, latex_str=None): + if pretty_str is None: + pretty_str = "x{}".format(index) + if latex_str is None: + latex_str = "x_{}".format(index) + pretty_str = str(pretty_str) + latex_str = str(latex_str) + # Verify arguments + if index not in range(0, 3): + raise ValueError("index must be 0, 1 or 2") + if not isinstance(system, CoordSys3D): + raise TypeError("system should be a CoordSys3D") + name = system._vector_names[index] + # Initialize an object + obj = super().__new__(cls, S(index), system) + # Assign important attributes + obj._base_instance = obj + obj._components = {obj: S.One} + obj._measure_number = S.One + obj._name = system._name + '.' + name + obj._pretty_form = '' + pretty_str + obj._latex_form = latex_str + obj._system = system + # The _id is used for printing purposes + obj._id = (index, system) + assumptions = {'commutative': True} + obj._assumptions = StdFactKB(assumptions) + + # This attr is used for re-expression to one of the systems + # involved in the definition of the Vector. Applies to + # VectorMul and VectorAdd too. + obj._sys = system + + return obj + + @property + def system(self): + return self._system + + def _sympystr(self, printer): + return self._name + + def _sympyrepr(self, printer): + index, system = self._id + return printer._print(system) + '.' + system._vector_names[index] + + @property + def free_symbols(self): + return {self} + + def _eval_conjugate(self): + return self + + +class VectorAdd(BasisDependentAdd, Vector): + """ + Class to denote sum of Vector instances. + """ + + def __new__(cls, *args, **options): + obj = BasisDependentAdd.__new__(cls, *args, **options) + return obj + + def _sympystr(self, printer): + ret_str = '' + items = list(self.separate().items()) + items.sort(key=lambda x: x[0].__str__()) + for system, vect in items: + base_vects = system.base_vectors() + for x in base_vects: + if x in vect.components: + temp_vect = self.components[x] * x + ret_str += printer._print(temp_vect) + " + " + return ret_str[:-3] + + +class VectorMul(BasisDependentMul, Vector): + """ + Class to denote products of scalars and BaseVectors. + """ + + def __new__(cls, *args, **options): + obj = BasisDependentMul.__new__(cls, *args, **options) + return obj + + @property + def base_vector(self): + """ The BaseVector involved in the product. """ + return self._base_instance + + @property + def measure_number(self): + """ The scalar expression involved in the definition of + this VectorMul. + """ + return self._measure_number + + +class VectorZero(BasisDependentZero, Vector): + """ + Class to denote a zero vector + """ + + _op_priority = 12.1 + _pretty_form = '0' + _latex_form = r'\mathbf{\hat{0}}' + + def __new__(cls): + obj = BasisDependentZero.__new__(cls) + return obj + + +class Cross(Vector): + """ + Represents unevaluated Cross product. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Cross + >>> R = CoordSys3D('R') + >>> v1 = R.i + R.j + R.k + >>> v2 = R.x * R.i + R.y * R.j + R.z * R.k + >>> Cross(v1, v2) + Cross(R.i + R.j + R.k, R.x*R.i + R.y*R.j + R.z*R.k) + >>> Cross(v1, v2).doit() + (-R.y + R.z)*R.i + (R.x - R.z)*R.j + (-R.x + R.y)*R.k + + """ + + def __new__(cls, expr1, expr2): + expr1 = sympify(expr1) + expr2 = sympify(expr2) + if default_sort_key(expr1) > default_sort_key(expr2): + return -Cross(expr2, expr1) + obj = Expr.__new__(cls, expr1, expr2) + obj._expr1 = expr1 + obj._expr2 = expr2 + return obj + + def doit(self, **hints): + return cross(self._expr1, self._expr2) + + +class Dot(Expr): + """ + Represents unevaluated Dot product. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D, Dot + >>> from sympy import symbols + >>> R = CoordSys3D('R') + >>> a, b, c = symbols('a b c') + >>> v1 = R.i + R.j + R.k + >>> v2 = a * R.i + b * R.j + c * R.k + >>> Dot(v1, v2) + Dot(R.i + R.j + R.k, a*R.i + b*R.j + c*R.k) + >>> Dot(v1, v2).doit() + a + b + c + + """ + + def __new__(cls, expr1, expr2): + expr1 = sympify(expr1) + expr2 = sympify(expr2) + expr1, expr2 = sorted([expr1, expr2], key=default_sort_key) + obj = Expr.__new__(cls, expr1, expr2) + obj._expr1 = expr1 + obj._expr2 = expr2 + return obj + + def doit(self, **hints): + return dot(self._expr1, self._expr2) + + +def cross(vect1, vect2): + """ + Returns cross product of two vectors. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.vector.vector import cross + >>> R = CoordSys3D('R') + >>> v1 = R.i + R.j + R.k + >>> v2 = R.x * R.i + R.y * R.j + R.z * R.k + >>> cross(v1, v2) + (-R.y + R.z)*R.i + (R.x - R.z)*R.j + (-R.x + R.y)*R.k + + """ + if isinstance(vect1, Add): + return VectorAdd.fromiter(cross(i, vect2) for i in vect1.args) + if isinstance(vect2, Add): + return VectorAdd.fromiter(cross(vect1, i) for i in vect2.args) + if isinstance(vect1, BaseVector) and isinstance(vect2, BaseVector): + if vect1._sys == vect2._sys: + n1 = vect1.args[0] + n2 = vect2.args[0] + if n1 == n2: + return Vector.zero + n3 = ({0,1,2}.difference({n1, n2})).pop() + sign = 1 if ((n1 + 1) % 3 == n2) else -1 + return sign*vect1._sys.base_vectors()[n3] + from .functions import express + try: + v = express(vect1, vect2._sys) + except ValueError: + return Cross(vect1, vect2) + else: + return cross(v, vect2) + if isinstance(vect1, VectorZero) or isinstance(vect2, VectorZero): + return Vector.zero + if isinstance(vect1, VectorMul): + v1, m1 = next(iter(vect1.components.items())) + return m1*cross(v1, vect2) + if isinstance(vect2, VectorMul): + v2, m2 = next(iter(vect2.components.items())) + return m2*cross(vect1, v2) + + return Cross(vect1, vect2) + + +def dot(vect1, vect2): + """ + Returns dot product of two vectors. + + Examples + ======== + + >>> from sympy.vector import CoordSys3D + >>> from sympy.vector.vector import dot + >>> R = CoordSys3D('R') + >>> v1 = R.i + R.j + R.k + >>> v2 = R.x * R.i + R.y * R.j + R.z * R.k + >>> dot(v1, v2) + R.x + R.y + R.z + + """ + if isinstance(vect1, Add): + return Add.fromiter(dot(i, vect2) for i in vect1.args) + if isinstance(vect2, Add): + return Add.fromiter(dot(vect1, i) for i in vect2.args) + if isinstance(vect1, BaseVector) and isinstance(vect2, BaseVector): + if vect1._sys == vect2._sys: + return S.One if vect1 == vect2 else S.Zero + from .functions import express + try: + v = express(vect2, vect1._sys) + except ValueError: + return Dot(vect1, vect2) + else: + return dot(vect1, v) + if isinstance(vect1, VectorZero) or isinstance(vect2, VectorZero): + return S.Zero + if isinstance(vect1, VectorMul): + v1, m1 = next(iter(vect1.components.items())) + return m1*dot(v1, vect2) + if isinstance(vect2, VectorMul): + v2, m2 = next(iter(vect2.components.items())) + return m2*dot(vect1, v2) + + return Dot(vect1, vect2) + + +Vector._expr_type = Vector +Vector._mul_func = VectorMul +Vector._add_func = VectorAdd +Vector._zero_func = VectorZero +Vector._base_func = BaseVector +Vector.zero = VectorZero()